Guide To PyTorch Metric Learning: A Library For Implementing Metric Learning Algorithms


Metric Learning is defined as learning distance functions over multiple objects. PyTorch Metric Learning (PML) is an open-source library that eases the tedious and time-consuming task of implementing various deep metric learning algorithms. It was introduced by Kevin Musgrave and Serge Belongie of Cornell Tech and Ser-Nam Lim of Facebook AI in August 2020 (research paper).

The flexible and modular design of the PML library enables the implementing various combinations of algorithms in the existing code. Several algorithms can also be combined for a complete train/test workflow. 

Modules of PyTorch Metric Learning

  1. Losses – classes to apply various loss functions
  2. Distances – include classes that compute pairwise distances or similarities between input embeddings
  3. Reducers – specify ways to go from several loss values to a single loss value
  4. Regularizers – applied to weights and embeddings for regularization.
  5. Miners

PML provides two types of mining function:


Sign up for your weekly dose of what's up in emerging technology.
  • Subset Batch Miners
  • Tuple Miners
  1. Samplers – They are extensions of class. They determine how batches of samples should be formed.
  2. Trainers – Trainers module provides access to the metric learning algorithms, which require data augmentation, additional networks etc., apart from the loss or mining functions.
  3. Testers – They take a model and dataset as input and find nearest-neighbour based accuracy metrics. (Using testers requires the installation of faiss package)
  4. Utils 
  • AccuracyCalculator class to calculate several accuracy metrics given a query and reference embeddings
  • Inference models: utils.inference comprises classes for finding matching pairs within a batch, or from a set of pairs
  • Logging Presets – The logging_presets module provides hooks for logging data, early stoppage during training, validating and saving models. It requires the record-keeper and tensorboard packages which you need to install as follows:

pip install record-keeper tensorboard

The following figure gives an overview of the main modules of the PML library:

PyTorch Metric Learning modules

Components of a loss function:

PyTorch Metric Learning loss function components

Images’ source: Research paper

Required PyTorch version for PyTorch Metric Learning

  • pytorch-metric-learning >= v0.9.90 requires torch >= 1.6
  • pytorch-metric-learning < v0.9.90 does not have specific version requirement, but was tested with torch >= 1.2

Practical implementation of PyTorch Metric Learning

Here’s a demonstration of using TrainWithClassifier trainer of PML on CIFAR100 dataset. The code has been implemented in Google colab with Python 3.7.10 and torch 1.8.0 versions. Step-wise explanation of the code is as follows:

  1. Install required packages
 #Install PML
 !pip install -q pytorch-metric-learning[with-hooks]
 #Install record keeper for logging information
 !pip install record_keeper 
  1. Import required libraries
 %matplotlib inline
 from pytorch_metric_learning import losses, miners, samplers, trainers, testers
 from pytorch_metric_learning.utils import common_functions
 import pytorch_metric_learning.utils.logging_presets as logging_presets
 import numpy as np
 import torchvision
 from torchvision import datasets, transforms
 import torch
 import torch.nn as nn
 from PIL import Image
 import logging
 import matplotlib.pyplot as plt
 import umap
 from cycler import cycler
 import record_keeper
 import pytorch_metric_learning
 logging.getLogger().setLevel(logging.INFO)"VERSION %s"%pytorch_metric_learning.__version__)  Define the model

3. Define the model

 class MLP(nn.Module): #multilayer perceptron model
     # sizes[0] is the dimension of input
     # sizes[-1] is the dimension of output
     def __init__(self, layerSizes, final_relu=False):  #constructor method
         super().__init__()  #refer to the base class
         list = []  #list of model layers
         sizes = [int(x) for x in sizes]  #number of neurons in each layer
         num = len(sizes) - 1 #number of layers
        #output layer
         final_relu_layer = num if final_relu else num - 1
         for i in range(len(sizes) - 1):  #for each layer
             ip_size = sizes[i]  #number of input features
             op_size = sizes[i + 1] #number of output features
             if i < final_relu_layer:  #for each intermediate layer
                 #apply ReLu activation function
                #apply linear transformation
                list.append(nn.Linear(ip_size, op_size)) 
#a sequential container for adding modules of the layer = nn.Sequential(*list) 
         self.last_linear =[-1]  #output layer
     def forward(self, x):  #forward propagation

4. Specify device on which torch.Tensor will be allocated

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

5. Set trunk model and replace the softmax layer with an identity function

Here, we have used an 18 layers deep convolutional network (ResNet18)

 trunk = torchvision.models.resnet18(pretrained=True)
 trunk_output_size = trunk.fc.in_features #number of inputs for linear layer
 trunk.fc = common_functions.Identity() #fully connected layer
 trunk = torch.nn.DataParallel( performs device conversion 

6. Set embedder model. Output of the trunk model is fed to it as input and it outputs 64 dimensional embeddings.

emb = torch.nn.DataParallel(MLP([trunk_output_size, 64]).to(device))

7. Training set here has the first 50 classes of the CIFAR100 dataset. Define the classifier which will take the embeddings as input and output a 50 dimensional vector.

classifier = torch.nn.DataParallel(MLP([64, 50])).to(device)

8. Initialize optimizers

We have used Adam optimization algorithm

 #optimize trunk model
 trunk_opt = torch.optim.Adam(trunk.parameters(), lr=0.00001, weight_decay=0.0001)ning rate
 #’lr’ denotes lear
 #optimize embedder 
 embedder_opt = torch.optim.Adam(emb.parameters(), lr=0.0001, weight_decay=0.0001)
 classifier_opt = torch.optim.Adam(classifier.parameters(), lr=0.0001, weight_decay=0.0001) 

9. Set the image transforms

 #For training data
 train_trf = transforms.Compose([transforms.Resize(64),
    #compose man image transforms simultaneously
                #crop the image to specified aspect ratio and size
                 transforms.RandomResizedCrop(scale=(0.16, 1), ratio=(0.75, 
 #Randomly flip the images horizontally with 50% probability
                #Convert the images to tenors
 #Normalize the tensor image with specified mean and standard deviation
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229,         
                0.224, 0.225])])
 #Repeat similar process for validation data
 val_transform = transforms.Compose([transforms.Resize(64),
                           transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])]) 

10. Download the original training and validation datasets

orig_train = datasets.CIFAR100(root="CIFAR100_Dataset", train=True, transform=None, download=True)
orig_val = datasets.CIFAR100(root="CIFAR100_Dataset", train=False, transform=None, download=True) 

11. Create training and validation sets that are class-disjoint

 class ClassDisjointCIFAR100(
     def __init__(self, orig_train, orig_val, train, transform):
 #Form rule to choose samples belonging to first 50 output categories
         rule = (lambda x: x < 50) if train else (lambda x: x >=50)
 #update training and validation sets with above defined rule
         #index of suitable records
         train_new = [i for i,x in enumerate(orig_train.targets) if rule(x)]
         val_new = [i for i,x in enumerate(orig_val.targets) if rule(x)]
        #add the updated data and corresponding labels = np.concatenate([[train_new],   [val_new]], axis=0)
         self.targets = np.concatenate([np.array(orig_train.targets) 
         [train_new],p.array(orig_val.targets)[val_new]], axis=0)
         self.transform = transform
     def __len__(self):  #function to get length of updated data
         return len(
     def __getitem__(self, index):   
   #function to extract image and corresponding label        
         img, target =[index], self.targets[index]
         img = Image.fromarray(img)  #creates an image memory from an object
         if self.transform is not None:
             img = self.transform(img) #perform image transformation
         return img, target 

12. Initialize class disjoint training and validation set

 train_set = ClassDisjointCIFAR100(orig_train, origl_val, True,     t
 val_set = ClassDisjointCIFAR100(orig_train, orig_val, False, val_transform)
 #debug the code to check if classes are disjoint using assert keyword
 assert set(train_set.targets).isdisjoint(set(val_set.targets)) 

13.  Initialize the loss function

 loss = losses.TripletMarginLoss(margin=0.1)
 #Classification loss
 clf_loss = torch.nn.CrossEntropyLoss() 

14. Initialize the mining function

m = miners.MultiSimilarityMiner(epsilon=0.1)

Set the data loader sampler; if not specified, random sampling is used

smpl = samplers.MPerClassSampler(train_dataset.targets, m=4, length_before_new_iter=len(train_dataset))

Set other training parameters

 batch_size = 32
 epochs = 4  #number of epochs 

15. Form a dictionary of above defined models, optimizers, loss functions and mining functions

 models = {"trunk": trunk, "embedder": embedder, "classifier": classifier}
 opt = {"trunk_optimizer": trunk_optimizer, "embedder_optimizer": embedder_optimizer, "classifier_optimizer": classifier_optimizer}
 loss_f = {"metric_loss": loss, "classifier_loss": classification_loss}
 mining_f = {"tuple_miner": miner}
 # Specify loss weights  
 loss_wts = {"metric_loss": 1, "classifier_loss": 0.5}
 #a dictionary mapping loss names to numbers 

16. Create training and testing hooks using logging_presets module

 record_keeper, _, _ = logging_presets.get_record_keeper("logs", "tensorboard")
 hooks = logging_presets.get_hook_container(record_keeper)
 dataset_dictionary = {"validation set": val_set} 
 model_folder = "saved_models"   

17. Define a function for visualiser hook

 def visualizer_hook(umapper, umap_embeddings, labels, split_name, keyname, *args):"UMAP plot for the {} split and label set  
     {}".format(split_name, keyname)) logs the message with level INFO on the logger
     label_set = np.unique(labels)   #unique labels
     num_classes = len(label_set)  #number of output classes 
     fig = plt.figure(figsize=(20,15)) 
     plt.gca().set_prop_cycle(cycler("color", [ for i   
     in np.linspace(0, 0.9, num_classes)]))
 #matplotlib.pyplot.gca() used above is used to get the current axes and matplotlib.axes.Axes.set_prop_cycle() sets the property cycle of the axes
     for i in range(num_classes):  
#get index  of records having ith unique label
         index = labels == label_set[i]  
         plt.plot(umap_embeddings[index, 0], umap_embeddings[index, 1], ".",   

UMAP (Uniform Manifold Approximation and Projection) is a dimension reduction technique that can be used for visualization.

18. Create the tester

test = testers.GlobalEmbeddingSpaceTester(end_of_testing_hook = hooks.end_of_testing_hook,visualizer = umap.UMAP(),                         visualizer_hook = visualizer_hook,dataloader_num_workers = 32)

testers.GlobalEmbeddingSpaceTester() finds nearest neighbours by considering all the points in the embedding space

19. Initialize hook for end of epoch. It performs some operation such as logging data at the end of every epoch.

 end_of_epoch = hooks.end_of_epoch_hook(tester, 
                                             test_interval = 1,
                                             patience = 1) 

20. Model trainer 

Since we have trunk model -> embedder model -> classifier architecture, we have used TrainWithClassifier trainer. It applies a metric loss and a classification loss to the utput of embedder network and classifier network output respectively.

 trainer = trainers.TrainWithClassifier(models,
                                 Opt,  #optimizers
                                 loss_f, #loss function
                                 mining_f, #mining function
                                 dataloader_num_workers = 32,
                                 loss_weights = loss_wts,
                                 end_of_iteration_hook =  
                                 end_of_epoch_hook = end_of_epoch) 

21. Model training


Sample output plots for two epochs:


For a detailed understanding of the PML library, refer to the following sources:

More Great AIM Stories

Nikita Shiledarbaxi
A zealous learner aspiring to advance in the domain of AI/ML. Eager to grasp emerging techniques to get insights from data and hence explore realistic Data Science applications as well.

Our Upcoming Events

Conference, in-person (Bangalore)
Machine Learning Developers Summit (MLDS) 2023
19-20th Jan, 2023

Conference, in-person (Bangalore)
Rising 2023 | Women in Tech Conference
16-17th Mar, 2023

Conference, in-person (Bangalore)
Data Engineering Summit (DES) 2023
27-28th Apr, 2023

Conference, in-person (Bangalore)
MachineCon 2023
23rd Jun, 2023

Conference, in-person (Bangalore)
Cypher 2023
20-22nd Sep, 2023

3 Ways to Join our Community

Whatsapp group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our newsletter

Get the latest updates from AIM