MITB Banner

Why DeepMind Is Bullish On This One Python Library

Share

JAX Research

DeepMind has been leveraging JAX, a Python library used extensively in machine learning, for numerical computing applications in a number of experiments. Through this article, we will understand how the company is building its own JAX Ecosystem, complete with different libraries to push their research.

What Is JAX Ecosystem At DeepMind

JAX’s API for numerical computation and functions is based on NumPy. In addition to the NumPy API, JAX also includes a system of function transformations that support machine learning functions such as —

  • JAX supports forward and backward mode of automatic differentiation of arbitrary numerical functions. It achieves that using function transformations like grad, hessian, jacrev, and jacfwd.
  • JAX provides automatic vectorisation, which is very central to ML research, through function transformation such as vmap, which simplifies programming. Vectorisation in Python, which is used to speed the code without using a loop.
  • JIT-compilation, when used with JAX, can help researchers with no prior experience in high-performance computing to scale many accelerators.

Leveraging these advantages, DeepMind has been using JAX, which now enables many of its experiments through novel algorithms and architectures.

One way adopted by DeepMind in order to develop the JAX Ecosystem is by modularisation. It is a method of extracting the crucial building blocks developed in each project into tested and efficient components. Through this, researchers can focus on their research and at the same time, benefit from code reuse, bug fixes and other performance improvements implemented by core libraries. Other factors to be considered for providing maximum flexibility to the researchers include — incremental buy-in and the ability to choose features without getting locked into others.

Further, DeepMind researchers also make sure that the JAX Ecosystem remains consistent using existing TensorFlow libraries such as Sonnet and TRFL. They are now working towards building self-descriptive components. 

Through a blog, DeepMind announced that it would be open-sourcing its libraries to share research outputs and make its JAX Ecosystem available to the broader community.

Libraries for JAX Ecosystem At DeepMind

Haiku: It is a neural network library for TensorFlow that helps in managing model parameters and other states. It allows users to use familiar object-oriented programming models and at the same time, harnesses the simplicity of JAX’s pure functional paradigm. It is built over the API for Sonnet, which is a module-based model for neural networks in TensorFlow.

Haiku is used by several researchers across the DeepMind community and Google as well. Some of its more popular use cases include external projects such as Coax, DeepChem and NumPyro. The team is now working on simplifying the porting from Sonnet to Haiku.

Optax: It is a gradient transformation library. Along with composition operators such as chain, Optax allows the implementation of several standard optimisers with just a line of code. It also offers facilities for stochastic gradient estimation and second-order optimisation. Some of its use cases include Elegy, Flax, and Stax frameworks.

RLax: Since many of the research projects at DeepMind are based on deep reinforcement learning (deep learning and reinforcement learning), RLax proves to be a library that provides building blocks for creating RL agents. The components in the RLax library include several algorithms such as TD-learning, policy gradients, MAP, non-linear value transformation, actor critics, general value functions, proximal policy optimisation, among others.

Chex: Chex is a collection of testing facilities that help users to verify the correctness of the experimental code. It provides a range of utilities that include JAX data types, multi-device test environments, and mocks and fakes. Some of its use cases are Coax and MineRL.

Jraph: It is a lightweight library that supports working with graph neural networks. Jraph provides a standardised data structure for graphs and a ‘zoo’ of easy and extensible neural networks. Some of its features include JIT-compilation support using padding and masking, GraphTuples to leverage hardware accelerators, and defined losses over input partitions.

Wrapping Up

DeepMind Engineers believe that JAX library will greatly improve and accelerate their research efforts by building tools, scaling algorithms, and creating challenging environments for training and testing of AI systems.

Share
Picture of Shraddha Goled

Shraddha Goled

I am a technology journalist with AIM. I write stories focused on the AI landscape in India and around the world with a special interest in analysing its long term impact on individuals and societies. Reach out to me at shraddha.goled@analyticsindiamag.com.
Related Posts

CORPORATE TRAINING PROGRAMS ON GENERATIVE AI

Generative AI Skilling for Enterprises

Our customized corporate training program on Generative AI provides a unique opportunity to empower, retain, and advance your talent.

Upcoming Large format Conference

May 30 and 31, 2024 | 📍 Bangalore, India

Download the easiest way to
stay informed

Subscribe to The Belamy: Our Weekly Newsletter

Biggest AI stories, delivered to your inbox every week.

AI Courses & Careers

Become a Certified Generative AI Engineer

AI Forum for India

Our Discord Community for AI Ecosystem, In collaboration with NVIDIA. 

Flagship Events

Rising 2024 | DE&I in Tech Summit

April 4 and 5, 2024 | 📍 Hilton Convention Center, Manyata Tech Park, Bangalore

MachineCon GCC Summit 2024

June 28 2024 | 📍Bangalore, India

MachineCon USA 2024

26 July 2024 | 583 Park Avenue, New York

Cypher India 2024

September 25-27, 2024 | 📍Bangalore, India

Cypher USA 2024

Nov 21-22 2024 | 📍Santa Clara Convention Center, California, USA

Data Engineering Summit 2024

May 30 and 31, 2024 | 📍 Bangalore, India

Subscribe to Our Newsletter

The Belamy, our weekly Newsletter is a rage. Just enter your email below.