Runtime Code Generation for Convolutions

Summary

Our work MARLIN (Matrix Multiplication through Reduced Load Instructions), is now available on github. MARLIN is a runtime code generation library for convolution kernels. The paper won the first place in graduate student category in CGO 2021 (International Symposium on Code Generation and Optimization).

Here are the links for the 10 minute presentation talk and 3 minute pitch video.
1. Long presentation (10 minutes)
2. Short presentation (3 minutes)


Introduction

JIT (Just-in-time) code generation in C++ execution environments has gained increasing traction recently due to the performance benefits. There is a limit as to the optimizations that can be performed at compile time of a program. One of those limitations is the lack of availability of problem dimensions. For example, if the compiler knows the dimensions of a matrix multiplication ahead of time, the code can be optimized further.

In high performing applications such as scientific simulations and Machine Learning, matrix multiplication is computed using a GEMM (General Matrix Multiplication) library that is highly optimized for the underlying target architecture. However, libraries encounter the problem dimensions at only at runtime. Thus, compile time optimizations that utilize problem dimensions are not possible. In recent years, a number of libraries such as LIBXSMM, Intel MKL, FBGEMM and oneDNN have been adopting JIT APIs to allow runtime code generation. All of these libraries target the x86 architecture and extensively utilize new vector extensions such as AVX-512.


Motivation

For matrices with large dimensions, traditional BLAS (Basic Linear Algebra Subprogram) libraries perform well by maximally utilizing the cache hierarchy. Also, for compute bound applications such as large matrix multiplication, it makes sense to reorder data and have better cache access patterns. In contrast, most of the JIT code generation techniques target predominantly memory bound applications such as small and medium matrix multiplication. With small and medium matrices, it is not possible to dedicate compute cycles to reorder the data.

Existing JIT libraries only look at optimizing the instruction flow. These libraries employ techniques such as loop unrolling at address compression (using x86 variable length instructions). In contrast, we devise a unique way of code generation by embedding data values as immediate values within instructions. This improves the performance (geometric mean) of small and medium matrix multiplication over traditional BLAS by ~12% and ~5% over Intel MKL JIT API for convolutions with 32 channels.


JIT Code Generation

MARLIN generates instructions by looking at the data values. The motivation lies in the fact that in Machine Learning applications, the weight matrix remains constant throughout inference. Thus, it makes sense to improve the access to the weight matrix. Code generation has two prominent issues involving the efficiency and the ability to amortize the JIT cost. Modern vector extensions such as AVX-512 provide the ability to use compact instructions for vector operations. For example, using AVX-512 a memory load can be transformed in to an immediate access with just two machine instructions.

movd         eax, 0x41000000       // 0xb8, 0x00, 0x00, 0x41
vpbroadcastd zmm0, eax             // 0x62, 0xf2, 0x7d, 0x48, 0x7c, 0xc0

Since machine code is directly generated rather than relying on LLVM runtime, we are able to observe a much low code generation overhead. All existing high performance JIT GEMM libraries have opted for this option due to performance concerns.

With GEMM, one of the main concerns is the best tiling strategy. We have considered a number of tiling strategies and we observed the best possible strategy when a block of matrix C (in standard notation) is fully computed and written to memory. The reason is that repeated writes to same location could be expensive for small to medium matrix multiplication. Our tiling strategy is extensively explained in our paper and also in the long presentation video.


Automated Code Generation

A core principle of the library design was to have fully automated code generation with lazy initialization. The user has control over triggering the JIT code generation and the user of the library does not have to worry about the implementation details.

Here’s an example code snippet to demonstrate how the library works.

#include <cstdlib>
#include <marlin>
#include <memory>

using namespace MARLIN;
typedef uint64_t index_t;

int main(/*int argc, char* argv[]*/) {
  index_t m = 3;
  index_t n = 5;
  index_t k = 2;

  float *A = static_cast<float *>(std::malloc(m * k * sizeof(float)));
  float *B = static_cast<float *>(std::malloc(n * k * sizeof(float)));
  float *C = static_cast<float *>(std::malloc(m * n * sizeof(float)));
  
  // initialize input
  for (index_t i = 0; i < m; ++i) {
    for (index_t j = 0; j < k; ++j) {
      index_t idx = i * k + j;
      A[i * k + j] = idx + 1;
    }
  }
  for (index_t i = 0; i < k; ++i) {
    for (index_t j = 0; j < n; ++j) {
      index_t idx = i * n + j;
      B[i * n + j] = idx + 1;
    }
  }
  
  // initiate the jitter and generate code
  std::shared_ptr<Jitter<float>> jitter = std::make_shared<Jitter<float>>();
  jitter->generate_code(B, m, k, n);

  // perform sgemm
  sgemm('N', 'N', m, n, k, 1.0, A, k, B, n, 0, C, n, jitter);

  // free
  std::free(A);
  std::free(B);
  std::free(C);
}

The most important lines are shown in the code block below.

std::shared_ptr<Jitter<float>> jitter = std::make_shared<Jitter<float>>();
jitter->generate_code(B, m, k, n);

// perform sgemm
sgemm('N', 'N', m, n, k, 1.0, A, k, B, n, 0, C, n, jitter);

Conclusion

Feel free to try out MARLIN and stay tuned for another blog post explaining the internal implementation of C++ JIT code generation. We will dive down in to page level allocations and how we can write GEMM kernels to execute JIT generated instructions in a page buffer.

For suggestions & feedback:

Leave a Reply

Your email address will not be published. Required fields are marked *