DeepMind has been leveraging JAX, a Python library used extensively in machine learning, for numerical computing applications in a number of experiments. Through this article, we will understand how the company is building its own JAX Ecosystem, complete with different libraries to push their research.
What Is JAX Ecosystem At DeepMind
JAX’s API for numerical computation and functions is based on NumPy. In addition to the NumPy API, JAX also includes a system of function transformations that support machine learning functions such as —
- JAX supports forward and backward mode of automatic differentiation of arbitrary numerical functions. It achieves that using function transformations like grad, hessian, jacrev, and jacfwd.
- JAX provides automatic vectorisation, which is very central to ML research, through function transformation such as vmap, which simplifies programming. Vectorisation in Python, which is used to speed the code without using a loop.
- JIT-compilation, when used with JAX, can help researchers with no prior experience in high-performance computing to scale many accelerators.
Leveraging these advantages, DeepMind has been using JAX, which now enables many of its experiments through novel algorithms and architectures.
One way adopted by DeepMind in order to develop the JAX Ecosystem is by modularisation. It is a method of extracting the crucial building blocks developed in each project into tested and efficient components. Through this, researchers can focus on their research and at the same time, benefit from code reuse, bug fixes and other performance improvements implemented by core libraries. Other factors to be considered for providing maximum flexibility to the researchers include — incremental buy-in and the ability to choose features without getting locked into others.
Further, DeepMind researchers also make sure that the JAX Ecosystem remains consistent using existing TensorFlow libraries such as Sonnet and TRFL. They are now working towards building self-descriptive components.
Through a blog, DeepMind announced that it would be open-sourcing its libraries to share research outputs and make its JAX Ecosystem available to the broader community.
Libraries for JAX Ecosystem At DeepMind
Haiku: It is a neural network library for TensorFlow that helps in managing model parameters and other states. It allows users to use familiar object-oriented programming models and at the same time, harnesses the simplicity of JAX’s pure functional paradigm. It is built over the API for Sonnet, which is a module-based model for neural networks in TensorFlow.
Haiku is used by several researchers across the DeepMind community and Google as well. Some of its more popular use cases include external projects such as Coax, DeepChem and NumPyro. The team is now working on simplifying the porting from Sonnet to Haiku.
Optax: It is a gradient transformation library. Along with composition operators such as chain, Optax allows the implementation of several standard optimisers with just a line of code. It also offers facilities for stochastic gradient estimation and second-order optimisation. Some of its use cases include Elegy, Flax, and Stax frameworks.
RLax: Since many of the research projects at DeepMind are based on deep reinforcement learning (deep learning and reinforcement learning), RLax proves to be a library that provides building blocks for creating RL agents. The components in the RLax library include several algorithms such as TD-learning, policy gradients, MAP, non-linear value transformation, actor critics, general value functions, proximal policy optimisation, among others.
Chex: Chex is a collection of testing facilities that help users to verify the correctness of the experimental code. It provides a range of utilities that include JAX data types, multi-device test environments, and mocks and fakes. Some of its use cases are Coax and MineRL.
Jraph: It is a lightweight library that supports working with graph neural networks. Jraph provides a standardised data structure for graphs and a ‘zoo’ of easy and extensible neural networks. Some of its features include JIT-compilation support using padding and masking, GraphTuples to leverage hardware accelerators, and defined losses over input partitions.
Wrapping Up
DeepMind Engineers believe that JAX library will greatly improve and accelerate their research efforts by building tools, scaling algorithms, and creating challenging environments for training and testing of AI systems.