For TensorFlow based models, we have Keras which provides access to all the complex functionality of the TensorFlow preprocessing, modelling, and managing callbacks in the form of simple and High-level API. In this similar space for PyTorch, there is a library called Torchbearer which is basically a model fitting library for PyTorch models and offers a high-level metric and callback API that can be used in a variety of applications. In this article, we are going to discuss this library in detail with its hands-on implementations. We will also see how this framework helps in interpreting the fitted machine learning models. Below are the major points that we are going to cover in this article.
Table of contents
- Need of Torchbearer
- Design of Torchbearer
- Training and visualizing a SVM
Let’s first understand the need for this library.
Need of Torchbearer
Deep learning’s meteoric rise has spawned a slew of frameworks that enable hardware-accelerated tensor processing and automatic differentiation. Differentiable programming, a more broad characterization, has gradually taken its place. Fitting is a method that involves maximizing the parameters of a differentiable algorithm using gradient descent. Pytorch is one library that has grown in popularity in recent years, thanks to its ease of use in creating models that execute non-standard tensor operations.
Subscribe to our Newsletter
Join our editors every weekday evening as they steer you through the most significant news of the day, introduce you to fresh perspectives, and provide unexpected moments of joy
This makes PyTorch particularly suitable for research projects in which any aspect of a model definition may need to be changed or modified. However, because PyTorch focuses solely on tensor processing and automated differentiation, it lacks the high-level model-fitting API found in other frameworks like Keras. Furthermore, such libraries rarely allow differentiable programming in its broadest sense.
As a result, the torchbearer, a Python library that supports research by speeding the model fitting process while maintaining transparency and generality. A model fitting API is at the heart of the torchbearer, allowing for easy customization of all aspects of the fitting process. We also offer a robust metric API that allows you to collect rolling information and averages. Finally, torchbearer has a number of advanced callbacks.
Design of Torchbearer
The torchbearer library is written in Python and uses PyTorch, torchvision, and tqdm, with some functionality provided by NumPy, sci-kit-learn, and tensor board. Torchbearer’s core abstractions are trials, callbacks, and metrics.
Torchbearer is different from other similar libraries like ignite or tnt because of these design concepts. Neither library, for example, has a large number of built-in callbacks for sophisticated functions. Furthermore, both ignite and tnt use an event-driven API for model fitting, which makes the code less clear and legible for humans.
Let’s briefly discuss its major APIs.
The Trial class implements a PyTorch model fitting interface based on the Trial.run(…) method. There are also predict(…) and evaluate(…) methods for inferring models and assessing saved models. The Trial class also has a state dict method that returns a dictionary comprising the model parameters, optimizer state, and callback states, which can be stored and then reloaded using load_state_dict.
During the fitting procedure, the callback API defines classes that can be used to perform a variety of tasks. A torchbearer component provided to each callback is the mutable state dictionary, which holds intermediate variables required by the Trial. Callbacks can change the nature of the fitting process in real-time as a result of this. Callbacks can be implemented as decorated functions using the decorator API.
The metric API makes use of a tree to allow data to flow from one metric to a set of children. This enables the computation of aggregates such as the running mean or standard deviation. Assembling these data structures can be difficult, so torchbearer includes a decorator API to make it easier. The default_for_key(…) decorator allows the metric to be referenced in the Trial definition with a string.
Training and visualizing SVM
In this section, we will try to implement the SVM model torch and will train, evaluate and visualize the hyperplane using the Torchbearer library. This example is taken from the official repository of the Torchbearer.
SVM seeks the hyperplane with the greatest margin of separation between the data classes. We minimize the following for a soft margin SVM where x is our data:
This can be expressed as an optimization over our weights w and bias b, where we minimize the hinge loss while accounting for a level 2 weight decay term.
Now before modelling this in PyTorch let’s first install and import the Torchbearer library.
# install and import torchbearer !pip install -q torchbearer import torchbearer
After this let’s define the SVM and hinge loss function.
# define SVM import torch.nn as nn class LinearSVM(nn.Module): """Support Vector Machine""" def __init__(self): super(LinearSVM, self).__init__() self.w = nn.Parameter(torch.randn(1, 2), requires_grad=True) self.b = nn.Parameter(torch.randn(1), requires_grad=True) def forward(self, x): h = x.matmul(self.w.t()) + self.b return h # define the loss function def hinge_loss(y_pred, y_true): return torch.mean(torch.clamp(1 - y_pred.t() * y_true, min=0))
Now we’ll create and normalize the synthetic data that will be separated by the hyperplanes.
# load data import numpy as np from sklearn.datasets import make_blobs X, Y = make_blobs(n_samples=1024, centers=2, cluster_std=1.2, random_state=1) X = (X - X.mean()) / X.std() Y[np.where(Y == 0)] = -1 X, Y = torch.FloatTensor(X), torch.FloatTensor(Y) # normalize the data delta = 0.01 x = np.arange(X[:, 0].min(), X[:, 0].max(), delta) y = np.arange(X[:, 1].min(), X[:, 1].max(), delta) x, y = np.meshgrid(x, y) xy = list(map(np.ravel, [x, y]))
Now we’ll define the callbacks and visualization function.
# visualization function from torchbearer import callbacks %matplotlib notebook import matplotlib import matplotlib.pyplot as plt @callbacks.on_step_training @callbacks.only_if(lambda state: state[torchbearer.BATCH] % 10 == 0) def draw_margin(state): w = state[torchbearer.MODEL].w.detach().to('cpu').numpy() b = state[torchbearer.MODEL].b.detach().to('cpu').numpy() z = (w.dot(xy) + b).reshape(x.shape) z[np.where(z > 1.)] = 4 z[np.where((z > 0.) & (z <= 1.))] = 3 z[np.where((z > -1.) & (z <= 0.))] = 2 z[np.where(z <= -1.)] = 1 plt.clf() plt.scatter(x=X[:, 0], y=X[:, 1], c="black", s=10) plt.contourf(x, y, z, cmap=plt.cm.jet, alpha=0.5) fig.canvas.draw()
We’d like to use a soft-margin SVM because we don’t know if our data is linearly separable. To accomplish this, we can use the L2WeightDecay callback in torchbearer. Because we only use a mini-batch at each step to approximate the gradient over all of the data, this entire process is known as subgradient descent. Now let’s train the model and see the result.
# train the model from torchbearer import Trial from torchbearer.callbacks import L2WeightDecay, ExponentialLR import torch.optim as optim device = 'cuda' if torch.cuda.is_available() else 'cpu' fig = plt.figure(figsize=(5, 5)) svm = LinearSVM() model = Trial(svm, optim.SGD(svm.parameters(), 0.1), hinge_loss, ['loss'], callbacks=[draw_margin, ExponentialLR(0.999, step_on_batch=True), L2WeightDecay(0.01, params=[svm.w])]).to(device) model.with_train_data(X, Y, batch_size=32) model.run(epochs=50, verbose=1)
The plot result can be saved in the working directory as below.
#fig will be saved in the working directory fig.savefig('svm.png', bbox_inches='tight')
And here are SVM’s hyperplanes that separate the above synthetic data seamlessly.
Through this article, we have discussed Torchbearer, a library that simplifies the training process and helps in differential kinds of programming for PyTorch models. From the above practical implementation, we have seen how seamlessly we can train the model using this framework. A complete set of built-in callbacks (such as logging, weight decay, and model checkpointing) and a strong metric API are two of torchbearer’s key features.