How to use Torchbearer for fitting ML models with PyTorch?

Torchbearer is python based library which is basically a model fitting library for PyTorch models and offers a high-level metric and callback API that can be used in a variety of applications.

For TensorFlow based models, we have Keras which provides access to all the complex functionality of the TensorFlow preprocessing, modelling, and managing callbacks in the form of simple and High-level API. In this similar space for PyTorch, there is a library called Torchbearer which is basically a model fitting library for PyTorch models and offers a high-level metric and callback API that can be used in a variety of applications. In this article, we are going to discuss this library in detail with its hands-on implementations. We will also see how this framework helps in interpreting the fitted machine learning models. Below are the major points that we are going to cover in this article.

Table of contents

  1. Need of Torchbearer
  2. Design of Torchbearer
  3. Training and visualizing a SVM

Let’s first understand the need for this library. 

Need of Torchbearer

Deep learning’s meteoric rise has spawned a slew of frameworks that enable hardware-accelerated tensor processing and automatic differentiation. Differentiable programming, a more broad characterization, has gradually taken its place. Fitting is a method that involves maximizing the parameters of a differentiable algorithm using gradient descent. Pytorch is one library that has grown in popularity in recent years, thanks to its ease of use in creating models that execute non-standard tensor operations.


Sign up for your weekly dose of what's up in emerging technology.

This makes PyTorch particularly suitable for research projects in which any aspect of a model definition may need to be changed or modified. However, because PyTorch focuses solely on tensor processing and automated differentiation, it lacks the high-level model-fitting API found in other frameworks like Keras. Furthermore, such libraries rarely allow differentiable programming in its broadest sense.

As a result, the torchbearer, a Python library that supports research by speeding the model fitting process while maintaining transparency and generality. A model fitting API is at the heart of the torchbearer, allowing for easy customization of all aspects of the fitting process. We also offer a robust metric API that allows you to collect rolling information and averages. Finally, torchbearer has a number of advanced callbacks.

Design of Torchbearer

The torchbearer library is written in Python and uses PyTorch, torchvision, and tqdm, with some functionality provided by NumPy, sci-kit-learn, and tensor board. Torchbearer’s core abstractions are trials, callbacks, and metrics. 

Torchbearer is different from other similar libraries like ignite or tnt because of these design concepts. Neither library, for example, has a large number of built-in callbacks for sophisticated functions. Furthermore, both ignite and tnt use an event-driven API for model fitting, which makes the code less clear and legible for humans.

Let’s briefly discuss its major APIs.

Trail API

The Trial class implements a PyTorch model fitting interface based on the…) method. There are also predict(…) and evaluate(…) methods for inferring models and assessing saved models. The Trial class also has a state dict method that returns a dictionary comprising the model parameters, optimizer state, and callback states, which can be stored and then reloaded using load_state_dict.

Callback API

During the fitting procedure, the callback API defines classes that can be used to perform a variety of tasks. A torchbearer component provided to each callback is the mutable state dictionary, which holds intermediate variables required by the Trial. Callbacks can change the nature of the fitting process in real-time as a result of this. Callbacks can be implemented as decorated functions using the decorator API.

Metric API

The metric API makes use of a tree to allow data to flow from one metric to a set of children. This enables the computation of aggregates such as the running mean or standard deviation. Assembling these data structures can be difficult, so torchbearer includes a decorator API to make it easier. The default_for_key(…) decorator allows the metric to be referenced in the Trial definition with a string.

Training and visualizing SVM

In this section, we will try to implement the SVM model torch and will train, evaluate and visualize the hyperplane using the Torchbearer library. This example is taken from the official repository of the Torchbearer.

SVM seeks the hyperplane with the greatest margin of separation between the data classes. We minimize the following for a soft margin SVM where x is our data:

This can be expressed as an optimization over our weights w and bias b, where we minimize the hinge loss while accounting for a level 2 weight decay term.

Now before modelling this in PyTorch let’s first install and import the Torchbearer library.

 # install and import torchbearer
!pip install -q torchbearer
import torchbearer

After this let’s define the SVM and hinge loss function.

# define SVM
import torch.nn as nn
class LinearSVM(nn.Module):
    """Support Vector Machine"""
    def __init__(self):
        super(LinearSVM, self).__init__()
        self.w = nn.Parameter(torch.randn(1, 2), requires_grad=True)
        self.b = nn.Parameter(torch.randn(1), requires_grad=True)
    def forward(self, x):
        h = x.matmul(self.w.t()) + self.b
        return h

# define the loss function
def hinge_loss(y_pred, y_true):
    return torch.mean(torch.clamp(1 - y_pred.t() * y_true, min=0))

Now we’ll create and normalize the synthetic data that will be separated by the hyperplanes.

# load data
import numpy as np
from sklearn.datasets import make_blobs
X, Y = make_blobs(n_samples=1024, centers=2, cluster_std=1.2, random_state=1)
X = (X - X.mean()) / X.std()
Y[np.where(Y == 0)] = -1
X, Y = torch.FloatTensor(X), torch.FloatTensor(Y)
# normalize the data
delta = 0.01
x = np.arange(X[:, 0].min(), X[:, 0].max(), delta)
y = np.arange(X[:, 1].min(), X[:, 1].max(), delta)
x, y = np.meshgrid(x, y)
xy = list(map(np.ravel, [x, y]))

Now we’ll define the callbacks and visualization function. 

# visualization function
from torchbearer import callbacks
%matplotlib notebook
import matplotlib
import matplotlib.pyplot as plt
@callbacks.only_if(lambda state: state[torchbearer.BATCH] % 10 == 0)
def draw_margin(state):
    w = state[torchbearer.MODEL].w[0].detach().to('cpu').numpy()
    b = state[torchbearer.MODEL].b[0].detach().to('cpu').numpy()
    z = ( + b).reshape(x.shape)
    z[np.where(z > 1.)] = 4
    z[np.where((z > 0.) & (z <= 1.))] = 3
    z[np.where((z > -1.) & (z <= 0.))] = 2
    z[np.where(z <= -1.)] = 1
    plt.scatter(x=X[:, 0], y=X[:, 1], c="black", s=10)
    plt.contourf(x, y, z,, alpha=0.5)

We’d like to use a soft-margin SVM because we don’t know if our data is linearly separable. To accomplish this, we can use the L2WeightDecay callback in torchbearer. Because we only use a mini-batch at each step to approximate the gradient over all of the data, this entire process is known as subgradient descent. Now let’s train the model and see the result. 

# train the model
from torchbearer import Trial
from torchbearer.callbacks import L2WeightDecay, ExponentialLR
import torch.optim as optim
device = 'cuda' if torch.cuda.is_available() else 'cpu'
fig = plt.figure(figsize=(5, 5))
svm = LinearSVM()
model = Trial(svm, optim.SGD(svm.parameters(), 0.1), hinge_loss, ['loss'],
              callbacks=[draw_margin, ExponentialLR(0.999, step_on_batch=True), L2WeightDecay(0.01, params=[svm.w])]).to(device)
model.with_train_data(X, Y, batch_size=32), verbose=1)

The plot result can be saved in the working directory as below.

#fig will be saved in the working directory
fig.savefig('svm.png', bbox_inches='tight')

And here are SVM’s hyperplanes that separate the above synthetic data seamlessly.

Final words

Through this article, we have discussed Torchbearer, a library that simplifies the training process and helps in differential kinds of programming for PyTorch models. From the above practical implementation, we have seen how seamlessly we can train the model using this framework. A complete set of built-in callbacks (such as logging, weight decay, and model checkpointing) and a strong metric API are two of torchbearer’s key features. 


More Great AIM Stories

Vijaysinh Lendave
Vijaysinh is an enthusiast in machine learning and deep learning. He is skilled in ML algorithms, data manipulation, handling and visualization, model building.

Our Upcoming Events

Conference, in-person (Bangalore)
Machine Learning Developers Summit (MLDS) 2023
19-20th Jan, 2023

Conference, in-person (Bangalore)
Rising 2023 | Women in Tech Conference
16-17th Mar, 2023

Conference, in-person (Bangalore)
Data Engineering Summit (DES) 2023
27-28th Apr, 2023

Conference, in-person (Bangalore)
MachineCon 2023
23rd Jun, 2023

Conference, in-person (Bangalore)
Cypher 2023
20-22nd Sep, 2023

3 Ways to Join our Community

Whatsapp group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our newsletter

Get the latest updates from AIM