How to use Torchbearer for fitting ML models with PyTorch?

Torchbearer is python based library which is basically a model fitting library for PyTorch models and offers a high-level metric and callback API that can be used in a variety of applications.

For TensorFlow based models, we have Keras which provides access to all the complex functionality of the TensorFlow preprocessing, modelling, and managing callbacks in the form of simple and High-level API. In this similar space for PyTorch, there is a library called Torchbearer which is basically a model fitting library for PyTorch models and offers a high-level metric and callback API that can be used in a variety of applications. In this article, we are going to discuss this library in detail with its hands-on implementations. We will also see how this framework helps in interpreting the fitted machine learning models. Below are the major points that we are going to cover in this article.

Table of contents

  1. Need of Torchbearer
  2. Design of Torchbearer
  3. Training and visualizing a SVM

Let’s first understand the need for this library. 

Need of Torchbearer

Deep learning’s meteoric rise has spawned a slew of frameworks that enable hardware-accelerated tensor processing and automatic differentiation. Differentiable programming, a more broad characterization, has gradually taken its place. Fitting is a method that involves maximizing the parameters of a differentiable algorithm using gradient descent. Pytorch is one library that has grown in popularity in recent years, thanks to its ease of use in creating models that execute non-standard tensor operations.

Subscribe to our Newsletter

Join our editors every weekday evening as they steer you through the most significant news of the day, introduce you to fresh perspectives, and provide unexpected moments of joy
Your newsletter subscriptions are subject to AIM Privacy Policy and Terms and Conditions.

This makes PyTorch particularly suitable for research projects in which any aspect of a model definition may need to be changed or modified. However, because PyTorch focuses solely on tensor processing and automated differentiation, it lacks the high-level model-fitting API found in other frameworks like Keras. Furthermore, such libraries rarely allow differentiable programming in its broadest sense.

As a result, the torchbearer, a Python library that supports research by speeding the model fitting process while maintaining transparency and generality. A model fitting API is at the heart of the torchbearer, allowing for easy customization of all aspects of the fitting process. We also offer a robust metric API that allows you to collect rolling information and averages. Finally, torchbearer has a number of advanced callbacks.

Design of Torchbearer

The torchbearer library is written in Python and uses PyTorch, torchvision, and tqdm, with some functionality provided by NumPy, sci-kit-learn, and tensor board. Torchbearer’s core abstractions are trials, callbacks, and metrics. 

Torchbearer is different from other similar libraries like ignite or tnt because of these design concepts. Neither library, for example, has a large number of built-in callbacks for sophisticated functions. Furthermore, both ignite and tnt use an event-driven API for model fitting, which makes the code less clear and legible for humans.

Let’s briefly discuss its major APIs.

Trail API

The Trial class implements a PyTorch model fitting interface based on the…) method. There are also predict(…) and evaluate(…) methods for inferring models and assessing saved models. The Trial class also has a state dict method that returns a dictionary comprising the model parameters, optimizer state, and callback states, which can be stored and then reloaded using load_state_dict.

Callback API

During the fitting procedure, the callback API defines classes that can be used to perform a variety of tasks. A torchbearer component provided to each callback is the mutable state dictionary, which holds intermediate variables required by the Trial. Callbacks can change the nature of the fitting process in real-time as a result of this. Callbacks can be implemented as decorated functions using the decorator API.

Metric API

The metric API makes use of a tree to allow data to flow from one metric to a set of children. This enables the computation of aggregates such as the running mean or standard deviation. Assembling these data structures can be difficult, so torchbearer includes a decorator API to make it easier. The default_for_key(…) decorator allows the metric to be referenced in the Trial definition with a string.

Training and visualizing SVM

In this section, we will try to implement the SVM model torch and will train, evaluate and visualize the hyperplane using the Torchbearer library. This example is taken from the official repository of the Torchbearer.

SVM seeks the hyperplane with the greatest margin of separation between the data classes. We minimize the following for a soft margin SVM where x is our data:

This can be expressed as an optimization over our weights w and bias b, where we minimize the hinge loss while accounting for a level 2 weight decay term.

Now before modelling this in PyTorch let’s first install and import the Torchbearer library.

 # install and import torchbearer
!pip install -q torchbearer
import torchbearer

After this let’s define the SVM and hinge loss function.

# define SVM
import torch.nn as nn
class LinearSVM(nn.Module):
    """Support Vector Machine"""
    def __init__(self):
        super(LinearSVM, self).__init__()
        self.w = nn.Parameter(torch.randn(1, 2), requires_grad=True)
        self.b = nn.Parameter(torch.randn(1), requires_grad=True)
    def forward(self, x):
        h = x.matmul(self.w.t()) + self.b
        return h

# define the loss function
def hinge_loss(y_pred, y_true):
    return torch.mean(torch.clamp(1 - y_pred.t() * y_true, min=0))

Now we’ll create and normalize the synthetic data that will be separated by the hyperplanes.

# load data
import numpy as np
from sklearn.datasets import make_blobs
X, Y = make_blobs(n_samples=1024, centers=2, cluster_std=1.2, random_state=1)
X = (X - X.mean()) / X.std()
Y[np.where(Y == 0)] = -1
X, Y = torch.FloatTensor(X), torch.FloatTensor(Y)
# normalize the data
delta = 0.01
x = np.arange(X[:, 0].min(), X[:, 0].max(), delta)
y = np.arange(X[:, 1].min(), X[:, 1].max(), delta)
x, y = np.meshgrid(x, y)
xy = list(map(np.ravel, [x, y]))

Now we’ll define the callbacks and visualization function. 

# visualization function
from torchbearer import callbacks
%matplotlib notebook
import matplotlib
import matplotlib.pyplot as plt
@callbacks.only_if(lambda state: state[torchbearer.BATCH] % 10 == 0)
def draw_margin(state):
    w = state[torchbearer.MODEL].w[0].detach().to('cpu').numpy()
    b = state[torchbearer.MODEL].b[0].detach().to('cpu').numpy()
    z = ( + b).reshape(x.shape)
    z[np.where(z > 1.)] = 4
    z[np.where((z > 0.) & (z <= 1.))] = 3
    z[np.where((z > -1.) & (z <= 0.))] = 2
    z[np.where(z <= -1.)] = 1
    plt.scatter(x=X[:, 0], y=X[:, 1], c="black", s=10)
    plt.contourf(x, y, z,, alpha=0.5)

We’d like to use a soft-margin SVM because we don’t know if our data is linearly separable. To accomplish this, we can use the L2WeightDecay callback in torchbearer. Because we only use a mini-batch at each step to approximate the gradient over all of the data, this entire process is known as subgradient descent. Now let’s train the model and see the result. 

# train the model
from torchbearer import Trial
from torchbearer.callbacks import L2WeightDecay, ExponentialLR
import torch.optim as optim
device = 'cuda' if torch.cuda.is_available() else 'cpu'
fig = plt.figure(figsize=(5, 5))
svm = LinearSVM()
model = Trial(svm, optim.SGD(svm.parameters(), 0.1), hinge_loss, ['loss'],
              callbacks=[draw_margin, ExponentialLR(0.999, step_on_batch=True), L2WeightDecay(0.01, params=[svm.w])]).to(device)
model.with_train_data(X, Y, batch_size=32), verbose=1)

The plot result can be saved in the working directory as below.

#fig will be saved in the working directory
fig.savefig('svm.png', bbox_inches='tight')

And here are SVM’s hyperplanes that separate the above synthetic data seamlessly.

Final words

Through this article, we have discussed Torchbearer, a library that simplifies the training process and helps in differential kinds of programming for PyTorch models. From the above practical implementation, we have seen how seamlessly we can train the model using this framework. A complete set of built-in callbacks (such as logging, weight decay, and model checkpointing) and a strong metric API are two of torchbearer’s key features. 


Vijaysinh Lendave
Vijaysinh is an enthusiast in machine learning and deep learning. He is skilled in ML algorithms, data manipulation, handling and visualization, model building.

Download our Mobile App

MachineHack | AI Hackathons, Coding & Learning

Host Hackathons & Recruit Great Data Talent!

AIMResearch Pioneering advanced AI market research

With a decade of experience under our belt, we are transforming how businesses use AI & data-driven insights to succeed.

The Gold Standard for Recognizing Excellence in Data Science and Tech Workplaces

With Best Firm Certification, you can effortlessly delve into the minds of your employees, unveil invaluable perspectives, and gain distinguished acclaim for fostering an exceptional company culture.

AIM Leaders Council

World’s Biggest Community Exclusively For Senior Executives In Data Science And Analytics.

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox