Reverse Engineering DNNs with JIT GEMM Libraries

What is JAXED?

JAXED [JIT + AXED] is a security attack targeted reverse engineering DNNs (Deep Neural Network) architectures using JIT optimized GEMM libraries. Hyperparameters specify the model architecture of a DNN. E.g.:- number of layers, number of channels in a filter etc. In our latest work, we exploit a novel side channel exposed during JIT-optimized GEMM execution. YouTube video is now available at: https://youtu.be/-dsBADRPFhk. Paper is available here.

jaxed-logo-seed-21

Why DNNs are Important?

Recently there has been a trend in incorporating DNNs to a variety of applications. For instance, the latest iPhone 13 release a great example of how companies are developing customized DNN models.

These DNN models have superior performance and interesting applications. Eg:- increasing resolution of images from 720p to 4k, growth of new social media applications have also promoted the use of DNNs. Companies invest a significant amount of resources to identify the best model hyperparameters. Since these hyperparameters give a competitive advantage over rivals in the industry, it is important that companies secure them.

Convolution Operations

Next, let’s look at a fundamental operation behind most of the DNNs. DNNs consist of convolution layers, fully connected layers which are computed using GEMM libraries. Because convolutions consume a majority of the execution time, there has been a trend to optimize these operations using JIT-optimized GEMM libraries.

JIT Optimizations Exposes a Side Channel!

JIT-optimized GEMM libraries generate an instruction sequence at runtime to exploit aggressive optimizations based on the problem size (e.g.:- loop unrolling, address compression at machine code level). This is only performed in the first inference. In subsequent inferences, the instruction sequence is retrieved from a software code cache for faster execution. The cost of JIT code generation is amortized over many iterations to provide faster inference with CPU resources.

Only initial inference triggers JIT code generation. If the same problem dimensions are used, the instruction sequence is retrieved from the code cache for faster inference.

Reverse Engineering DNNs

How will this work as an exploit?

[1] For now, assume that the ML framework is shared. Victim first executes the model which checks for the availability of the instruction sequence corresponding to the matrix dimensions (M1, N1, K1) in standard GEMM notation. Since it’s not available, JIT compilation is triggered.

[2] Afterwards, if an adversary uses the same parameters as the victim (M1, N1, K1), the attacker will observe faster execution because the instruction sequence will be retrieved from the cache.

SEED-2021-animated
SEED-2021-animated

Is this Timing Difference Noticeable?

In the diagram below we show the timing difference between LIBXSMM (JIT) and OpenBLAS (without JIT). Isolated run means that only the adversary is present in the compute environment. Shared means that the victim and adversary are both present in the compute environment and are using the same parameters. Since parameters are shared, both would be accessing the same cache entry.

Timing difference when executing GEMM (with JIT and without JIT). Left: LIBXSMM & Right: OpenBLAS

Attack Scenarios

Now, since we have a basic understanding of the exploit, let us look at two attack scenarios.

  • Attacker only has read only access to a web API, which allows the attacker to detect the availability of the side channel.
Attacker has read only API access
  • Adversary has write access to her own model. And now can perform an exploratory attack to extract model hyperparameters and perform reverse engineering DNNs.
Attacker has both read and write API access

Detailed JAXED Attack: Reverse Engineering DNNs

Detailed JAXED Attack

The victim will first load all the libraries in step 1, and the appropriate weights will be loaded in step 2. As the model is being executed, the GEMM driver will be called in step 3.

This will be directed to the JIT GEMM library API in step 4. First, the code cache will be probed for the availability and if not the JIT code will be generated in step 6.

Now, we can see the adversary’s execution. The attacker loads a simple model and each time changes parameters. The attacker uses dummy weights for completeness purposes, the request is directed through the GEMM driver in step b to the library API in step d. However, this time the execution is serviced through the code cache in step e.


Please refer to our YouTube video / presentation for detailed results. Thanks!

Leave a Reply

Your email address will not be published. Required fields are marked *