Deep learning owes a lot of its success to automatic differentiation. Popular libraries such as TensorFlow and PyTorch keep track of gradients over neural network parameters during training with both comprising high-level APIs for implementing the commonly used neural network functionality for deep learning. JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research. Along with a Deep Learning framework, JAX has created a super polished linear algebra library with automatic differentiation and XLA support.
JAX is a new machine learning library from Google designed for high-performance numerical computing. The Autograd library has the ability to differentiate through every native python and NumPy code.
JAX is defined as “Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more”. The library utilises the grad function transformation to convert a function into a function that returns the original function’s gradient. Jax also offers a function transformation JIT for just-in-time compilation of existing functions and vmap and pmap for vectorization and parallelization, respectively.
Sign up for your weekly dose of what's up in emerging technology.
The move from PyTorch or Tensorflow 2 to JAX is nothing short of tectonic. PyTorch builds up a graph during the forward pass, and gradients during the backward pass. JAX, on the other hand, allows the user to express their computation as a Python function, and by transforming it with grad() gives the gradient function that can be evaluated like the computation function—but instead of the output, it gives the gradient of the output for the first parameter that the function took as input.
JAX vs Tensorflow vs Pytorch
While TensorFlow and Pytorch have compiled execution modes, these modes were added later on and thus have left their scars. For instance, TensorFlow’s eager mode is not 100% compatible with the graphic mode allowing for a bad developer experience. Pytorch has a bad history of being forced to use less intuitive tensor formats since they were performed in eager mode. JAX arrives with both of these modes- focusing on eager for debugging and on JIT to perform heavy computations. But the clean nature of these modes allow for mixing and matching whenever needed.
PyTorch and Tensorflow are deep learning libraries consisting of high-level APIs for modern methods in deep learning. In comparison, JAX is a more functionally-minded library for arbitrary differentiable programming.
DZone conducted a mini-experiment to study how JAX stacks up against other libraries. A single hidden layer MLP along with a training loop to “fit” a classification problem of random noise – was implemented in JAX, Autograd, Tensorflow 2.0 and PyTorch. This baseline was implemented to compare the performance efficiency of each library.
While JAX uses just-in-time compilation for library calls, the jit function transformation can be used as a decorator for custom Python functions. For instance, here is a snippet:
# use jit as a decorator on a function definition
def get_loss(x, w, y_tgts):
y_pred = forward(x, w)
return ce_loss(y_tgts, y_pred)
# use jit as a function for transforming an already defined function into a just-in-time compiled function.
get_grad = grad(get_loss, argnums=(1))
jit_grad = jit(get_grad)
The experimenters implemented a simple multi-layer perceptron in each of the libraries, consisting of a sequence of weighted connections determining numerical values and equivalent to the matrix multiplication of input tensors and weight matrices. The results showed that JAX dominated the experiment. JAX has a faster CPU execution time than any other library and the shortest execution time for implementations using only matrix multiplication. The experiment also found that while JAX dominates over other libraries with matmul, PyTorch leads with Linear Layers. PyTorch had a quick execution time while running on the GPU – PyTorch and Linear layers took 9.9 seconds with a batch size of 16,384, which corresponds with JAX running with JIT on a batch size of 1024. PyTorch was the fastest, followed by JAX and TensorFlow when taking advantage of higher-level neural network APIs. For implementing fully connected neural layers, PyTorch’s execution speed was more effective than TensorFlow. On the other hand, JAX offered impressive speed-ups of an order of magnitude or more over the comparable Autograd library. JAX was also the fastest when MLP implementation was limited to matrix multiplication operations.