MITB Banner

What Is Trax and How Is It A Better Framework For Advanced Deep Learning?

Share

Trax

TensorFlow, Pytorch, Caffe, Keras, Theano, and many more. There’s already an abundance of deep learning frameworks, so why should you care about Trax? Well, most deep learning libraries have two major drawbacks:

  • They require you to write long syntaxes, even for simple tasks.
  • Their language/API can be quite complex and hard to understand, especially for complicated architectures. 

PyTorch Lightning and Keras solve this issue to a great extent, but they are just high-level wrapper APIs to complicated packages. On the other hand, Trax is built from the ground up for speed and clear, concise code, even when dealing with large, complex models. As the developers put it, Trax is “Your path to advanced deep learning“. Also, it’s actively used and maintained by the Google Brain team. 

The codebase is organized by SOLID architecture and design principles, and it provides well-formatted logging. Trax uses the JAX library. JAX provides high-performance code acceleration by using Autograd and XLA. Autograd assists JAX to distinguish native Python and Numpy, and XLA is used to just-in-time compile and execute programs on GPU and Cloud TPU accelerators. It can be used as a library in python scripts and notebooks or binary from the shell. This makes training larger models more convenient. One thing to note is that Trax oriented more towards natural language models than computer vision. 

A brief introduction to Trax‘s high level syntax

  1. Install Trax from PyPI

!pip install

  1. To work with layers in Trax you’ll need to import layers. A basic Sigmoid layer can be instantiated using activation_fns.Sigmoid(), you can find the details of all layers here.
 # Make a sigmoid activation layer
 from trax import layers as ly
 sigmoid = ly.activation_fns.Sigmoid()

 # Some attributes
 print("name :", sigmoid.name)
 print("weights :", sigmoid.weights)
 print("# of inputs :", sigmoid.n_in)
 print("# of outputs :", sigmoid.n_out) 
Sigmoid layer in Trax

Trax provides a Python decorator that can be used to create classes for neural network layers dynamically

# define a custom layer
def Custom_layer():
     # Set a name
     layer_name = "custom_layer"
     # Custom function
     def func(x):
         return x + x^2
     return ly.base.Fn(layer_name, func)

 # Create the layer object
 custom_layer = Custom_layer()

 # Check properties
 print("name :", custom_layer.name)
 print("expected inputs :", custom_layer.n_in)
 print("promised outputs :", custom_layer.n_out)

 # Inputs
 x = np.array([0, -1, 1])
 # Outputs
 print("outputs :", custom_layer(x)) 
custom layer in Trax
  1. Models are built from layers using combinators like trax.layers.combinators.Serialtrax.layers.combinators.Parallel,  and trax.layers.combinators.Branch. Here’s a transformer implemented in Trax:
 model = ly.Serial(
     ly.Embedding(vocab_size=8192, d_feature=256),
     ly.Mean(axis=1),  # Average on axis 1 (length of sentence).
     ly.Dense(2),      # Classify 2 classes.
 )
 # Print model structure.
 print(model) 
Transformer in Trax
  1. It has access to a large number of datasets including Tesnor2Tesnor and Tensorflow datasets. The data streams in Trax are represented as Python iterators, here’s the code to import the TFDS IMDb reviews dataset using trax.data:
 train_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True)()
 eval_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=False)() 
  1. You can train supervised and reinforcement learning models in Trax using trax.supervised.training and trax.rl respectively. Here’s an example of training a supervised learning model:
 from trax.supervised import training

 # Training task
 train_task = training.TrainTask(
     labeled_data=train_batches_stream,
     loss_layer=tl.WeightedCategoryCrossEntropy(),
     optimizer=trax.optimizers.Adam(0.01),
     n_steps_per_checkpoint=500,
 )

 # Evaluaton task
 eval_task = training.EvalTask(
     labeled_data=eval_batches_stream,
     metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],
     n_eval_batches=20  # For less variance in eval numbers.
 )

 # Training loop saves checkpoints to output_dir.
 output_dir = os.path.expanduser('~/output_dir/')
 !rm -rf {output_dir}
 training_loop = training.Loop(model,
                               train_task,
                               eval_tasks=[eval_task],
                               output_dir=output_dir)

 # Run 2000 steps (batches).
 training_loop.run(2000) 
Training log

After training, the models can be run like any function:

 example_input = next(eval_batches_stream)[0][0]
 example_input_str = trax.data.detokenize(example_input, vocab_file='en_8k.subword')
 print(f'example input_str: {example_input_str}')
 sentiment_log_probs = model(example_input[None, :])  # Add batch dimension.
 print(f'Model returned sentiment probabilities: {np.exp(sentiment_log_probs)}') 
  1. Running a pre-trained transformer-based English-German translation model:

A Transformer model is created with trax.models.Transformer, and initialized using model.init_from_file. The input is tokenized with trax.data.tokenize and passed to the model. The output from the Transformer model is decoded using trax.supervised.decoding.autoregressive_sample, and finally de-tokenized with trax.data.detokenize.

 # Create a Transformer model.
 # Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin
 model = trax.models.Transformer(
     input_vocab_size=33300,
     d_model=512, d_ff=2048,
     n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
     max_len=2048, mode='predict')
 # Initialize using pre-trained weights.
 model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',
                      weights_only=True)

 # Tokenize a sentence.
 sentence = 'It is nice to learn new things today!'
 tokenized = list(trax.data.tokenize(iter([sentence]),  # Operates on streams.
                                     vocab_dir='gs://trax-ml/vocabs/',
                                     vocab_file='ende_32k.subword'))[0]

 # Decode from the Transformer.
 tokenized = tokenized[None, :]  # Add batch dimension.
 tokenized_translation = trax.supervised.decoding.autoregressive_sample(
     model, tokenized, temperature=0.0)  # Higher temperature: more diverse results.

 # De-tokenize,
 tokenized_translation = tokenized_translation[0][:-1]  # Remove batch and EOS.
 translation = trax.data.detokenize(tokenized_translation,
                                    vocab_dir='gs://trax-ml/vocabs/',
                                    vocab_file='ende_32k.subword')
 print(translation) 
Output of the transformer-based translation model created in Trax

Last Epoch (Endnote)

This post briefly introduced Trax and went through the intention behind its development and some of its advantages. We also illustrated its simple high-level syntax for various tasks involved in a deep learning pipeline. For more information, codes, and examples see:

Share
Picture of Aditya Singh

Aditya Singh

A machine learning enthusiast with a knack for finding patterns. In my free time, I like to delve into the world of non-fiction books and video essays.
Related Posts

CORPORATE TRAINING PROGRAMS ON GENERATIVE AI

Generative AI Skilling for Enterprises

Our customized corporate training program on Generative AI provides a unique opportunity to empower, retain, and advance your talent.

Upcoming Large format Conference

May 30 and 31, 2024 | 📍 Bangalore, India

Download the easiest way to
stay informed

Subscribe to The Belamy: Our Weekly Newsletter

Biggest AI stories, delivered to your inbox every week.

AI Courses & Careers

Become a Certified Generative AI Engineer

AI Forum for India

Our Discord Community for AI Ecosystem, In collaboration with NVIDIA. 

Flagship Events

Rising 2024 | DE&I in Tech Summit

April 4 and 5, 2024 | 📍 Hilton Convention Center, Manyata Tech Park, Bangalore

MachineCon GCC Summit 2024

June 28 2024 | 📍Bangalore, India

MachineCon USA 2024

26 July 2024 | 583 Park Avenue, New York

Cypher India 2024

September 25-27, 2024 | 📍Bangalore, India

Cypher USA 2024

Nov 21-22 2024 | 📍Santa Clara Convention Center, California, USA

Data Engineering Summit 2024

May 30 and 31, 2024 | 📍 Bangalore, India

Subscribe to Our Newsletter

The Belamy, our weekly Newsletter is a rage. Just enter your email below.