Listen to this story
|
For high-performance machine learning research, Just After eXceution (JAX) is NumPy on the CPU, GPU, and TPU, with excellent automated differentiation. It is a Python library for high-performance numerical computation, particularly machine learning research. Its numerical API is based on NumPy, a library of functions used in scientific computing. Python and NumPy are both renowned and used programming languages, making JAX straightforward, versatile, and simple to implement. This article will focus on the JAX features and implementation to build a deep learning model. Following are the topics to be covered.
Table of contents
- Reason to use JAX
- What is XLA?
- What’s there in the ecosystem of JAX?
- Building ML model with JAX
JAX is not an official product by google but its popularity is increasing, let’s know the reasons behind the popularity.
Reason to use JAX
Although JAX provides a straightforward and strong API for developing accelerated numerical code, working efficiently with JAX occasionally necessitates extra thought. JAX is essentially a Just-In-Time (JIT) compiler that focuses on generating efficient code while utilising the simplicity of pure Python. Aside from the NumPy API, JAX contains an extendable set of composable function transformations that aid in machine learning research, such as:
- Differentiation: Gradient-based optimization is essential to machine learning. JAX natively enables automated differentiation of arbitrary numerical functions in both forward and reverse mode using function transformations such as Gradients, Hessian and Jacobians (jacfwd and jacrev).
- Vectorisation: In machine learning research, a single function is frequently applied to large amounts of data, such as computing the loss across a batch or assessing per-example gradients for differentially private learning. The vmap transformation in JAX enables automated vectorisation, which simplifies this type of programming. When developing new algorithms, for example, researchers do not need to consider batching. JAX also allows large-scale data parallelism with the related pmap transformation, which elegantly distributes data that is too vast for a single accelerator’s memory.
- Just-in-time (JIT) compilation: XLA is used to JIT-compile and run JAX applications on GPU and Cloud TPU accelerators. JIT compilation, in conjunction with JAX’s NumPy-consistent API, enables researchers with no prior experience in high-performance computing to readily scale to one or more accelerators.
Are you looking for a complete repository of Python libraries used in data science, check out here.
What is XLA?
XLA (Accelerated Linear Algebra) is a domain-specific linear algebra compiler that can accelerate TensorFlow models with little source code modifications.
When a TensorFlow programme is performed, the TensorFlow executor performs each operation independently. The executor dispatches to a pre-compiled GPU kernel implementation for each TensorFlow operation. XLA offers an additional manner of model execution by compiling the TensorFlow graph into a sequence of computing kernels built particularly for the specified model. Because these kernels are model-specific, they may use model-specific information to optimise.
Architecture of XLA
The input language to XLA is called High-Level Operations (HLO). It is most convenient to think of HLO as a compiler intermediate representation. So, HLO represents a program “between” the source and target languages.
XLA translates graphs described in HLO into machine instructions for multiple platforms. XLA is modular in the sense that an alternate backend may be easily inserted to target some innovative hardware architecture. XLA transfers the HLO computation to a backend after the target-independent phase. The backend can do additional HLO-level optimizations, this time with target-specific data and requirements in mind.
The following step is to generate target-specific code. LLVM is used by the CPU and GPU backends bundled with XLA for low-level intermediate representation optimization and code creation. These backends produce the LLVM IR required to efficiently describe the XLA HLO calculation and then use LLVM to emit native code from this LLVM intermediate representation.
Reason to use XLA
There are four major reasons to use XLA.
- Because translation appears to entail analysis and synthesis by definition. Word-for-word translation is ineffective.
- To divide the complex challenge of translation into two simpler, more manageable halves.
- A new back end might be constructed for an existing front end to provide retargetable compilers and vice versa.
- To carry out machine-independent optimizations.
What’s there in the ecosystem of JAX?
The ecosystem consists of five different libraries.
Haiku
Dealing with stateful objects, such as neural networks with trainable parameters, might be difficult with the JAX programming paradigm of composable function transformations. Haiku is a neural network library that enables users to use traditional object-oriented programming paradigms while making use of the power and simplicity of JAX’s pure functional paradigm.
Several external projects, including Coax, DeepChem, and NumPyro, actively use Haiku. It extends the API for Sonnet, our module-based neural network programming model in TensorFlow.
Optax
Gradient-based optimization is important to machine learning. Optax includes a gradient transformation library as well as composition operators (such as chain) that allow the development of numerous common optimisers (such as RMSProp or Adam) in a single line of code. Optax’s compositional structure lends itself readily to recombining the same fundamental elements in bespoke optimisers. It also includes utilities for stochastic gradient estimation and second-order optimization.
RLax
RLax is a library that provides important building blocks for the development of reinforcement learning (RL), also known as deep reinforcement learning. RLax’s components include TD-learning, policy gradients, actor critics, MAP, proximal policy optimisation, non-linear value transformation, generic value functions, and numerous exploration approaches.
RLax is not meant to be a framework for developing and deploying full-fledged RL agent systems. Acme is one example of a fully-featured agent architecture built on RLax components.
Chex
Testing is essential for the reliability of software, and research code is no exception. Drawing scientific findings from research trials necessitates faith in your code’s accuracy. Chex is a collection of testing utilities used by library writers to ensure that the common building blocks are correct and resilient, as well as by end-users to validate their experimental programmes.
Chex includes a number of tools, such as JAX-aware unit testing, assertions on JAX data type attributes, mocks and fakes, and multi-device test environments.
Jraph
Jraph is a little library for working with Graph neural networks GNNs in JAX. Jraph provides a standardised data structure for graphs, a set of tools for working with graphs, and a set of graph neural network models that are readily forkable and expandable. Other major features include GraphTuple batching that takes advantage of hardware accelerators, JIT-compilation support for variable-shaped graphs through padding and masking, and losses specified across input partitions. Jraph, like Optax and our other libraries, has no restrictions on the user’s choice of a neural network library.
Building ML model with JAX
For this article building a Generative Adversarial Net model on the TensorFlow platform trained on the MNIST dataset in Jax’s Haiku.
Let’s start by installing the Haiku and Optax
!pip install dm-haiku ! pip install optax
Import necessary libraries
import functools from typing import Any, NamedTuple import haiku as hk import jax import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np import tensorflow as tf import tensorflow_datasets as tfds
Reading the dataset
mnist_dataset = tfds.load("mnist") def make_dataset(batch_size, seed=1): def _preprocess(sample): image = tf.image.convert_image_dtype(sample["image"], tf.float32) return 2.0 * image - 1.0 ds = mnist["train"] ds = ds.map(map_func=_preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE) ds = ds.cache() ds = ds.shuffle(10 * batch_size, seed=seed).repeat().batch(batch_size) return iter(tfds.as_numpy(ds))
Creating generator and discriminator
The model is utilised as a generator to produce new plausible instances from the issue area whereas The model is used as a discriminator to determine if an example is real (from the domain) or generated.
class Generator(hk.Module): def __init__(self, output_channels=(32, 1), name=None): super().__init__(name=name) self.output_channels = output_channels def __call__(self, x): x = hk.Linear(7 * 7 * 64)(x) x = jnp.reshape(x, x.shape[:1] + (7, 7, 64)) for output_channels in self.output_channels: x = jax.nn.relu(x) x = hk.Conv2DTranspose(output_channels=output_channels, kernel_shape=[5, 5], stride=2, padding="SAME")(x) return jnp.tanh(x)
class Discriminator(hk.Module): def __init__(self, output_channels=(8, 16, 32, 64, 128), strides=(2, 1, 2, 1, 2), name=None): super().__init__(name=name) self.output_channels = output_channels self.strides = strides def __call__(self, x): for output_channels, stride in zip(self.output_channels, self.strides): x = hk.Conv2D(output_channels=output_channels, kernel_shape=[5, 5], stride=stride, padding="SAME")(x) x = jax.nn.leaky_relu(x, negative_slope=0.2) x = hk.Flatten()(x) logits = hk.Linear(2)(x) return logits
Creating the GAN algorithm
import optax class GAN_algo_basic: def __init__(self, num_latents): self.num_latents = num_latents self.gen_transform = hk.without_apply_rng( hk.transform(lambda *args: Generator()(*args))) self.disc_transform = hk.without_apply_rng( hk.transform(lambda *args: Discriminator()(*args))) self.optimizers = GANTuple(gen=optax.adam(1e-4, b1=0.5, b2=0.9), disc=optax.adam(1e-4, b1=0.5, b2=0.9)) @functools.partial(jax.jit, static_argnums=0) def initial_state(self, rng, batch): dummy_latents = jnp.zeros((batch.shape[0], self.num_latents)) rng_gen, rng_disc = jax.random.split(rng) params = GANTuple(gen=self.gen_transform.init(rng_gen, dummy_latents), disc=self.disc_transform.init(rng_disc, batch)) print("Generator: \n\n{}\n".format(tree_shape(params.gen))) print("Discriminator: \n\n{}\n".format(tree_shape(params.disc))) opt_state = GANTuple(gen=self.optimizers.gen.init(params.gen), disc=self.optimizers.disc.init(params.disc)) return GANState(params=params, opt_state=opt_state) def sample(self, rng, gen_params, num_samples): """Generates images from noise latents.""" latents = jax.random.normal(rng, shape=(num_samples, self.num_latents)) return self.gen_transform.apply(gen_params, latents) def gen_loss(self, gen_params, rng, disc_params, batch): fake_batch = self.sample(rng, gen_params, num_samples=batch.shape[0]) fake_logits = self.disc_transform.apply(disc_params, fake_batch) fake_probs = jax.nn.softmax(fake_logits)[:, 1] loss = -jnp.log(fake_probs) return jnp.mean(loss) def disc_loss(self, disc_params, rng, gen_params, batch): fake_batch = self.sample(rng, gen_params, num_samples=batch.shape[0]) real_and_fake_batch = jnp.concatenate([batch, fake_batch], axis=0) real_and_fake_logits = self.disc_transform.apply(disc_params, real_and_fake_batch) real_logits, fake_logits = jnp.split(real_and_fake_logits, 2, axis=0) real_labels = jnp.ones((batch.shape[0],), dtype=jnp.int32) real_loss = sparse_softmax_cross_entropy(real_logits, real_labels) fake_labels = jnp.zeros((batch.shape[0],), dtype=jnp.int32) fake_loss = sparse_softmax_cross_entropy(fake_logits, fake_labels) return jnp.mean(real_loss + fake_loss) @functools.partial(jax.jit, static_argnums=0) def update(self, rng, gan_state, batch): rng, rng_gen, rng_disc = jax.random.split(rng, 3) disc_loss, disc_grads = jax.value_and_grad(self.disc_loss)( gan_state.params.disc, rng_disc, gan_state.params.gen, batch) disc_update, disc_opt_state = self.optimizers.disc.update( disc_grads, gan_state.opt_state.disc) disc_params = optax.apply_updates(gan_state.params.disc, disc_update) gen_loss, gen_grads = jax.value_and_grad(self.gen_loss)( gan_state.params.gen, rng_gen, gan_state.params.disc, batch) gen_update, gen_opt_state = self.optimizers.gen.update( gen_grads, gan_state.opt_state.gen) gen_params = optax.apply_updates(gan_state.params.gen, gen_update) params = GANTuple(gen=gen_params, disc=disc_params) opt_state = GANTuple(gen=gen_opt_state, disc=disc_opt_state) gan_state = GANState(params=params, opt_state=opt_state) log = { "gen_loss": gen_loss, "disc_loss": disc_loss, } return rng, gan_state, log
Training the model
for step in range(num_steps): rng, gan_state, log = model.update(rng, gan_state, next(dataset)) if step % log_every == 0: log = jax.device_get(log) gen_loss = log["gen_loss"] disc_loss = log["disc_loss"] print(f"Step {step}: " f"gen_loss = {gen_loss:.3f}, disc_loss = {disc_loss:.3f}") steps.append(step) gen_losses.append(gen_loss) disc_losses.append(disc_loss)
The model will be trained for 5000 steps due to time constraints. It depends on the user for selecting the number of steps. For 5000 steps it took approximately 60 minutes.
Analyzing the losses for the generator and discriminator
fig, axes = plt.subplots(1, 2, figsize=(20, 6)) # Plot the discriminator loss. axes[0].plot(steps, disc_losses, "-") axes[0].set_title("Discriminator loss", fontsize=20) # Plot the generator loss. axes[1].plot(steps, gen_losses, '-') axes[1].set_title("Generator loss", fontsize=20);
We can observe that the generator loss was pretty high during the initial 2000 steps and after 3000 steps the discriminator and generator loss got approximately constant on average.
Conclusion
Just After eXceution (JAX) is a high-performance numerical computation, particularly in machine learning research. Its numerical API is based on NumPy, a library of functions used in scientific computing. With this article, we have understood the ecosystem of JAX and the implementation of Optax and Haiku which are part of that ecosystem.