The JAX libraries & frameworks for reinforcement learning

In this article, we’ll explore different libraries and frameworks for reinforcement learning using JAX.

JAX (Just After eXecution) is a machine/deep learning library developed by DeepMind. All JAX operations are based on XLA or Accelerated Linear Algebra. Developed by Google, XLA is a domain-specific compiler for linear algebra that uses whole-program optimisations to accelerate computing. It makes BERT’s training speed faster by almost 7.3 times.

It is designed for high-performance numerical computing. JAX was launched in 2018 and is presently used by Alphabet subsidiary DeepMind. It is similar to the numerical computing library NumPy, another library for Python programming. Its API for numerical functions is based on NumPy.

In this article, we’ll explore different libraries and frameworks for reinforcement learning using JAX. 

Subscribe to our Newsletter

Join our editors every weekday evening as they steer you through the most significant news of the day, introduce you to fresh perspectives, and provide unexpected moments of joy
Your newsletter subscriptions are subject to AIM Privacy Policy and Terms and Conditions.

JAX reinforcement learning agents

1. RLax

RLax (pronounced ‘relax’) is a simple library on JAX. It provides useful building blocks for implementing reinforcement learning agents.

It can be installed directly from GitHub or PyPI.

All RLax code may be compiled for different hardware (e.g. CPU, GPU, TPU) using jax.jit.

For more information, click here.

2. Haiku

Haiku is a library for JAX. The neural network allows users to use familiar programming models, with JAX’s pure function transformations available.

The two core tools provided include a module abstraction, hk.Module, and a simple function transformation, hk.transform.

It is written in pure Python but depends on C++ code.

The repository can be found here.

3. Gymnax

Gymnax is the JAX-compatible version of Open AI’s gym environment. Gym is an open-source Python library for developing and comparing RL algorithms. It provides a standard API to communicate between learning algorithms and environments. After the release, the API has become the field standard for doing this.

For more information, click here

4. Dopamine

Dopamine is a research framework for prototyping RL algorithms. It aims to fulfil the need for a codebase in which users can freely experiment with wild ideas (theoretical research).

The easiest way to use Dopamine is to install it from the source and modify the source code directly. The version released in 2020 supports JAX agents, which includes an implementation of the Quantile Regression agent (QR-DQN).

It can also be installed with pip. Moreover, Dopamine supports Atari environments and Mujoco environments.

For more information, click here.


Launched in 2020, Flax is a high-performance neural network library for JAX that is designed for flexibility: Try new forms of training by forking an example and modifying the training loop, not by adding features to a framework.

It is developed in close collaboration with the JAX team. The basic philosophy behind FLAX is library code should be easy to read and understand.

At its core, Flax is built around parameterised functions called Modules, which can be used as normal functions.

The Google Research: Flax repository is on GitHub.

6. coax

coax is an RL python package for solving GymAI environments with JAX-based function approximators.

It is designed to align with the core RL concepts, not with the high-level concept of an agent. This makes coax more modular and user-friendly for RL users.

For more information, click here

7. Acme

Acme is a library of reinforcement learning (RL) building blocks that strives to expose simple agents. Firstly, these agents serve as reference implementations and provide strong baselines for algorithm performance. The building blocks of Acme are designed in such a way that the agents can be written at multiple scales.

It supports both TensorFlow v2 and JAX. 

For more information, click here

Tasmia Ansari
Tasmia is a tech journalist at AIM, looking to bring a fresh perspective to emerging technologies and trends in data science, analytics, and artificial intelligence.

Download our Mobile App

MachineHack | AI Hackathons, Coding & Learning

Host Hackathons & Recruit Great Data Talent!

AIMResearch Pioneering advanced AI market research

With a decade of experience under our belt, we are transforming how businesses use AI & data-driven insights to succeed.

The Gold Standard for Recognizing Excellence in Data Science and Tech Workplaces

With Best Firm Certification, you can effortlessly delve into the minds of your employees, unveil invaluable perspectives, and gain distinguished acclaim for fostering an exceptional company culture.

AIM Leaders Council

World’s Biggest Community Exclusively For Senior Executives In Data Science And Analytics.

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox