Active Hackathon

Top 8 JAX Libraries for Data Scientists in 2021

- EMLP - FedJAX - PIX - BRAX - EFAX - Sklearn-jax-kernels
JAX Libraries

Python library JAX is designed for high-performance numerical computing. Developed by Google researchers, JAX was launched in 2018 and is presently used by Alphabet subsidiary DeepMind. It is very similar to numerical computing library NumPy, another library for Python programming. In fact, its API for numerical functions is based on NumPy. 

It is a framework for automatic differentiation, very similar to TensorFlow or PyTorch


Sign up for your weekly dose of what's up in emerging technology.

Why JAX? 

JAX is promising for machine learning scientists in the sense that it makes machine learning programming intuitive, structured and clean. It includes extensible system, of composable function transformations to help machine learning researchers with the following: 

  • Differentiation: JAX supports both forward and reverse mode automatic differentiation of arbitrary numerical functions. 
  • Vectorisation: JAX provides automatic vectorisation via vmap transformation. It also supports large scale data parallelism via related pmap transformation. 
  • JIT-compilation: Just-in-time or JIT compilation together with JAX’s NumPy-consistent API allows researchers to scale to one or many accelerators. 

Today, we take a look at some of the recent JAX libraries:


Equivariant-MLP is a JAX library for automated construction of equivariant layers in deep learning, through constrained solving. It is based on the ICML2021 paper A Practical Method for Constructing Equivariant Multilayer Perceptrons for Arbitrary Matrix Groups. 

It is used as a tool for building larger models where EMLP is just one component in a larger system. 

To install the package, one needs to use: 

pip install emlp

For more information, click here


FedJAX library is used for developing custom Federated Learning algorithms in JAX. It prioritises the ease-of-use for anyone with the knowledge of NumPy. 

FedJAX is built around the components of: 

  • Federated datasets: clients and a dataset for each client 
  • Models: CNN, ResNet 
  • Optimisers: SGD, Momentum 
  • Federated algorithms: Client updates and server aggregation 

FedJAX uses lightweight wrappers and containers to work on existing implementations– Haiku, Stax, and Optax. 

To install FedJAX, one would require Python 3.6 and a working JAX installation. For CPU-only versions: 

pip install –upgrade pip

pip install –upgrade jax jaxlib  

For other devices: 

pip install fedjax

FedJAX is still in its early stages and is yet to be officially supported by Google. 

For more information, click here


PIX is an image processing library in JAX and for JAX. While it is written in pure Python, PIX depends on C++ code via JAX.

Built on top of JAX, PIX provides image processing functions and tools that can be optimised and parallelised. 

To install PIX use: 

$ pip install git+

For more information, click here


Built on top of JAX, RLax exposes building blocks for implementing reinforcement learning agents. 

It can be installed with pip directly from github using: 

pip install git+git:// 

Or from PyPI: 

pip install relax

For more information, click here


It is a differentiable physics engine to stimulate environments made of rigid bodies, joints and actuators. Written in JAX and designed for use on acceleration hardware, BRAX is also a suite of learning algorithms to train agents to operate in environments. 

BRAX is not only efficient for single-core training but is also scalable to massively parallel simulation. 

One can install BRAX from source using: 

python3 -m venv env

source env/bin/activate

pip install –upgrade pip

pip install -e .

To train a model, use learn. 

For more information, click here


EFAX provides tools to work with the exponential family– a class of probability distributions that include normal, gamma, beta, exponential, Poisson, binomial and Bernoulli– distributions in JAX. 

EFAX provides natural and expectation parametrization, often making developers prefer it over a library like Tensorflow Probability. 

For more information, click here


It is a performant reimplementation of the UniRep protein featurisation model in JAX. Developed by George Church’s Lab, Jax-unirep is self-contained and easily customisation version of the UniRep model, with additional utility APIs that support protein engineering workflows. 

To install Jax-unirep one has to ensure that their compute environment allows them to run JAX code. Additionally, a modern Linux or macOS with a GLIBX>=2.23 is necessary. 

Jax-unirep can by installed from PyPI using the following code:

pip install jax-unirep

Or directly from the source using: 

pip install git+

For more information, click here


Sklearn-jax-kernels has been developed to be utilised on JAX to allow accelerated kernel computations through XLA optimisation, computation on GPUs and for the computation of gradients through kernels. 

It provides the same flexibility and ease of use as scikit-learn kernels, but while improving speed and allowing the faster design of new kernels through automatic differentiation. 

To install via pip use: 

pip install sklearn-jax-kernels

For more information, click here.

More Great AIM Stories

Debolina Biswas
After diving deep into the Indian startup ecosystem, Debolina is now a Technology Journalist. When not writing, she is found reading or playing with paint brushes and palette knives. She can be reached at

Our Upcoming Events

Conference, Virtual
Genpact Analytics Career Day
3rd Sep

Conference, in-person (Bangalore)
Cypher 2022
21-23rd Sep

Conference, in-person (Bangalore)
Machine Learning Developers Summit (MLDS) 2023
19-20th Jan, 2023

Conference, in-person (Bangalore)
Data Engineering Summit (DES) 2023
21st Apr, 2023

Conference, in-person (Bangalore)
MachineCon 2023
23rd Jun, 2023

3 Ways to Join our Community

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Telegram Channel

Discover special offers, top stories, upcoming events, and more.

Subscribe to our newsletter

Get the latest updates from AIM

Council Post: How to Evolve with Changing Workforce

The demand for digital roles is growing rapidly, and scouting for talent is becoming more and more difficult. If organisations do not change their ways to adapt and alter their strategy, it could have a significant business impact.

All Tech Giants: On your Mark, Get Set – Slow!

In September 2021, the FTC published a report on M&As of five top companies in the US that have escaped the antitrust laws. These were Alphabet/Google, Amazon, Apple, Facebook, and Microsoft.

The Digital Transformation Journey of Vedanta

In the current digital ecosystem, the evolving technologies can be seen both as an opportunity to gain new insights as well as a disruption by others, says Vineet Jaiswal, chief digital and technology officer at Vedanta Resources Limited

BlenderBot — Public, Yet Not Too Public

As a footnote, Meta cites access will be granted to academic researchers and people affiliated to government organisations, civil society groups, academia and global industry research labs.