--------------------------------------------------------------------------------

GPU optimized matrix math with Rust and OpenCL

This post is terse and is meant for folks who are familiar with both Rust and OpenCL, though maybe not the two together.

I was actually really suprised by how easy this is with Rust's OpenCL bindings.

In writing a linear regression model that needs to train on small datasets in near realtime, I ended up going with the linear algebra approach over gradient descent since my training set sizes would be small enough to keep this performant.

The matrix math wasn't quite quick enough on the CPU though, but offloading to the GPU wasn't nearly as painful as I thought it'd be.

Kernel

I'm not including my full kernel code because it's long and I'm bad enough at numerical computing to be embarrased by it, but our signature looks like this:

__kernel void vector_transpose(
  __global double *input,
  __global double *output)

Allocating device buffers

Let's get a domain-level Matrix representation together that allocates a GPU matrix with with given OpenCL context:

pub struct GpuMatrix {
  m: usize,
  n: usize,
  buf: CLBuffer<f64>,
}

impl GpuMatrix {
  pub fn new(ctx: &Context, m: usize, n: usize) -> GpuMatrix {
    let buf_size = m * n;
    let buf = ctx.create_buffer(buf_size, opencl::cl::CL_MEM_READ_ONLY);
    GpuMatrix { m: m, n: n, buf: buf }
  }
}

We'll need a method to add it to push the kernel and compute as well:

impl GpuMatrix {
  pub fn transpose(&self, ctx: &Context) -> GpuEvent {
    let kernel = ctx.program.create_kernel("transpose");

    kernel.set_arg(0, &self.buf);
    kernel.set_arg(0, &self.output_buf);

    let event = ctx.queue.enqueue_async_kernel(&kernel, self.buf.len(), None, ())
  }
}

This returns an openCL event, which we'll need to unpack with a blocking function:

impl GpuMatrix {
  pub fn unpack_transpose(
    &self,
    ctx: &Context,
    event: &opencl::hl::Event,
    target: GpuMatrix
  ) -> Box<GpuMatrix> {
    let target = Box::New(GpuMatrix::new(ctx, self.n, self.m));
    ctx.queue.get(&target.buf, event);
    target
  }
}