Metric Learning is defined as learning distance functions over multiple objects. PyTorch Metric Learning (PML) is an open-source library that eases the tedious and time-consuming task of implementing various deep metric learning algorithms. It was introduced by Kevin Musgrave and Serge Belongie of Cornell Tech and Ser-Nam Lim of Facebook AI in August 2020 (research paper).
The flexible and modular design of the PML library enables the implementing various combinations of algorithms in the existing code. Several algorithms can also be combined for a complete train/test workflow.
Modules of PyTorch Metric Learning
- Losses – classes to apply various loss functions
- Distances – include classes that compute pairwise distances or similarities between input embeddings
- Reducers – specify ways to go from several loss values to a single loss value
- Regularizers – applied to weights and embeddings for regularization.
- Miners
PML provides two types of mining function:
AIM Daily XO
Join our editors every weekday evening as they steer you through the most significant news of the day, introduce you to fresh perspectives, and provide unexpected moments of joy
Your newsletter subscriptions are subject to AIM Privacy Policy and Terms and Conditions.
- Subset Batch Miners
- Tuple Miners
- Samplers – They are extensions of torch.utils.data.Sampler class. They determine how batches of samples should be formed.
- Trainers – Trainers module provides access to the metric learning algorithms, which require data augmentation, additional networks etc., apart from the loss or mining functions.
- Testers – They take a model and dataset as input and find nearest-neighbour based accuracy metrics. (Using testers requires the installation of faiss package)
- Utils
- AccuracyCalculator class to calculate several accuracy metrics given a query and reference embeddings
- Inference models: utils.inference comprises classes for finding matching pairs within a batch, or from a set of pairs
- Logging Presets – The logging_presets module provides hooks for logging data, early stoppage during training, validating and saving models. It requires the record-keeper and tensorboard packages which you need to install as follows:
pip install record-keeper tensorboard
The following figure gives an overview of the main modules of the PML library:
Download our Mobile App
Components of a loss function:
Images’ source: Research paper
Required PyTorch version for PyTorch Metric Learning
- pytorch-metric-learning >= v0.9.90 requires torch >= 1.6
- pytorch-metric-learning < v0.9.90 does not have specific version requirement, but was tested with torch >= 1.2
Practical implementation of PyTorch Metric Learning
Here’s a demonstration of using TrainWithClassifier trainer of PML on CIFAR100 dataset. The code has been implemented in Google colab with Python 3.7.10 and torch 1.8.0 versions. Step-wise explanation of the code is as follows:
- Install required packages
#Install PML !pip install -q pytorch-metric-learning[with-hooks] #Install record keeper for logging information !pip install record_keeper
- Import required libraries
%matplotlib inline from pytorch_metric_learning import losses, miners, samplers, trainers, testers from pytorch_metric_learning.utils import common_functions import pytorch_metric_learning.utils.logging_presets as logging_presets import numpy as np import torchvision from torchvision import datasets, transforms import torch import torch.nn as nn from PIL import Image import logging import matplotlib.pyplot as plt import umap from cycler import cycler import record_keeper import pytorch_metric_learning logging.getLogger().setLevel(logging.INFO) logging.info("VERSION %s"%pytorch_metric_learning.__version__) Define the model
3. Define the model
class MLP(nn.Module): #multilayer perceptron model # sizes[0] is the dimension of input # sizes[-1] is the dimension of output def __init__(self, layerSizes, final_relu=False): #constructor method super().__init__() #refer to the base class list = [] #list of model layers sizes = [int(x) for x in sizes] #number of neurons in each layer num = len(sizes) - 1 #number of layers #output layer final_relu_layer = num if final_relu else num - 1 for i in range(len(sizes) - 1): #for each layer ip_size = sizes[i] #number of input features op_size = sizes[i + 1] #number of output features if i < final_relu_layer: #for each intermediate layer #apply ReLu activation function list.append(nn.ReLU(inplace=False)) #apply linear transformation list.append(nn.Linear(ip_size, op_size)) #a sequential container for adding modules of the layer self.net = nn.Sequential(*list) self.last_linear = self.net[-1] #output layer def forward(self, x): #forward propagation return self.net(x)
4. Specify device on which torch.Tensor will be allocated
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5. Set trunk model and replace the softmax layer with an identity function
Here, we have used an 18 layers deep convolutional network (ResNet18)
trunk = torchvision.models.resnet18(pretrained=True) trunk_output_size = trunk.fc.in_features #number of inputs for linear layer trunk.fc = common_functions.Identity() #fully connected layer trunk = torch.nn.DataParallel(trunk.to(device)) #torch.Tensor.to() performs device conversion
6. Set embedder model. Output of the trunk model is fed to it as input and it outputs 64 dimensional embeddings.
emb = torch.nn.DataParallel(MLP([trunk_output_size, 64]).to(device))
7. Training set here has the first 50 classes of the CIFAR100 dataset. Define the classifier which will take the embeddings as input and output a 50 dimensional vector.
classifier = torch.nn.DataParallel(MLP([64, 50])).to(device)
8. Initialize optimizers
We have used Adam optimization algorithm
#optimize trunk model trunk_opt = torch.optim.Adam(trunk.parameters(), lr=0.00001, weight_decay=0.0001)ning rate #’lr’ denotes lear #optimize embedder embedder_opt = torch.optim.Adam(emb.parameters(), lr=0.0001, weight_decay=0.0001) classifier_opt = torch.optim.Adam(classifier.parameters(), lr=0.0001, weight_decay=0.0001)
9. Set the image transforms
#For training data train_trf = transforms.Compose([transforms.Resize(64), #compose man image transforms simultaneously #crop the image to specified aspect ratio and size transforms.RandomResizedCrop(scale=(0.16, 1), ratio=(0.75, 1.33),size=64), #Randomly flip the images horizontally with 50% probability transforms.RandomHorizontalFlip(0.5), #Convert the images to tenors transforms.ToTensor(), #Normalize the tensor image with specified mean and standard deviation transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) #Repeat similar process for validation data val_transform = transforms.Compose([transforms.Resize(64), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
10. Download the original training and validation datasets
orig_train = datasets.CIFAR100(root="CIFAR100_Dataset", train=True, transform=None, download=True) orig_val = datasets.CIFAR100(root="CIFAR100_Dataset", train=False, transform=None, download=True)
11. Create training and validation sets that are class-disjoint
class ClassDisjointCIFAR100(torch.utils.data.Dataset): def __init__(self, orig_train, orig_val, train, transform): #Form rule to choose samples belonging to first 50 output categories rule = (lambda x: x < 50) if train else (lambda x: x >=50) #update training and validation sets with above defined rule #index of suitable records train_new = [i for i,x in enumerate(orig_train.targets) if rule(x)] val_new = [i for i,x in enumerate(orig_val.targets) if rule(x)] #add the updated data and corresponding labels self.data = np.concatenate([orig_train.data[train_new], orig_val.data[val_new]], axis=0) self.targets = np.concatenate([np.array(orig_train.targets) [train_new],p.array(orig_val.targets)[val_new]], axis=0) self.transform = transform def __len__(self): #function to get length of updated data return len(self.data) def __getitem__(self, index): #function to extract image and corresponding label img, target = self.data[index], self.targets[index] img = Image.fromarray(img) #creates an image memory from an object if self.transform is not None: img = self.transform(img) #perform image transformation return img, target
12. Initialize class disjoint training and validation set
train_set = ClassDisjointCIFAR100(orig_train, origl_val, True, t rain_transform) val_set = ClassDisjointCIFAR100(orig_train, orig_val, False, val_transform) #debug the code to check if classes are disjoint using assert keyword assert set(train_set.targets).isdisjoint(set(val_set.targets))
13. Initialize the loss function
loss = losses.TripletMarginLoss(margin=0.1) #Classification loss clf_loss = torch.nn.CrossEntropyLoss()
14. Initialize the mining function
m = miners.MultiSimilarityMiner(epsilon=0.1)
Set the data loader sampler; if not specified, random sampling is used
smpl = samplers.MPerClassSampler(train_dataset.targets, m=4, length_before_new_iter=len(train_dataset))
Set other training parameters
batch_size = 32 epochs = 4 #number of epochs
15. Form a dictionary of above defined models, optimizers, loss functions and mining functions
models = {"trunk": trunk, "embedder": embedder, "classifier": classifier} opt = {"trunk_optimizer": trunk_optimizer, "embedder_optimizer": embedder_optimizer, "classifier_optimizer": classifier_optimizer} loss_f = {"metric_loss": loss, "classifier_loss": classification_loss} mining_f = {"tuple_miner": miner} # Specify loss weights loss_wts = {"metric_loss": 1, "classifier_loss": 0.5} #a dictionary mapping loss names to numbers
16. Create training and testing hooks using logging_presets module
record_keeper, _, _ = logging_presets.get_record_keeper("logs", "tensorboard") hooks = logging_presets.get_hook_container(record_keeper) dataset_dictionary = {"validation set": val_set} model_folder = "saved_models"
17. Define a function for visualiser hook
def visualizer_hook(umapper, umap_embeddings, labels, split_name, keyname, *args): logging.info("UMAP plot for the {} split and label set {}".format(split_name, keyname)) #logging.info() logs the message with level INFO on the logger label_set = np.unique(labels) #unique labels num_classes = len(label_set) #number of output classes fig = plt.figure(figsize=(20,15)) plt.gca().set_prop_cycle(cycler("color", [plt.cm.nipy_spectral(i) for i in np.linspace(0, 0.9, num_classes)])) #matplotlib.pyplot.gca() used above is used to get the current axes and matplotlib.axes.Axes.set_prop_cycle() sets the property cycle of the axes for i in range(num_classes): #get index of records having ith unique label index = labels == label_set[i] plt.plot(umap_embeddings[index, 0], umap_embeddings[index, 1], ".", markersize=1) plt.show()
UMAP (Uniform Manifold Approximation and Projection) is a dimension reduction technique that can be used for visualization.
18. Create the tester
test = testers.GlobalEmbeddingSpaceTester(end_of_testing_hook = hooks.end_of_testing_hook,visualizer = umap.UMAP(), visualizer_hook = visualizer_hook,dataloader_num_workers = 32)
testers.GlobalEmbeddingSpaceTester() finds nearest neighbours by considering all the points in the embedding space
19. Initialize hook for end of epoch. It performs some operation such as logging data at the end of every epoch.
end_of_epoch = hooks.end_of_epoch_hook(tester, dataset_dictionay, model_folder, test_interval = 1, patience = 1)
20. Model trainer
Since we have trunk model -> embedder model -> classifier architecture, we have used TrainWithClassifier trainer. It applies a metric loss and a classification loss to the utput of embedder network and classifier network output respectively.
trainer = trainers.TrainWithClassifier(models, Opt, #optimizers batch_size, loss_f, #loss function mining_f, #mining function train_set, sampler=smpl, dataloader_num_workers = 32, loss_weights = loss_wts, end_of_iteration_hook = hooks.end_of_iteration_hook, end_of_epoch_hook = end_of_epoch)
21. Model training
trainer.train(num_epochs=epochs)
Sample output plots for two epochs:
- Code source
- Google colab notebook of the above implementation
References
For a detailed understanding of the PML library, refer to the following sources: