The JAX libraries & frameworks for reinforcement learning

In this article, we’ll explore different libraries and frameworks for reinforcement learning using JAX.

JAX (Just After eXecution) is a machine/deep learning library developed by DeepMind. All JAX operations are based on XLA or Accelerated Linear Algebra. Developed by Google, XLA is a domain-specific compiler for linear algebra that uses whole-program optimisations to accelerate computing. It makes BERT’s training speed faster by almost 7.3 times.

It is designed for high-performance numerical computing. JAX was launched in 2018 and is presently used by Alphabet subsidiary DeepMind. It is similar to the numerical computing library NumPy, another library for Python programming. Its API for numerical functions is based on NumPy.

In this article, we’ll explore different libraries and frameworks for reinforcement learning using JAX. 


Sign up for your weekly dose of what's up in emerging technology.

JAX reinforcement learning agents

1. RLax

RLax (pronounced ‘relax’) is a simple library on JAX. It provides useful building blocks for implementing reinforcement learning agents.

It can be installed directly from GitHub or PyPI.

All RLax code may be compiled for different hardware (e.g. CPU, GPU, TPU) using jax.jit.

For more information, click here.

2. Haiku

Haiku is a library for JAX. The neural network allows users to use familiar programming models, with JAX’s pure function transformations available.

The two core tools provided include a module abstraction, hk.Module, and a simple function transformation, hk.transform.

It is written in pure Python but depends on C++ code.

The repository can be found here.

3. Gymnax

Gymnax is the JAX-compatible version of Open AI’s gym environment. Gym is an open-source Python library for developing and comparing RL algorithms. It provides a standard API to communicate between learning algorithms and environments. After the release, the API has become the field standard for doing this.

For more information, click here

4. Dopamine

Dopamine is a research framework for prototyping RL algorithms. It aims to fulfil the need for a codebase in which users can freely experiment with wild ideas (theoretical research).

The easiest way to use Dopamine is to install it from the source and modify the source code directly. The version released in 2020 supports JAX agents, which includes an implementation of the Quantile Regression agent (QR-DQN).

It can also be installed with pip. Moreover, Dopamine supports Atari environments and Mujoco environments.

For more information, click here.


Launched in 2020, Flax is a high-performance neural network library for JAX that is designed for flexibility: Try new forms of training by forking an example and modifying the training loop, not by adding features to a framework.

It is developed in close collaboration with the JAX team. The basic philosophy behind FLAX is library code should be easy to read and understand.

At its core, Flax is built around parameterised functions called Modules, which can be used as normal functions.

The Google Research: Flax repository is on GitHub.

6. coax

coax is an RL python package for solving GymAI environments with JAX-based function approximators.

It is designed to align with the core RL concepts, not with the high-level concept of an agent. This makes coax more modular and user-friendly for RL users.

For more information, click here

7. Acme

Acme is a library of reinforcement learning (RL) building blocks that strives to expose simple agents. Firstly, these agents serve as reference implementations and provide strong baselines for algorithm performance. The building blocks of Acme are designed in such a way that the agents can be written at multiple scales.

It supports both TensorFlow v2 and JAX. 

For more information, click here

More Great AIM Stories

Tasmia Ansari
Tasmia is a tech journalist at AIM, looking to bring a fresh perspective to emerging technologies and trends in data science, analytics, and artificial intelligence.

Our Upcoming Events

Conference, in-person (Bangalore)
Machine Learning Developers Summit (MLDS) 2023
19-20th Jan, 2023

Conference, in-person (Bangalore)
Rising 2023 | Women in Tech Conference
16-17th Mar, 2023

Conference, in-person (Bangalore)
Data Engineering Summit (DES) 2023
27-28th Apr, 2023

Conference, in-person (Bangalore)
MachineCon 2023
23rd Jun, 2023

Conference, in-person (Bangalore)
Cypher 2023
20-22nd Sep, 2023

3 Ways to Join our Community

Whatsapp group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our newsletter

Get the latest updates from AIM