Now Reading
What Is Trax and How Is It A Better Framework For Advanced Deep Learning?

What Is Trax and How Is It A Better Framework For Advanced Deep Learning?


TensorFlow, Pytorch, Caffe, Keras, Theano, and many more. There’s already an abundance of deep learning frameworks, so why should you care about Trax? Well, most deep learning libraries have two major drawbacks:

  • They require you to write long syntaxes, even for simple tasks.
  • Their language/API can be quite complex and hard to understand, especially for complicated architectures. 

PyTorch Lightning and Keras solve this issue to a great extent, but they are just high-level wrapper APIs to complicated packages. On the other hand, Trax is built from the ground up for speed and clear, concise code, even when dealing with large, complex models. As the developers put it, Trax is “Your path to advanced deep learning“. Also, it’s actively used and maintained by the Google Brain team. 

Register for this Session>>

The codebase is organized by SOLID architecture and design principles, and it provides well-formatted logging. Trax uses the JAX library. JAX provides high-performance code acceleration by using Autograd and XLA. Autograd assists JAX to distinguish native Python and Numpy, and XLA is used to just-in-time compile and execute programs on GPU and Cloud TPU accelerators. It can be used as a library in python scripts and notebooks or binary from the shell. This makes training larger models more convenient. One thing to note is that Trax oriented more towards natural language models than computer vision. 

A brief introduction to Trax‘s high level syntax

  1. Install Trax from PyPI

!pip install

  1. To work with layers in Trax you’ll need to import layers. A basic Sigmoid layer can be instantiated using activation_fns.Sigmoid(), you can find the details of all layers here.
 # Make a sigmoid activation layer
 from trax import layers as ly
 sigmoid = ly.activation_fns.Sigmoid()

 # Some attributes
 print("name :",
 print("weights :", sigmoid.weights)
 print("# of inputs :", sigmoid.n_in)
 print("# of outputs :", sigmoid.n_out) 
Sigmoid layer in Trax

Trax provides a Python decorator that can be used to create classes for neural network layers dynamically

# define a custom layer
def Custom_layer():
     # Set a name
     layer_name = "custom_layer"
     # Custom function
     def func(x):
         return x + x^2
     return ly.base.Fn(layer_name, func)

 # Create the layer object
 custom_layer = Custom_layer()

 # Check properties
 print("name :",
 print("expected inputs :", custom_layer.n_in)
 print("promised outputs :", custom_layer.n_out)

 # Inputs
 x = np.array([0, -1, 1])
 # Outputs
 print("outputs :", custom_layer(x)) 
custom layer in Trax
  1. Models are built from layers using combinators like trax.layers.combinators.Serialtrax.layers.combinators.Parallel,  and trax.layers.combinators.Branch. Here’s a transformer implemented in Trax:
 model = ly.Serial(
     ly.Embedding(vocab_size=8192, d_feature=256),
     ly.Mean(axis=1),  # Average on axis 1 (length of sentence).
     ly.Dense(2),      # Classify 2 classes.
 # Print model structure.
Transformer in Trax
  1. It has access to a large number of datasets including Tesnor2Tesnor and Tensorflow datasets. The data streams in Trax are represented as Python iterators, here’s the code to import the TFDS IMDb reviews dataset using
 train_stream ='imdb_reviews', keys=('text', 'label'), train=True)()
 eval_stream ='imdb_reviews', keys=('text', 'label'), train=False)() 
  1. You can train supervised and reinforcement learning models in Trax using and trax.rl respectively. Here’s an example of training a supervised learning model:
 from trax.supervised import training

 # Training task
 train_task = training.TrainTask(

 # Evaluaton task
 eval_task = training.EvalTask(
     metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],
     n_eval_batches=20  # For less variance in eval numbers.

 # Training loop saves checkpoints to output_dir.
 output_dir = os.path.expanduser('~/output_dir/')
 !rm -rf {output_dir}
 training_loop = training.Loop(model,

 # Run 2000 steps (batches). 
Training log

After training, the models can be run like any function:

See Also

 example_input = next(eval_batches_stream)[0][0]
 example_input_str =, vocab_file='en_8k.subword')
 print(f'example input_str: {example_input_str}')
 sentiment_log_probs = model(example_input[None, :])  # Add batch dimension.
 print(f'Model returned sentiment probabilities: {np.exp(sentiment_log_probs)}') 
  1. Running a pre-trained transformer-based English-German translation model:

A Transformer model is created with trax.models.Transformer, and initialized using model.init_from_file. The input is tokenized with and passed to the model. The output from the Transformer model is decoded using trax.supervised.decoding.autoregressive_sample, and finally de-tokenized with

 # Create a Transformer model.
 # Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin
 model = trax.models.Transformer(
     d_model=512, d_ff=2048,
     n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
     max_len=2048, mode='predict')
 # Initialize using pre-trained weights.

 # Tokenize a sentence.
 sentence = 'It is nice to learn new things today!'
 tokenized = list([sentence]),  # Operates on streams.

 # Decode from the Transformer.
 tokenized = tokenized[None, :]  # Add batch dimension.
 tokenized_translation = trax.supervised.decoding.autoregressive_sample(
     model, tokenized, temperature=0.0)  # Higher temperature: more diverse results.

 # De-tokenize,
 tokenized_translation = tokenized_translation[0][:-1]  # Remove batch and EOS.
 translation =,
Output of the transformer-based translation model created in Trax

Last Epoch (Endnote)

This post briefly introduced Trax and went through the intention behind its development and some of its advantages. We also illustrated its simple high-level syntax for various tasks involved in a deep learning pipeline. For more information, codes, and examples see:

Subscribe to our Newsletter

Get the latest updates and relevant offers by sharing your email.
Join our Telegram Group. Be part of an engaging community

Copyright Analytics India Magazine Pvt Ltd

Scroll To Top