JAX Vs TensorFlow Vs PyTorch: A Comparative Analysis

JAX is a Python library designed for high-performance numerical computing.

Deep learning owes a lot of its success to automatic differentiation. Popular libraries such as TensorFlow and PyTorch keep track of gradients over neural network parameters during training with both comprising high-level APIs for implementing the commonly used neural network functionality for deep learning. JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research. Along with a Deep Learning framework, JAX has created a super polished linear algebra library with automatic differentiation and XLA support.

About JAX

JAX is a new machine learning library from Google designed for high-performance numerical computing. The Autograd library has the ability to differentiate through every native python and NumPy code.

JAX is defined as “Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more”. The library utilises the grad function transformation to convert a function into a function that returns the original function’s gradient. Jax also offers a function transformation JIT for just-in-time compilation of existing functions and vmap and pmap for vectorization and parallelization, respectively. 

The move from PyTorch or Tensorflow 2 to JAX is nothing short of tectonic. PyTorch builds up a graph during the forward pass, and gradients during the backward pass. JAX, on the other hand, allows the user to express their computation as a Python function, and by transforming it with grad() gives the gradient function that can be evaluated like the computation function—but instead of the output, it gives the gradient of the output for the first parameter that the function took as input.

JAX vs Tensorflow vs Pytorch

While TensorFlow and Pytorch have compiled execution modes, these modes were added later on and thus have left their scars. For instance, TensorFlow’s eager mode is not 100% compatible with the graphic mode allowing for a bad developer experience. Pytorch has a bad history of being forced to use less intuitive tensor formats since they were performed in eager mode. JAX arrives with both of these modes- focusing on eager for debugging and on JIT to perform heavy computations. But the clean nature of these modes allow for mixing and matching whenever needed. 

PyTorch and Tensorflow are deep learning libraries consisting of high-level APIs for modern methods in deep learning. In comparison, JAX is a more functionally-minded library for arbitrary differentiable programming. 

DZone conducted a mini-experiment to study how JAX stacks up against other libraries. A single hidden layer MLP along with a training loop to “fit” a classification problem of random noise – was implemented in JAX, Autograd, Tensorflow 2.0 and PyTorch. This baseline was implemented to compare the performance efficiency of each library. 

While JAX uses just-in-time compilation for library calls, the jit function transformation can be used as a decorator for custom Python functions. For instance, here is a snippet: 


# use jit as a decorator on a function definition


def get_loss(x, w, y_tgts):

y_pred = forward(x, w)

return ce_loss(y_tgts, y_pred)

# use jit as a function for transforming an already defined function into a just-in-time compiled function.

get_grad = grad(get_loss, argnums=(1))

jit_grad = jit(get_grad)

Jax runtimes

The experimenters implemented a simple multi-layer perceptron in each of the libraries, consisting of a sequence of weighted connections determining numerical values and equivalent to the matrix multiplication of input tensors and weight matrices. The results showed that JAX dominated the experiment. JAX has a faster CPU execution time than any other library and the shortest execution time for implementations using only matrix multiplication. The experiment also found that while JAX dominates over other libraries with matmul, PyTorch leads with Linear Layers. PyTorch had a quick execution time while running on the GPU – PyTorch and Linear layers took 9.9 seconds with a batch size of 16,384, which corresponds with JAX running with JIT on a batch size of 1024. PyTorch was the fastest, followed by JAX and TensorFlow when taking advantage of higher-level neural network APIs. For implementing fully connected neural layers, PyTorch’s execution speed was more effective than TensorFlow. On the other hand, JAX offered impressive speed-ups of an order of magnitude or more over the comparable Autograd library. JAX was also the fastest when MLP implementation was limited to matrix multiplication operations. 

More Great AIM Stories

Avi Gopani
Avi Gopani is a technology journalist that seeks to analyse industry trends and developments from an interdisciplinary perspective at Analytics India Magazine. Her articles chronicle cultural, political and social stories that are curated with a focus on the evolving technologies of artificial intelligence and data analytics.

More Stories

Vijaysinh Lendave
Guide To Build A Simple Sentiment Analyzer Using TensorFlow-Hub

Sentiment analysis is a part of natural language processing used to determine whether the sentiment of the data under observation is positive, negative or neutral. Usually, sentiment analysis is carried on text data to help professionals monitor and understand their brand and product sentiment across the industry and customers by taking the feedback.

Victor Dey
A Beginner’s Guide To TensorFlow

TensorFlow allows developers to create dataflow graphs, which are structures that describe how the data moves through a graph, or a series of processing nodes present.

Vijaysinh Lendave
Hands-On Guide To Custom Training With Tensorflow Strategy

Distributed training in TensorFlow is built around data parallelism, where we can replicate the same model architecture on multiple devices and run different slices of input data on them. Here the device is nothing but a unit of CPU + GPU or separate units of GPUs and TPUs. This method follows like; our entire data is divided into equal numbers of slices. These slices are decided based on available devices to train; following each slice, there is a model to train on that slice.

3 Ways to Join our Community

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Telegram Channel

Discover special offers, top stories, upcoming events, and more.

Subscribe to our newsletter

Get the latest updates from AIM