MITB Banner

Guide To learn2learn: A Library For Meta-Learning Research

Share

learn2learn

learn2learn is a software library designed for meta-learning research. It was introduced by Sebastien M. R. Arnold from University of Southern California, Praateek Mahajan from Iterable Inc., Debajyoti Datta from University of Virginia, Ian Bunner from University of Waterloo and Konstantinos Saitas Zarkias from KTH Royal Institute of Technology and Research Institute of Sweden (RISE) in the year 2020 (research paper).

Before going into the details of learn2learn, let us first understand what meta-learning means.

What is Meta-Learning?

Meta-Learning is a sub-domain of machine learning. It deals with the systematic observation of how various ML approaches perform different learning tasks. The model then learns from meta-data to learn more quickly than otherwise possible. It can absorb information from one task and efficiently generalize for accomplishing unseen tasks. Hence meta-learning is also known as ‘learning how to learn. A detailed explanation of this learning paradigm can be found here.

Overview of learn2learn

learn2learn has been built on the top of PyTorch library. Prototyping and reproducibility are the two major issues faced by meta-learning researchers which the learn2learn library tackles. Since modern meta-learning methods rely on unconventional ML frameworks, researchers are likely to commit mistakes while prototyping new tasks and algorithms. As a result, it becomes difficult to reproduce existing results. This also is a consequence of inadequate standardized implementations and benchmarks. learn2learn library solves these issues, enabling fast prototyping and correct reproducibility. 

learn2learn provides low-level utilities and a unified interface for the creation of new algorithms and domains. It also provides standardized benchmarks and qualitative implementations of existing algorithms. Besides providing such functionalities, it remains compatible with PyTorch-based libraries such as torchvision, torchaudio, torchtext and cherry.

Components of learn2learn

Practical implementation

Here’s a demonstration of few-shot learning using MAML wrapper for fast-adaptation. MAML (Model-Agnostic Meta-Learning) is a model-agnostic algorithm for meta-learning i.e. it is compatible with any kind of model trained using gradient descent and is applicable to a wide range of tasks such as reinforcement learning, classification and regression. (MAML research paper). The code uses the benchmark interface for loading the mtini-ImageNet dataset.

The code has been implemented in Google colab using Python 3.7.10 and learn2learn 0.1.5 versions. Step-wise explanation of the code is as follows:

  1. Clone the learn2learn GitHub repository

!git clone https://github.com/learnables/learn2learn

  1. Install the learn2learn library using pip command

!pip install learn2learn

  1. Import the required libraries
 import random
 import numpy as np
 import torch
 from torch import nn, optim
 import learn2learn as l2l
 from learn2learn.data.transforms import (NWays,
                                          KShots,
                                          LoadData,
                                          RemapLabels,
                                          ConsecutiveLabels) 

learn2learn.data.transforms is a collection of general task transformations (objects which implement the callable interface). Each transformation returns a new task description containing all samples from the dataset under consideration. A task transform modifies the list of task descriptions to create a particular task.  Among the task transforms imported above

  •  NWays keeps samples from N random output labels present in the task description
  • KShots keeps K samples for each output label
  • LoadData loads a sample with given index from the dataset
  • RemapLabels maps the labels of samples taken from K classes to 0,…,K
  • ConsecutiveLabels rearranges the samples in the task description such that they get sorted in consecutive order

4.  Define a function for computing model accuracy

 def accuracy(pred, targets):
     #Maximum predicted value
     pred = pred.argmax(dim=1).view(targets.shape)
   #Take the average of predicted values which match the actual 
   target labels
     return (pred == targets).sum().float() / targets.size(0) 

5. Define a method for fast adaption of the neural network

 def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways,  
 device):
     data, labels = batch #separate features and labels
     data, labels = data.to(device), labels.to(device)
     # Separate data into adaptation/evaluation sets
    #initialize adaptation indices to zero
     adapt_indices = np.zeros(data.size(0), dtype=bool)
     adapt_indices[np.arange(shots*ways) * 2] = True
     eval_indices = torch.from_numpy(adapt_indices)
     adapt_indices = torch.from_numpy(adapt_indices)
 #Separate adaptation and evaluation data and labels
     adapt_data, adapt_labels = data[adapt_indices], labels[adapt_indices]
     eval_data, eval_labels = data[eval_indices], labels[eval_indices]
     # Adapt the model
     for step in range(adaptation_steps): #for each adaptation step
        #adaptation error
         adapt_error = loss(learner(adapt_data), adapt_labels) 
         learner.adapt(adaptation_error)
 #maml.adapt() takes a gradient step on the loss and updates the cloned  
 #parameters
     # Evaluate the adapted model
     predictions = learner(evaluation_data)
     eval_error = loss(predictions, evaluation_labels) #evaluation error
     #evaluation accuracy
     eval_accuracy = accuracy(predictions, evaluation_labels)
     return eval_error, eval_accuracy 
  1. Define main() method which marks the point from where the code execution begins
 def main(
         ways=5, #number of classes whose samples are to be considered
         shots=5, #number of labeled examples per class/category
         meta_lr=0.003, #learning rate for meta-learning
         fast_lr=0.5, #learning rate for fast adaptation
         batch_size=32,
         adaptation_steps=1, #number of steps for fast-adaptation
         num_iterations=100, #number of times model is to be run
         cuda=True,
         seed=42,
 ):
     np.random.seed(seed)   #initialize the random number generator
     torch.manual_seed(seed)
     device = torch.device('cpu')
 #torch.device is an object which represents the device on which a  
 #torch.Tensor will be allocated
     if cuda and torch.cuda.device_count():
 #if ‘cuda’ parameter passed to main() is set to True and number of  
 #available GPUs is 1, then the logical ‘and’ in the if condition will cause 
 #the condition to be satisfied
 #set seed for random number generation on current GPU
         torch.cuda.manual_seed(seed) 
#set the device for torch.Tensor’s allocation
         device = torch.device('cuda') 
#Create Tasksets using the learn2learn.vision.benchmarks module which #provides an interface to standardized benchmarks
     tasksets = l2l.vision.benchmarks.get_tasksets
 ('mini-imagenet',  #name of the benchmark
     #number of samples per train task
train_samples=2*shots,                                                      train_ways=ways, #number of classes per train task
 #number of samples per test task
test_samples=2*shots, 
test_ways=ways, #number of classes per test task
root='~/data', #where the data is stored
) 
  1. Model creation
     model = l2l.vision.models.MiniImagenetCNN(ways)
     model.to(device) #sends model parameters to the GPU 
 #Use high-level implementation of MAML algorithm
     maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False)
#’model’ is the module to be wrapped, ‘lr’ is the fast adaptation learning     # rate, ‘first_order’ represents whether to use first-order adaptation of    #MAML
 #use Adam optimization algorithm using torch.optim.Adam()
     opt = optim.Adam(maml.parameters(), meta_lr)
#maml.parameters() give the model parameters to be optimized and ‘meta_lr’ #is the learning rate
     loss = nn.CrossEntropyLoss(reduction='mean') #loss function 
  1. Compute training and validation errors and accuracies   
  for i in range(num_iterations):
#zero out the gradients first so that the parameters update correctly
         opt.zero_grad() 
         train_error = 0.0  #training error
         train_accuracy = 0.0 #training accuracy
         valid_error = 0.0 #validation error
         valid_accuracy = 0.0 #validation accuracy
         for task in range(batch_size): #for each task in the batch
  # Compute meta-training loss
             learner = maml.clone()
#maml.clone() returns a MAML-wrapped copy of the module.Its parameters and #buffers are torch.cloned from original module  
             batch = tasksets.train.sample() #sample of the training taskset
#Call the fast_adapt() method defined in step (3). It will return evaluation #error and accuracy
             evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                                learner,
                                                                loss,                                                            adaptation_steps,
                                                                shots,
                                                                ways,
                                                                device)
 #torch.autograd.backward() computes the sum of gradients of the tensors
             evaluation_error.backward()
#Training error and accuracy were initialized to zero. Change them to the  
#computed values
             train_error += evaluation_error.item()
             train_accuracy += evaluation_accuracy.item()
             # Compute meta-validation loss
             learner = maml.clone()
             batch = tasksets.validation.sample()
#Call the fast_adapt() method defined in step (3). It will return evaluation #error and accuracy
             evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                                learner,
                                                                loss,
                                                          adaptation_steps,
                                                                shots,
                                                                ways,
                                                                device)
#Validation error and accuracy were initialized to zero. Change them to #the computed values
             valid_error += evaluation_error.item()
             valid_accuracy += evaluation_accuracy.item()
    # Print the computed metrics for each iteration
         print('\n')
          print('Iteration', i) #iteration number
         print('Meta Train Error', train_error / batch_size)
          print('Meta Train Accuracy', train_accuracy / batch_size)
          print('Meta Valid Error', valid_error / batch_size)
          print('Meta Valid Accuracy', valid_accuracy / batch_size)
 # Average the accumulated gradients and optimize
         for para in maml.parameters():
     #divide the gradients by batch size
              para.grad.data.mul_(1.0 / batch_size) 
#torch.optim.Adam.step() performs a single optimization step 
          opt.step() 
  1. Compute test and validation error and accuracy
 #Initialize test error and test accuracy to zero before updating them
      test_error = 0.0
      test_accuracy = 0.0
#Perform the calculations for each iteration (same process as done for #training and evaluation in step (7)
       for task in range(meta_batch_size):
          # Compute meta-testing loss
          learner = maml.clone()
          batch = tasksets.test.sample()
          evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                            learner,
                                                            loss,
                                                          adaptation_steps,
                                                            shots,
                                                            ways,
                                                            device)
 #Update the test error and accuracy
          test_error += evaluation_error.item()
          test_accuracy += evaluation_accuracy.item()
 #Print the test metrics
 print('Meta Test Error', test_error / batch_size)
      print('Meta Test Accuracy', test_accuracy / batch_size) 
  1. Call the main() method
 if __name__ == '__main__':
      main() 

The output will show the train, test and validation metrics for all 100 iterations.

Sample output for the first 5 iterations:

Note: The output may vary every time you execute the code and also depending upon the execution environment you choose.

References

Refer to the following sources for a detailed understanding of the learn2learn library:

Share
Picture of Nikita Shiledarbaxi

Nikita Shiledarbaxi

A zealous learner aspiring to advance in the domain of AI/ML. Eager to grasp emerging techniques to get insights from data and hence explore realistic Data Science applications as well.
Related Posts

CORPORATE TRAINING PROGRAMS ON GENERATIVE AI

Generative AI Skilling for Enterprises

Our customized corporate training program on Generative AI provides a unique opportunity to empower, retain, and advance your talent.

Upcoming Large format Conference

May 30 and 31, 2024 | 📍 Bangalore, India

Download the easiest way to
stay informed

Subscribe to The Belamy: Our Weekly Newsletter

Biggest AI stories, delivered to your inbox every week.

AI Courses & Careers

Become a Certified Generative AI Engineer

AI Forum for India

Our Discord Community for AI Ecosystem, In collaboration with NVIDIA. 

Flagship Events

Rising 2024 | DE&I in Tech Summit

April 4 and 5, 2024 | 📍 Hilton Convention Center, Manyata Tech Park, Bangalore

MachineCon GCC Summit 2024

June 28 2024 | 📍Bangalore, India

MachineCon USA 2024

26 July 2024 | 583 Park Avenue, New York

Cypher India 2024

September 25-27, 2024 | 📍Bangalore, India

Cypher USA 2024

Nov 21-22 2024 | 📍Santa Clara Convention Center, California, USA

Data Engineering Summit 2024

May 30 and 31, 2024 | 📍 Bangalore, India

Subscribe to Our Newsletter

The Belamy, our weekly Newsletter is a rage. Just enter your email below.