JAX (Just After eXecution) is a machine/deep learning library developed by DeepMind. All JAX operations are based on XLA or Accelerated Linear Algebra. Developed by Google, XLA is a domain-specific compiler for linear algebra that uses whole-program optimisations to accelerate computing. It makes BERT’s training speed faster by almost 7.3 times.
It is designed for high-performance numerical computing. JAX was launched in 2018 and is presently used by Alphabet subsidiary DeepMind. It is similar to the numerical computing library NumPy, another library for Python programming. Its API for numerical functions is based on NumPy.
In this article, we’ll explore different libraries and frameworks for reinforcement learning using JAX.
Subscribe to our Newsletter
Join our editors every weekday evening as they steer you through the most significant news of the day, introduce you to fresh perspectives, and provide unexpected moments of joy
Your newsletter subscriptions are subject to AIM Privacy Policy and Terms and Conditions.
JAX reinforcement learning agents
1. RLax
RLax (pronounced ‘relax’) is a simple library on JAX. It provides useful building blocks for implementing reinforcement learning agents.
It can be installed directly from GitHub or PyPI.
All RLax code may be compiled for different hardware (e.g. CPU, GPU, TPU) using jax.jit.
For more information, click here.
2. Haiku
Haiku is a library for JAX. The neural network allows users to use familiar programming models, with JAX’s pure function transformations available.
The two core tools provided include a module abstraction, hk.Module, and a simple function transformation, hk.transform.
It is written in pure Python but depends on C++ code.
The repository can be found here.
3. Gymnax
Gymnax is the JAX-compatible version of Open AI’s gym environment. Gym is an open-source Python library for developing and comparing RL algorithms. It provides a standard API to communicate between learning algorithms and environments. After the release, the API has become the field standard for doing this.
For more information, click here.
4. Dopamine
Dopamine is a research framework for prototyping RL algorithms. It aims to fulfil the need for a codebase in which users can freely experiment with wild ideas (theoretical research).
The easiest way to use Dopamine is to install it from the source and modify the source code directly. The version released in 2020 supports JAX agents, which includes an implementation of the Quantile Regression agent (QR-DQN).
It can also be installed with pip. Moreover, Dopamine supports Atari environments and Mujoco environments.
For more information, click here.
5. JAX FLAX (RL)
Launched in 2020, Flax is a high-performance neural network library for JAX that is designed for flexibility: Try new forms of training by forking an example and modifying the training loop, not by adding features to a framework.
It is developed in close collaboration with the JAX team. The basic philosophy behind FLAX is library code should be easy to read and understand.
At its core, Flax is built around parameterised functions called Modules, which can be used as normal functions.
The Google Research: Flax repository is on GitHub.
6. coax
coax is an RL python package for solving GymAI environments with JAX-based function approximators.
It is designed to align with the core RL concepts, not with the high-level concept of an agent. This makes coax more modular and user-friendly for RL users.
For more information, click here.
7. Acme
Acme is a library of reinforcement learning (RL) building blocks that strives to expose simple agents. Firstly, these agents serve as reference implementations and provide strong baselines for algorithm performance. The building blocks of Acme are designed in such a way that the agents can be written at multiple scales.
It supports both TensorFlow v2 and JAX.
For more information, click here.