MITB Banner

Watch More

What makes JAX so Awesome

JAX is a high performance numerical computation python library.
JAX Research
Listen to this story

For high-performance machine learning research, Just After eXceution (JAX) is NumPy on the CPU, GPU, and TPU, with excellent automated differentiation. It is a Python library for high-performance numerical computation, particularly machine learning research. Its numerical API is based on NumPy, a library of functions used in scientific computing. Python and NumPy are both renowned and used programming languages, making JAX straightforward, versatile, and simple to implement. This article will focus on the JAX features and implementation to build a deep learning model. Following are the topics to be covered.

Table of contents

  1. Reason to use JAX
  2. What is XLA?
  3. What’s there in the ecosystem of JAX?
  4. Building ML model with JAX

JAX is not an official product by google but its popularity is increasing, let’s know the reasons behind the popularity.

Reason to use JAX

Although JAX provides a straightforward and strong API for developing accelerated numerical code, working efficiently with JAX occasionally necessitates extra thought. JAX is essentially a Just-In-Time (JIT) compiler that focuses on generating efficient code while utilising the simplicity of pure Python.  Aside from the NumPy API, JAX contains an extendable set of composable function transformations that aid in machine learning research, such as:

  • Differentiation: Gradient-based optimization is essential to machine learning. JAX natively enables automated differentiation of arbitrary numerical functions in both forward and reverse mode using function transformations such as Gradients, Hessian and Jacobians (jacfwd and jacrev).
  • Vectorisation: In machine learning research, a single function is frequently applied to large amounts of data, such as computing the loss across a batch or assessing per-example gradients for differentially private learning. The vmap transformation in JAX enables automated vectorisation, which simplifies this type of programming. When developing new algorithms, for example, researchers do not need to consider batching. JAX also allows large-scale data parallelism with the related pmap transformation, which elegantly distributes data that is too vast for a single accelerator’s memory.
  • Just-in-time (JIT) compilation: XLA is used to JIT-compile and run JAX applications on GPU and Cloud TPU accelerators. JIT compilation, in conjunction with JAX’s NumPy-consistent API, enables researchers with no prior experience in high-performance computing to readily scale to one or more accelerators.

Are you looking for a complete repository of Python libraries used in data science, check out here.

What is XLA?

XLA (Accelerated Linear Algebra) is a domain-specific linear algebra compiler that can accelerate TensorFlow models with little source code modifications. 

When a TensorFlow programme is performed, the TensorFlow executor performs each operation independently. The executor dispatches to a pre-compiled GPU kernel implementation for each TensorFlow operation. XLA offers an additional manner of model execution by compiling the TensorFlow graph into a sequence of computing kernels built particularly for the specified model. Because these kernels are model-specific, they may use model-specific information to optimise.

Architecture of XLA

The input language to XLA is called High-Level Operations (HLO). It is most convenient to think of HLO as a compiler intermediate representation. So, HLO represents a program “between” the source and target languages.

XLA translates graphs described in HLO into machine instructions for multiple platforms. XLA is modular in the sense that an alternate backend may be easily inserted to target some innovative hardware architecture. XLA transfers the HLO computation to a backend after the target-independent phase. The backend can do additional HLO-level optimizations, this time with target-specific data and requirements in mind.

The following step is to generate target-specific code. LLVM is used by the CPU and GPU backends bundled with XLA for low-level intermediate representation optimization and code creation. These backends produce the LLVM IR required to efficiently describe the XLA HLO calculation and then use LLVM to emit native code from this LLVM intermediate representation.

Reason to use XLA

There are four major reasons to use XLA.

  • Because translation appears to entail analysis and synthesis by definition. Word-for-word translation is ineffective.
  • To divide the complex challenge of translation into two simpler, more manageable halves.
  • A new back end might be constructed for an existing front end to provide retargetable compilers and vice versa.
  • To carry out machine-independent optimizations.

What’s there in the ecosystem of JAX?

The ecosystem consists of five different libraries.

Haiku

Dealing with stateful objects, such as neural networks with trainable parameters, might be difficult with the JAX programming paradigm of composable function transformations. Haiku is a neural network library that enables users to use traditional object-oriented programming paradigms while making use of the power and simplicity of JAX’s pure functional paradigm.

Several external projects, including Coax, DeepChem, and NumPyro, actively use Haiku. It extends the API for Sonnet, our module-based neural network programming model in TensorFlow.

Optax

Gradient-based optimization is important to machine learning. Optax includes a gradient transformation library as well as composition operators (such as chain) that allow the development of numerous common optimisers (such as RMSProp or Adam) in a single line of code. Optax’s compositional structure lends itself readily to recombining the same fundamental elements in bespoke optimisers. It also includes utilities for stochastic gradient estimation and second-order optimization.

RLax

RLax is a library that provides important building blocks for the development of reinforcement learning (RL), also known as deep reinforcement learning. RLax’s components include TD-learning, policy gradients, actor critics, MAP, proximal policy optimisation, non-linear value transformation, generic value functions, and numerous exploration approaches.

RLax is not meant to be a framework for developing and deploying full-fledged RL agent systems. Acme is one example of a fully-featured agent architecture built on RLax components.

Chex

Testing is essential for the reliability of software, and research code is no exception. Drawing scientific findings from research trials necessitates faith in your code’s accuracy. Chex is a collection of testing utilities used by library writers to ensure that the common building blocks are correct and resilient, as well as by end-users to validate their experimental programmes.

Chex includes a number of tools, such as JAX-aware unit testing, assertions on JAX data type attributes, mocks and fakes, and multi-device test environments.

Jraph

Jraph is a little library for working with Graph neural networks GNNs in JAX. Jraph provides a standardised data structure for graphs, a set of tools for working with graphs, and a set of graph neural network models that are readily forkable and expandable. Other major features include GraphTuple batching that takes advantage of hardware accelerators, JIT-compilation support for variable-shaped graphs through padding and masking, and losses specified across input partitions. Jraph, like Optax and our other libraries, has no restrictions on the user’s choice of a neural network library.

Building ML model with JAX

For this article building a Generative Adversarial Net model on the TensorFlow platform trained on the MNIST dataset in Jax’s Haiku.

Let’s start by installing the Haiku and Optax

!pip install dm-haiku
! pip install optax

Import necessary libraries

import functools
from typing import Any, NamedTuple
 
import haiku as hk
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

Reading the dataset

mnist_dataset = tfds.load("mnist")
def make_dataset(batch_size, seed=1):
  def _preprocess(sample):
    image = tf.image.convert_image_dtype(sample["image"], tf.float32)
    return 2.0 * image - 1.0
 
  ds = mnist["train"]
  ds = ds.map(map_func=_preprocess, 
              num_parallel_calls=tf.data.experimental.AUTOTUNE)
  ds = ds.cache()
  ds = ds.shuffle(10 * batch_size, seed=seed).repeat().batch(batch_size)
  return iter(tfds.as_numpy(ds))

Creating generator and discriminator

The model is utilised as a generator to produce new plausible instances from the issue area whereas The model is used as a discriminator to determine if an example is real (from the domain) or generated.

class Generator(hk.Module):
  def __init__(self, output_channels=(32, 1), name=None):
    super().__init__(name=name)
    self.output_channels = output_channels
 
  def __call__(self, x):
    x = hk.Linear(7 * 7 * 64)(x)
    x = jnp.reshape(x, x.shape[:1] + (7, 7, 64)) 
    for output_channels in self.output_channels:
      x = jax.nn.relu(x)
      x = hk.Conv2DTranspose(output_channels=output_channels,
                             kernel_shape=[5, 5],
                             stride=2,
                             padding="SAME")(x)
    return jnp.tanh(x)
class Discriminator(hk.Module):
 
  def __init__(self,
               output_channels=(8, 16, 32, 64, 128),
               strides=(2, 1, 2, 1, 2),
               name=None):   
    super().__init__(name=name)
    self.output_channels = output_channels
    self.strides = strides
 
  def __call__(self, x):
    for output_channels, stride in zip(self.output_channels, self.strides):
      x = hk.Conv2D(output_channels=output_channels,
                    kernel_shape=[5, 5],
                    stride=stride,
                    padding="SAME")(x)
      x = jax.nn.leaky_relu(x, negative_slope=0.2)
    x = hk.Flatten()(x)    
    logits = hk.Linear(2)(x)
    return logits

Creating the GAN algorithm

import optax
class GAN_algo_basic:
  def __init__(self, num_latents):
    self.num_latents = num_latents
    self.gen_transform = hk.without_apply_rng(
        hk.transform(lambda *args: Generator()(*args)))
    self.disc_transform = hk.without_apply_rng(
        hk.transform(lambda *args: Discriminator()(*args)))
    self.optimizers = GANTuple(gen=optax.adam(1e-4, b1=0.5, b2=0.9),
                               disc=optax.adam(1e-4, b1=0.5, b2=0.9))
 
  @functools.partial(jax.jit, static_argnums=0)
  def initial_state(self, rng, batch):
    dummy_latents = jnp.zeros((batch.shape[0], self.num_latents))
    rng_gen, rng_disc = jax.random.split(rng)
    params = GANTuple(gen=self.gen_transform.init(rng_gen, dummy_latents),
                      disc=self.disc_transform.init(rng_disc, batch))
    print("Generator: \n\n{}\n".format(tree_shape(params.gen)))
    print("Discriminator: \n\n{}\n".format(tree_shape(params.disc)))
    opt_state = GANTuple(gen=self.optimizers.gen.init(params.gen),
                         disc=self.optimizers.disc.init(params.disc))
    
    return GANState(params=params, opt_state=opt_state)
 
  def sample(self, rng, gen_params, num_samples):
    """Generates images from noise latents."""
    latents = jax.random.normal(rng, shape=(num_samples, self.num_latents))
    return self.gen_transform.apply(gen_params, latents)
 
  def gen_loss(self, gen_params, rng, disc_params, batch):
    fake_batch = self.sample(rng, gen_params, num_samples=batch.shape[0])
    fake_logits = self.disc_transform.apply(disc_params, fake_batch)
    fake_probs = jax.nn.softmax(fake_logits)[:, 1]
    loss = -jnp.log(fake_probs)
    
    return jnp.mean(loss)
 
  def disc_loss(self, disc_params, rng, gen_params, batch):
    fake_batch = self.sample(rng, gen_params, num_samples=batch.shape[0])
    real_and_fake_batch = jnp.concatenate([batch, fake_batch], axis=0)
    real_and_fake_logits = self.disc_transform.apply(disc_params, 
                                                     real_and_fake_batch)
    real_logits, fake_logits = jnp.split(real_and_fake_logits, 2, axis=0)
    real_labels = jnp.ones((batch.shape[0],), dtype=jnp.int32)
    real_loss = sparse_softmax_cross_entropy(real_logits, real_labels)
    fake_labels = jnp.zeros((batch.shape[0],), dtype=jnp.int32)
    fake_loss = sparse_softmax_cross_entropy(fake_logits, fake_labels)
 
    return jnp.mean(real_loss + fake_loss)
  @functools.partial(jax.jit, static_argnums=0)
  def update(self, rng, gan_state, batch):
    rng, rng_gen, rng_disc = jax.random.split(rng, 3)
    disc_loss, disc_grads = jax.value_and_grad(self.disc_loss)(
        gan_state.params.disc,
        rng_disc, 
        gan_state.params.gen,
        batch)
    disc_update, disc_opt_state = self.optimizers.disc.update(
        disc_grads, gan_state.opt_state.disc)
    disc_params = optax.apply_updates(gan_state.params.disc, disc_update)
    gen_loss, gen_grads = jax.value_and_grad(self.gen_loss)(
        gan_state.params.gen,
        rng_gen, 
        gan_state.params.disc,
        batch)
    gen_update, gen_opt_state = self.optimizers.gen.update(
        gen_grads, gan_state.opt_state.gen)
    gen_params = optax.apply_updates(gan_state.params.gen, gen_update)
    
    params = GANTuple(gen=gen_params, disc=disc_params)
    opt_state = GANTuple(gen=gen_opt_state, disc=disc_opt_state)
    gan_state = GANState(params=params, opt_state=opt_state)
    log = {
        "gen_loss": gen_loss,
        "disc_loss": disc_loss,
    }
 
    return rng, gan_state, log

Training the model

for step in range(num_steps):
  rng, gan_state, log = model.update(rng, gan_state, next(dataset))
  if step % log_every == 0:   
    log = jax.device_get(log)
    gen_loss = log["gen_loss"]
    disc_loss = log["disc_loss"]
    print(f"Step {step}: "
          f"gen_loss = {gen_loss:.3f}, disc_loss = {disc_loss:.3f}")
    steps.append(step)
    gen_losses.append(gen_loss)
    disc_losses.append(disc_loss)

The model will be trained for 5000 steps due to time constraints. It depends on the user for selecting the number of steps. For 5000 steps it took approximately 60 minutes.

Analytics India Magazine

Analyzing the losses for the generator and discriminator

fig, axes = plt.subplots(1, 2, figsize=(20, 6))
 
# Plot the discriminator loss.
axes[0].plot(steps, disc_losses, "-")
axes[0].set_title("Discriminator loss", fontsize=20)
 
# Plot the generator loss.
axes[1].plot(steps, gen_losses, '-')
axes[1].set_title("Generator loss", fontsize=20);
Analytics India Magazine

We can observe that the generator loss was pretty high during the initial 2000 steps and after 3000 steps the discriminator and generator loss got approximately constant on average.

Conclusion

Just After eXceution (JAX) is a high-performance numerical computation, particularly in machine learning research. Its numerical API is based on NumPy, a library of functions used in scientific computing. With this article, we have understood the ecosystem of JAX and the implementation of Optax and Haiku which are part of that ecosystem.

References

Access all our open Survey & Awards Nomination forms in one place >>

Picture of Sourabh Mehta

Sourabh Mehta

Sourabh has worked as a full-time data scientist for an ISP organisation, experienced in analysing patterns and their implementation in product development. He has a keen interest in developing solutions for real-time problems with the help of data both in this universe and metaverse.

Download our Mobile App

CORPORATE TRAINING PROGRAMS ON GENERATIVE AI

Generative AI Skilling for Enterprises

Our customized corporate training program on Generative AI provides a unique opportunity to empower, retain, and advance your talent.

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox
Recent Stories