MITB Banner

How to use Torchbearer for fitting ML models with PyTorch?

Torchbearer is python based library which is basically a model fitting library for PyTorch models and offers a high-level metric and callback API that can be used in a variety of applications.
Share

For TensorFlow based models, we have Keras which provides access to all the complex functionality of the TensorFlow preprocessing, modelling, and managing callbacks in the form of simple and High-level API. In this similar space for PyTorch, there is a library called Torchbearer which is basically a model fitting library for PyTorch models and offers a high-level metric and callback API that can be used in a variety of applications. In this article, we are going to discuss this library in detail with its hands-on implementations. We will also see how this framework helps in interpreting the fitted machine learning models. Below are the major points that we are going to cover in this article.

Table of contents

  1. Need of Torchbearer
  2. Design of Torchbearer
  3. Training and visualizing a SVM

Let’s first understand the need for this library. 

Need of Torchbearer

Deep learning’s meteoric rise has spawned a slew of frameworks that enable hardware-accelerated tensor processing and automatic differentiation. Differentiable programming, a more broad characterization, has gradually taken its place. Fitting is a method that involves maximizing the parameters of a differentiable algorithm using gradient descent. Pytorch is one library that has grown in popularity in recent years, thanks to its ease of use in creating models that execute non-standard tensor operations.

This makes PyTorch particularly suitable for research projects in which any aspect of a model definition may need to be changed or modified. However, because PyTorch focuses solely on tensor processing and automated differentiation, it lacks the high-level model-fitting API found in other frameworks like Keras. Furthermore, such libraries rarely allow differentiable programming in its broadest sense.

As a result, the torchbearer, a Python library that supports research by speeding the model fitting process while maintaining transparency and generality. A model fitting API is at the heart of the torchbearer, allowing for easy customization of all aspects of the fitting process. We also offer a robust metric API that allows you to collect rolling information and averages. Finally, torchbearer has a number of advanced callbacks.

Design of Torchbearer

The torchbearer library is written in Python and uses PyTorch, torchvision, and tqdm, with some functionality provided by NumPy, sci-kit-learn, and tensor board. Torchbearer’s core abstractions are trials, callbacks, and metrics. 

Torchbearer is different from other similar libraries like ignite or tnt because of these design concepts. Neither library, for example, has a large number of built-in callbacks for sophisticated functions. Furthermore, both ignite and tnt use an event-driven API for model fitting, which makes the code less clear and legible for humans.

Let’s briefly discuss its major APIs.

Trail API

The Trial class implements a PyTorch model fitting interface based on the Trial.run(…) method. There are also predict(…) and evaluate(…) methods for inferring models and assessing saved models. The Trial class also has a state dict method that returns a dictionary comprising the model parameters, optimizer state, and callback states, which can be stored and then reloaded using load_state_dict.

Callback API

During the fitting procedure, the callback API defines classes that can be used to perform a variety of tasks. A torchbearer component provided to each callback is the mutable state dictionary, which holds intermediate variables required by the Trial. Callbacks can change the nature of the fitting process in real-time as a result of this. Callbacks can be implemented as decorated functions using the decorator API.

Metric API

The metric API makes use of a tree to allow data to flow from one metric to a set of children. This enables the computation of aggregates such as the running mean or standard deviation. Assembling these data structures can be difficult, so torchbearer includes a decorator API to make it easier. The default_for_key(…) decorator allows the metric to be referenced in the Trial definition with a string.

Training and visualizing SVM

In this section, we will try to implement the SVM model torch and will train, evaluate and visualize the hyperplane using the Torchbearer library. This example is taken from the official repository of the Torchbearer.

SVM seeks the hyperplane with the greatest margin of separation between the data classes. We minimize the following for a soft margin SVM where x is our data:

This can be expressed as an optimization over our weights w and bias b, where we minimize the hinge loss while accounting for a level 2 weight decay term.

Now before modelling this in PyTorch let’s first install and import the Torchbearer library.

 # install and import torchbearer
!pip install -q torchbearer
import torchbearer

After this let’s define the SVM and hinge loss function.

# define SVM
import torch.nn as nn
 
class LinearSVM(nn.Module):
    """Support Vector Machine"""
 
    def __init__(self):
        super(LinearSVM, self).__init__()
        self.w = nn.Parameter(torch.randn(1, 2), requires_grad=True)
        self.b = nn.Parameter(torch.randn(1), requires_grad=True)
 
    def forward(self, x):
        h = x.matmul(self.w.t()) + self.b
        return h

# define the loss function
def hinge_loss(y_pred, y_true):
    return torch.mean(torch.clamp(1 - y_pred.t() * y_true, min=0))

Now we’ll create and normalize the synthetic data that will be separated by the hyperplanes.

# load data
import numpy as np
from sklearn.datasets import make_blobs
 
X, Y = make_blobs(n_samples=1024, centers=2, cluster_std=1.2, random_state=1)
X = (X - X.mean()) / X.std()
Y[np.where(Y == 0)] = -1
X, Y = torch.FloatTensor(X), torch.FloatTensor(Y)
# normalize the data
delta = 0.01
x = np.arange(X[:, 0].min(), X[:, 0].max(), delta)
y = np.arange(X[:, 1].min(), X[:, 1].max(), delta)
x, y = np.meshgrid(x, y)
xy = list(map(np.ravel, [x, y]))

Now we’ll define the callbacks and visualization function. 

# visualization function
 
from torchbearer import callbacks
 
%matplotlib notebook
import matplotlib
import matplotlib.pyplot as plt
 
@callbacks.on_step_training
@callbacks.only_if(lambda state: state[torchbearer.BATCH] % 10 == 0)
def draw_margin(state):
    w = state[torchbearer.MODEL].w[0].detach().to('cpu').numpy()
    b = state[torchbearer.MODEL].b[0].detach().to('cpu').numpy()
 
    z = (w.dot(xy) + b).reshape(x.shape)
    z[np.where(z > 1.)] = 4
    z[np.where((z > 0.) & (z <= 1.))] = 3
    z[np.where((z > -1.) & (z <= 0.))] = 2
    z[np.where(z <= -1.)] = 1
 
    plt.clf()
    plt.scatter(x=X[:, 0], y=X[:, 1], c="black", s=10)
    plt.contourf(x, y, z, cmap=plt.cm.jet, alpha=0.5)
    fig.canvas.draw()

We’d like to use a soft-margin SVM because we don’t know if our data is linearly separable. To accomplish this, we can use the L2WeightDecay callback in torchbearer. Because we only use a mini-batch at each step to approximate the gradient over all of the data, this entire process is known as subgradient descent. Now let’s train the model and see the result. 

# train the model
from torchbearer import Trial
from torchbearer.callbacks import L2WeightDecay, ExponentialLR
 
import torch.optim as optim
 
device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
fig = plt.figure(figsize=(5, 5))
 
svm = LinearSVM()
model = Trial(svm, optim.SGD(svm.parameters(), 0.1), hinge_loss, ['loss'],
              callbacks=[draw_margin, ExponentialLR(0.999, step_on_batch=True), L2WeightDecay(0.01, params=[svm.w])]).to(device)
model.with_train_data(X, Y, batch_size=32)
model.run(epochs=50, verbose=1)

The plot result can be saved in the working directory as below.

#fig will be saved in the working directory
fig.savefig('svm.png', bbox_inches='tight')

And here are SVM’s hyperplanes that separate the above synthetic data seamlessly.

Final words

Through this article, we have discussed Torchbearer, a library that simplifies the training process and helps in differential kinds of programming for PyTorch models. From the above practical implementation, we have seen how seamlessly we can train the model using this framework. A complete set of built-in callbacks (such as logging, weight decay, and model checkpointing) and a strong metric API are two of torchbearer’s key features. 

References

PS: The story was written using a keyboard.
Picture of Vijaysinh Lendave

Vijaysinh Lendave

Vijaysinh is an enthusiast in machine learning and deep learning. He is skilled in ML algorithms, data manipulation, handling and visualization, model building.
Related Posts

Download our Mobile App

CORPORATE TRAINING PROGRAMS ON GENERATIVE AI

Generative AI Skilling for Enterprises

Our customized corporate training program on Generative AI provides a unique opportunity to empower, retain, and advance your talent.

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox
Recent Stories

Subscribe to The Belamy: Our Weekly Newsletter

Biggest AI stories, delivered to your inbox every week.

AI Courses & Careers

Become a Certified Generative AI Engineer

AI Forum for India

Our Discord Community for AI Ecosystem, In collaboration with NVIDIA. 

AIM Conference Calendar

Immerse yourself in AI and business conferences tailored to your role, designed to elevate your performance and empower you to accomplish your organization’s vital objectives. Revel in intimate events that encapsulate the heart and soul of the AI Industry.

Flagship Events

Rising 2024 | DE&I in Tech Summit

April 4 and 5, 2024 | 📍 Hilton Convention Center, Manyata Tech Park, Bangalore

MachineCon GCC Summit 2024

June 28 2024 | 📍Bangalore, India

MachineCon USA 2024

26 July 2024 | 583 Park Avenue, New York

Cypher India 2024

September 25-27, 2024 | 📍Bangalore, India

Cypher USA 2024

Nov 21-22 2024 | 📍Santa Clara Convention Center, California, USA

Data Engineering Summit 2024

May 30 and 31, 2024 | 📍 Bangalore, India

Download the easiest way to
stay informed