Why is Google’s JAX so popular?

Despite JAX’s reputation as an accelerator, it isn’t fully optimised for per-operating dispatch in CPU computing.

Apart from TensorFlow and PyTorch, Google’s new framework, Just After Execution or JAX, has become increasingly popular and with good reason. Essentially, JAX was developed to accelerate machine learning tasks and make Python’s Numpy easier to use. Even though deep learning is a subset of what JAX can do, JAX gained ground after it was used in Google’s Vision Transformer (ViT) and DeepMind engineers posted a blog explaining why it was suitable for several projects. Quite simply, JAX is a high-performing Python library meant for numerical computing, especially in research. 

There are several reasons to use JAX or even not to use it. Let’s weigh in:

THE BELAMY

Sign up for your weekly dose of what's up in emerging technology.

Why JAX? 

Speed: All JAX operations are based on XLA or Accelerated Linear Algebra which is responsible for JAX’s speed. Also developed by Google, XLA is a domain-specific compiler for linear algebra that uses whole-program optimisations to accelerate computing. XLA makes BERT’s training speed faster by almost 7.3 times. More importantly, using XLA lowers memory usage, enabling gradient accumulation, which boosts computational throughput by 12 times in the long run.

Source: AssemblyAI 

JAX also allows users to transform their functions into just-in-time or JIT-compiled versions. With JIT, the speed of subsequent executions can be improved by adding a simple function decorator. However, every function can’t be compiled using JIT in JAX. The JAX documentation notes the exceptions to this rule. 

Source: Tensorflow

Compatibility with GPUs: Unlike Numpy, which is only compatible with CPUs, JAX is compatible with both CPUs and GPUs easily and has an API that is very similar to Numpy. This is why JAX is able to auto-compile code directly on accelerators like GPUs and TPUs without any changes, making the process seamless. A user can write their code just once using syntax that is similar to Numpy, try it out on the CPU and then shift it to a GPU cluster smoothly. 

Automatic differentiation: JAX aims to differentiate between native Python and Numpy functions automatically. Most of the optimisation algorithms in machine learning use the gradients of the functions to minimise losses. JAX simplifies differentiation with the help of the updated version of autograd. 

Vectorisation: JAX offers automatic vectorisation via the vmap transformation, which makes life easier for developers. In ML research, a single function is applied to a lot of data at times, say to calculate losses across a batch or to evaluate per-example gradients for differentially private learning. In instances where the data is too large for a single accelerator, JAX performs data parallelism on a large scale using the related pmap transformation. 

Source: Developpaper.com

Deep learning: While JAX is not just a deep learning framework, it has proven to be a solid foundation for deep learning tasks. Libraries like flax, haiku and elegy have been built on top of JAX for deep learning processes. Hessians perform higher-order optimisation techniques in deep learning, and JAX is efficient at computing them. JAX is able to compute Hessians much faster than PyTorch, thanks to XLA. 

Why not JAX? 

  • Despite JAX’s reputation as an accelerator, it isn’t fully optimised for per-operating dispatch in CPU computing. Numpy may, in fact, be faster than JAX in specific cases, especially in small programs, because of the overhead by JAX. 
  • JAX is relatively a newer framework and lacks the complete infrastructure that TensorFlow has built over the years. Frameworks like TensorFlow have greater portability deployment and can be employed in a variety of cases like open-source projects, pre-trained models, tutorials and higher-level abstractions through Keras. JAX is still in the research phase and isn’t even promoted by Google as a fully-formed final product. 
  • The time taken and the costs incurred to debug or the risk of having untracked side effects might make JAX an insecure bet for new developers. Since side effects like outer-scope encapsulation don’t manifest themselves while running an impure function in JAX, it is difficult to track them. This can lead to serious ramifications in industries like healthcare. 
  • Windows does not support JAX. The only way JAX can be used on a Windows system is by using a Virtual Machine. 
  • JAX doesn’t have a Data Loader. The user needs to either have their own or borrow one from TensorFlow or PyTorch. 
  • JAX is adept at working with lower-level functions that are appropriate for research projects. However, JAX isn’t built for higher-level model abstractions.

Source: Developpaper.com

Despite JAX’s gradual growth, it is currently employed in a range of projects, like in bayesian methods and robotics, apart from deep learning. Last week, DeepMind announced four new libraries that would join their ecosystem. Mctx offers AlphaZero and MuZero Monte Carlo tree search, KFAC-JAX is a library for second-order optimisation of neural networks and for computing scalable curvature approximations, DM_AUX which is for audio signal processing in JAX, providing tools for spectrogram extraction and SpecAug augmentation, and TF2JAX which is a library for converting TensorFlow functions and graphs to JAX functions. 

More Great AIM Stories

Poulomi Chatterjee
Poulomi is a Technology Journalist with Analytics India Magazine. Her fascination with tech and eagerness to dive into new areas led her to the dynamic world of AI and data analytics.

Our Upcoming Events

Conference, in-person (Bangalore)
Machine Learning Developers Summit (MLDS) 2023
19-20th Jan, 2023

Conference, in-person (Bangalore)
Rising 2023 | Women in Tech Conference
16-17th Mar, 2023

Conference, in-person (Bangalore)
Data Engineering Summit (DES) 2023
27-28th Apr, 2023

Conference, in-person (Bangalore)
MachineCon 2023
23rd Jun, 2023

3 Ways to Join our Community

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Telegram Channel

Discover special offers, top stories, upcoming events, and more.

Subscribe to our newsletter

Get the latest updates from AIM