What Is Trax and How Is It A Better Framework For Advanced Deep Learning?


TensorFlow, Pytorch, Caffe, Keras, Theano, and many more. There’s already an abundance of deep learning frameworks, so why should you care about Trax? Well, most deep learning libraries have two major drawbacks:

  • They require you to write long syntaxes, even for simple tasks.
  • Their language/API can be quite complex and hard to understand, especially for complicated architectures. 

PyTorch Lightning and Keras solve this issue to a great extent, but they are just high-level wrapper APIs to complicated packages. On the other hand, Trax is built from the ground up for speed and clear, concise code, even when dealing with large, complex models. As the developers put it, Trax is “Your path to advanced deep learning“. Also, it’s actively used and maintained by the Google Brain team. 

The codebase is organized by SOLID architecture and design principles, and it provides well-formatted logging. Trax uses the JAX library. JAX provides high-performance code acceleration by using Autograd and XLA. Autograd assists JAX to distinguish native Python and Numpy, and XLA is used to just-in-time compile and execute programs on GPU and Cloud TPU accelerators. It can be used as a library in python scripts and notebooks or binary from the shell. This makes training larger models more convenient. One thing to note is that Trax oriented more towards natural language models than computer vision. 


Sign up for your weekly dose of what's up in emerging technology.

A brief introduction to Trax‘s high level syntax

  1. Install Trax from PyPI

!pip install

  1. To work with layers in Trax you’ll need to import layers. A basic Sigmoid layer can be instantiated using activation_fns.Sigmoid(), you can find the details of all layers here.
 # Make a sigmoid activation layer
 from trax import layers as ly
 sigmoid = ly.activation_fns.Sigmoid()

 # Some attributes
 print("name :", sigmoid.name)
 print("weights :", sigmoid.weights)
 print("# of inputs :", sigmoid.n_in)
 print("# of outputs :", sigmoid.n_out) 
Sigmoid layer in Trax

Trax provides a Python decorator that can be used to create classes for neural network layers dynamically

Download our Mobile App

# define a custom layer
def Custom_layer():
     # Set a name
     layer_name = "custom_layer"
     # Custom function
     def func(x):
         return x + x^2
     return ly.base.Fn(layer_name, func)

 # Create the layer object
 custom_layer = Custom_layer()

 # Check properties
 print("name :", custom_layer.name)
 print("expected inputs :", custom_layer.n_in)
 print("promised outputs :", custom_layer.n_out)

 # Inputs
 x = np.array([0, -1, 1])
 # Outputs
 print("outputs :", custom_layer(x)) 
custom layer in Trax
  1. Models are built from layers using combinators like trax.layers.combinators.Serialtrax.layers.combinators.Parallel,  and trax.layers.combinators.Branch. Here’s a transformer implemented in Trax:
 model = ly.Serial(
     ly.Embedding(vocab_size=8192, d_feature=256),
     ly.Mean(axis=1),  # Average on axis 1 (length of sentence).
     ly.Dense(2),      # Classify 2 classes.
 # Print model structure.
Transformer in Trax
  1. It has access to a large number of datasets including Tesnor2Tesnor and Tensorflow datasets. The data streams in Trax are represented as Python iterators, here’s the code to import the TFDS IMDb reviews dataset using trax.data:
 train_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True)()
 eval_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=False)() 
  1. You can train supervised and reinforcement learning models in Trax using trax.supervised.training and trax.rl respectively. Here’s an example of training a supervised learning model:
 from trax.supervised import training

 # Training task
 train_task = training.TrainTask(

 # Evaluaton task
 eval_task = training.EvalTask(
     metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],
     n_eval_batches=20  # For less variance in eval numbers.

 # Training loop saves checkpoints to output_dir.
 output_dir = os.path.expanduser('~/output_dir/')
 !rm -rf {output_dir}
 training_loop = training.Loop(model,

 # Run 2000 steps (batches).
Training log

After training, the models can be run like any function:

 example_input = next(eval_batches_stream)[0][0]
 example_input_str = trax.data.detokenize(example_input, vocab_file='en_8k.subword')
 print(f'example input_str: {example_input_str}')
 sentiment_log_probs = model(example_input[None, :])  # Add batch dimension.
 print(f'Model returned sentiment probabilities: {np.exp(sentiment_log_probs)}') 
  1. Running a pre-trained transformer-based English-German translation model:

A Transformer model is created with trax.models.Transformer, and initialized using model.init_from_file. The input is tokenized with trax.data.tokenize and passed to the model. The output from the Transformer model is decoded using trax.supervised.decoding.autoregressive_sample, and finally de-tokenized with trax.data.detokenize.

 # Create a Transformer model.
 # Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin
 model = trax.models.Transformer(
     d_model=512, d_ff=2048,
     n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
     max_len=2048, mode='predict')
 # Initialize using pre-trained weights.

 # Tokenize a sentence.
 sentence = 'It is nice to learn new things today!'
 tokenized = list(trax.data.tokenize(iter([sentence]),  # Operates on streams.

 # Decode from the Transformer.
 tokenized = tokenized[None, :]  # Add batch dimension.
 tokenized_translation = trax.supervised.decoding.autoregressive_sample(
     model, tokenized, temperature=0.0)  # Higher temperature: more diverse results.

 # De-tokenize,
 tokenized_translation = tokenized_translation[0][:-1]  # Remove batch and EOS.
 translation = trax.data.detokenize(tokenized_translation,
Output of the transformer-based translation model created in Trax

Last Epoch (Endnote)

This post briefly introduced Trax and went through the intention behind its development and some of its advantages. We also illustrated its simple high-level syntax for various tasks involved in a deep learning pipeline. For more information, codes, and examples see:

Support independent technology journalism

Get exclusive, premium content, ads-free experience & more

Rs. 299/month

Subscribe now for a 7-day free trial

More Great AIM Stories

Aditya Singh
A machine learning enthusiast with a knack for finding patterns. In my free time, I like to delve into the world of non-fiction books and video essays.

AIM Upcoming Events

Early Bird Passes expire on 3rd Feb

Conference, in-person (Bangalore)
Rising 2023 | Women in Tech Conference
16-17th Mar, 2023

Conference, in-person (Bangalore)
Data Engineering Summit (DES) 2023
27-28th Apr, 2023

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox