A shape typed GPU tensor library in Java using ArrayFire

2023-12-14 / 13 min read

I’ve been coding for over a decade, and one preference that I cannot shake is that I hate writing code without strong compile time types. I avoided Python as much as I could (or when I could, used mypy aggressively). When I first started playing around with ML frameworks many years ago I accepted the rationale for Python - you want to move quickly, experiment, try new things and not get bogged down by types.

But then came my first attempt to build a model that wasn’t from a tutorial, and unsurprisingly, finding it hard at first to think about N dimensional vectors - I spent a significant amount of time running my code, getting a shape mismatch error from Tensorflow, and trying to work out what the hell I’d done wrong. This problem never fully went away.

It also felt solvable. It would be great if I got feedback earlier in my development that I was making a mistake like this. I slept on this idea for a while, until a few things happened that made me decide to take a shot at it.

  1. I discovered ArrayFire - a C++ Tensor library with GPU support (CUDA + OpenCL), kernel fusion and a clean C API
  2. Java introduced the Foreign Linker and Memory APIs, creating Java bindings for a C library would be faster, and much more pleasant to build and maintain

With a long-term plan to shift my career towards ML/AI, I wanted to learn across the whole stack. ArrayFire lacked up to date Java bindings, so I decided to see what I could do with a little twist.

Shape checking matmul #

Given two matrices:

  • A with dimensions X x Y
  • B with dimensions Y x Z

AB will have the resulting shape X x Z. How do we infer the resulting shape?

X, Y, Z are all integers, and with the exception of possibly TypeScript no languages offer a performant type system that can do arithmetic. So instead, we use variables to represent the dimension values. The framework provides a set of standard dimension size type variables A...Z.

var x = af.x(1);
var y = af.y(2);
var z = af.z(3);
var left = af.create(new float[]{1, 2})
		.reshape(x, y); // [X, Y, U, U]
var right = af.create(new float[]{1, 2, 3, 4, 5, 6})
		.reshape(y, z); // [Y, Z, U, U]
var result = af.matmul(left, right); // [X, Z, U, U]

We create two tensors, and reshape them giving each dimension an explicit type variable we have defined, and matmul can tell is what the resulting shape is at compile time.

If your shapes don’t fit (say we try to multiply a Y x X and a Y x Z matrix), then it will throw a compilation error:

// Changing last line to:
var result = af.matmul(left.transpose(), right);
// Results in:
// required: Array<T,D0,D1,D2,D3>,Array<T,D1,OD1,D2,D3>
// found:    Array<F32,Y,X,U,U>,Array<F32,Y,Z,U,U>
// reason: inference variable D1 has incompatible equality constraints Y,X

*U is a special dimension type variable, it stands for “Unit” - and it’s always of size 1. In this case it is used as these are only two dimensional vectors out of the possible 4 that we can have with ArrayFire.*

Reducing over different dimensions #

Reducing a tensor (sum, mean, median, min, max, etc) can operate across different dimensions, and whichever one it operates on will be collapsed to size 1. We can encode this fairly cleanly with types:

var data = af.range(F32, 16).reshape(2, 2, 2, 2); // [N, N, N, N]
var mean0 = af.mean(data, D0); // [U, N, N, N]
var mean1 = af.mean(data, D1); // [N, U, N, N]
var mean2 = af.mean(data, D2); // [N, N, U, N]
var mean3 = af.mean(data, D3); // [N, N, N, U]

N is another special dimension type variable, which you get when you don’t specify a type variable as we did above during the reshape. Methods in the library that can’t determine the resulting shape (e.g convolution) will also return dimensions with type N.

SVD #

u vt s, and back again, no problem!

var a = af.a(2);
var b = af.b(3);
var matrix = af.create(F32, new float[]{1, 2, 3, 4, 5, 6}).reshape(a, b);
var svd = af.svd(matrix);
var u = svd.u(); // [A, A, U, U]
var s = svd.s(); // [A, U, U, U] (Non zero diagonal elements)
var vt = svd.vt(); // [B, B, U, U]

We can recreate the original matrix and get back the same type and shape of matrix:

var recreated = af.matmul(u, af.diag(s), af.index(vt, af.seq(a)));
// [A, B, U, U>

Tiling / broadcasting #

ArrayFire along with other frameworks supports broadcasting. As en example, If you wanted to center a batch of tensors, subtracting the mean of each tensor from itself, you might have code like:

var vectors = af.randu(F32, af.shape(64, 100)); // [64, 100, 1, 1]
var means = af.mean(vectors, D0); // [1, 100, 1, 1]
var centeredVectors = af.sub(vectors, means); // [64, 100, 1, 1]

This will throw a compilation error. It is impossible to compute the output shape types if you allow this, as the resulting shape depends on if the left or right argument has the bigger dimensions (i.e which one isn’t being tiled).

The middle ground that can still support automatic broadcasting but still plays nicely with the type system looks like this:

var centeredVectors = af.sub(vectors, means.tile()); // [64, 100, 1, 1]

There still be ~~dragons~~ runtime errors if:

  • The argument marked as tileable is larger in any dimension than the non tiled argument
  • The tileable argument doesn’t cleanly tile into the other argument

Data type transformations #

Types can also be useful on the data type (F32, F16, etc), as the output type of an operation is not always the same type as the inputs.

When summing a tensor in ArrayFire, if the type of the array is S16 the result of the sum will be S32, and U16, U8, B8 all sum into U32. We can encode this knowledge in the type system:

var data = af.create(U8, new byte[]{1, 2, 3, 4})
		.reshape(2, 2); // U8
var sum = af.sum(data); // U32

The type of the sum result has changed to U32. Additionally, the first dimension has changed from N to U as we summed across it, so the dimension has reduced to size 1.

A similar approach is used for converting data into and out of Java land and returning the correct array types:

var tensorF32 = af.range(F32, 8);
var dataF32 = af.data(tensorF32).java(); // float[]

var tensorU8 = af.range(U8, 8);
var dataU8 = af.data(tensorU8).java(); // byte[]

Autograd #

The majority of API methods now support automatic reverse mode differentiation. You can see a complete example training MNIST with a 2-layer neural network here.

Here's an easy to follow example, that optimizes a set of random parameters through a fairly trivial set of operations to produce a single static number.

tidy(() -> {
    var a = params(() -> randu(F32, shape(5)), SGD.create());
    var b = randu(F32, shape(5));
    var latestLoss = Float.POSITIVE_INFINITY;
    for (int i = 0; i < 50 || latestLoss >= 1E-10; i++) {
        latestLoss = tidy(() -> {
            var mul = mul(a, b);
            var loss = pow(sub(sum(mul), 5), 2);
            optimize(loss);
            return data(loss).get(0);
        });
    }
    System.out.println(latestLoss);
});

The API for auto-grad relies on param objects, that are effectively wrappers for tensors/arrays that are mutable. Whenever optimize is called, it works backwards through an implicit graph that is kept up to date whenever any API methods such as pow or mul are called, to work out which gradients we need in order to get back to any parameters that were used in the current scope.

After computing the gradients, they are passed to the optimizer that can be specified for the params object, which can then do gradient descent as they wish. In this example we use simple SGD, but the API should support writing other optimizers such as ADAM.

General design challenges #

Memory management #

GPUs are often memory limited, and you have to be very careful about managing it’s memory. Management of foreign memory segments that are created for interacting with the ArrayFire C API must also be managed carefully. For both of these I decided it would be wise to abandon the garbage collector.

The solution here is memory scopes. There is just one function called tidy and it can be used unsparingly. You can’t do anything without being in a memory scope. This is a wrapper around foreign memory API concepts, and ArrayFire specific methods for managing device memory.

af.tidy(() -> {
    var data = af.zeros(F32, af.shape(2, 2));
});

Functions that generate intermediate outputs use tidy to release intermediate arrays, and if a memory managed object is returned from the lambda passed in, it won’t be tidied up, but moved up to the outer memory scope. For example, the body of a function to generate a covariance matrix:


return tidy(() -> {
    var centered = sub(tensor, mean(tensor, D1).tile());
    var matrix = matmul(centered, centered.transpose());
    return div(matrix, constant(tensor.d1().size() - 1.0f));
});

Fighting Java’s verbosity #

I think that most of these examples above are terse (for Java at least), and I avoided putting my 2010 Java hat on as much as possible in the design of the API. There are few classes, very rarely should you need to write new, and nearly the entire API surface area is under a single class called… af.

This is wildly against the Java naming conventions, but it's worth it. Putting the following at the top of a file is enough to get the entire API imported:

import arrayfire.*; // Types
import static arrayfire.af.*; // Methods & Values

Kotlin #

Kotlin makes things start to feel a bit more magical given support for operator overloads and infix functions. Code like the following is possible, clean yet it is still doing all the type checking above for you, and given it’s interoperability with Java it’s easy to support.

val a = create(1f, 2f, 3f);
val b = create(4f, 5f, 6f);
val c = (a + 1) / b;
var d = a.t() matmul b; // 3x3
var e = a matmul b; // Error

Final thoughts #

This was really intended as an experiment and a learning experience, particularly in understanding how to work with GPUs, and designing sane tensor APIs. The above is all real working code (give or take), although as it depends on JDK 21, and Java’s Foreign Memory and Linker APIs which are still in preview - it will likely continue to break with each 6 monthly JDK release until they aren’t.

There are many important things missing (like even a semblance of documentation) but if you are interested you can check out on the library here: https://github.com/lewish/arrayfire-java-fla.

reply via email follow me on twitter

< lewish