Guide To learn2learn: A Library For Meta-Learning Research


learn2learn is a software library designed for meta-learning research. It was introduced by Sebastien M. R. Arnold from University of Southern California, Praateek Mahajan from Iterable Inc., Debajyoti Datta from University of Virginia, Ian Bunner from University of Waterloo and Konstantinos Saitas Zarkias from KTH Royal Institute of Technology and Research Institute of Sweden (RISE) in the year 2020 (research paper).

Before going into the details of learn2learn, let us first understand what meta-learning means.

What is Meta-Learning?

Meta-Learning is a sub-domain of machine learning. It deals with the systematic observation of how various ML approaches perform different learning tasks. The model then learns from meta-data to learn more quickly than otherwise possible. It can absorb information from one task and efficiently generalize for accomplishing unseen tasks. Hence meta-learning is also known as ‘learning how to learn. A detailed explanation of this learning paradigm can be found here.


Sign up for your weekly dose of what's up in emerging technology.

Overview of learn2learn

learn2learn has been built on the top of PyTorch library. Prototyping and reproducibility are the two major issues faced by meta-learning researchers which the learn2learn library tackles. Since modern meta-learning methods rely on unconventional ML frameworks, researchers are likely to commit mistakes while prototyping new tasks and algorithms. As a result, it becomes difficult to reproduce existing results. This also is a consequence of inadequate standardized implementations and benchmarks. learn2learn library solves these issues, enabling fast prototyping and correct reproducibility. 

learn2learn provides low-level utilities and a unified interface for the creation of new algorithms and domains. It also provides standardized benchmarks and qualitative implementations of existing algorithms. Besides providing such functionalities, it remains compatible with PyTorch-based libraries such as torchvision, torchaudio, torchtext and cherry.

Components of learn2learn

Practical implementation

Here’s a demonstration of few-shot learning using MAML wrapper for fast-adaptation. MAML (Model-Agnostic Meta-Learning) is a model-agnostic algorithm for meta-learning i.e. it is compatible with any kind of model trained using gradient descent and is applicable to a wide range of tasks such as reinforcement learning, classification and regression. (MAML research paper). The code uses the benchmark interface for loading the mtini-ImageNet dataset.

The code has been implemented in Google colab using Python 3.7.10 and learn2learn 0.1.5 versions. Step-wise explanation of the code is as follows:

  1. Clone the learn2learn GitHub repository

!git clone

  1. Install the learn2learn library using pip command

!pip install learn2learn

  1. Import the required libraries
 import random
 import numpy as np
 import torch
 from torch import nn, optim
 import learn2learn as l2l
 from import (NWays,
                                          ConsecutiveLabels) is a collection of general task transformations (objects which implement the callable interface). Each transformation returns a new task description containing all samples from the dataset under consideration. A task transform modifies the list of task descriptions to create a particular task.  Among the task transforms imported above

  •  NWays keeps samples from N random output labels present in the task description
  • KShots keeps K samples for each output label
  • LoadData loads a sample with given index from the dataset
  • RemapLabels maps the labels of samples taken from K classes to 0,…,K
  • ConsecutiveLabels rearranges the samples in the task description such that they get sorted in consecutive order

4.  Define a function for computing model accuracy

 def accuracy(pred, targets):
     #Maximum predicted value
     pred = pred.argmax(dim=1).view(targets.shape)
   #Take the average of predicted values which match the actual 
   target labels
     return (pred == targets).sum().float() / targets.size(0) 

5. Define a method for fast adaption of the neural network

 def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways,  
     data, labels = batch #separate features and labels
     data, labels =,
     # Separate data into adaptation/evaluation sets
    #initialize adaptation indices to zero
     adapt_indices = np.zeros(data.size(0), dtype=bool)
     adapt_indices[np.arange(shots*ways) * 2] = True
     eval_indices = torch.from_numpy(adapt_indices)
     adapt_indices = torch.from_numpy(adapt_indices)
 #Separate adaptation and evaluation data and labels
     adapt_data, adapt_labels = data[adapt_indices], labels[adapt_indices]
     eval_data, eval_labels = data[eval_indices], labels[eval_indices]
     # Adapt the model
     for step in range(adaptation_steps): #for each adaptation step
        #adaptation error
         adapt_error = loss(learner(adapt_data), adapt_labels) 
 #maml.adapt() takes a gradient step on the loss and updates the cloned  
     # Evaluate the adapted model
     predictions = learner(evaluation_data)
     eval_error = loss(predictions, evaluation_labels) #evaluation error
     #evaluation accuracy
     eval_accuracy = accuracy(predictions, evaluation_labels)
     return eval_error, eval_accuracy 
  1. Define main() method which marks the point from where the code execution begins
 def main(
         ways=5, #number of classes whose samples are to be considered
         shots=5, #number of labeled examples per class/category
         meta_lr=0.003, #learning rate for meta-learning
         fast_lr=0.5, #learning rate for fast adaptation
         adaptation_steps=1, #number of steps for fast-adaptation
         num_iterations=100, #number of times model is to be run
     np.random.seed(seed)   #initialize the random number generator
     device = torch.device('cpu')
 #torch.device is an object which represents the device on which a  
 #torch.Tensor will be allocated
     if cuda and torch.cuda.device_count():
 #if ‘cuda’ parameter passed to main() is set to True and number of  
 #available GPUs is 1, then the logical ‘and’ in the if condition will cause 
 #the condition to be satisfied
 #set seed for random number generation on current GPU
#set the device for torch.Tensor’s allocation
         device = torch.device('cuda') 
#Create Tasksets using the module which #provides an interface to standardized benchmarks
     tasksets =
 ('mini-imagenet',  #name of the benchmark
     #number of samples per train task
train_samples=2*shots,                                                      train_ways=ways, #number of classes per train task
 #number of samples per test task
test_ways=ways, #number of classes per test task
root='~/data', #where the data is stored
  1. Model creation
     model = #sends model parameters to the GPU 
 #Use high-level implementation of MAML algorithm
     maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False)
#’model’ is the module to be wrapped, ‘lr’ is the fast adaptation learning     # rate, ‘first_order’ represents whether to use first-order adaptation of    #MAML
 #use Adam optimization algorithm using torch.optim.Adam()
     opt = optim.Adam(maml.parameters(), meta_lr)
#maml.parameters() give the model parameters to be optimized and ‘meta_lr’ #is the learning rate
     loss = nn.CrossEntropyLoss(reduction='mean') #loss function 
  1. Compute training and validation errors and accuracies   
  for i in range(num_iterations):
#zero out the gradients first so that the parameters update correctly
         train_error = 0.0  #training error
         train_accuracy = 0.0 #training accuracy
         valid_error = 0.0 #validation error
         valid_accuracy = 0.0 #validation accuracy
         for task in range(batch_size): #for each task in the batch
  # Compute meta-training loss
             learner = maml.clone()
#maml.clone() returns a MAML-wrapped copy of the module.Its parameters and #buffers are torch.cloned from original module  
             batch = tasksets.train.sample() #sample of the training taskset
#Call the fast_adapt() method defined in step (3). It will return evaluation #error and accuracy
             evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                                loss,                                                            adaptation_steps,
 #torch.autograd.backward() computes the sum of gradients of the tensors
#Training error and accuracy were initialized to zero. Change them to the  
#computed values
             train_error += evaluation_error.item()
             train_accuracy += evaluation_accuracy.item()
             # Compute meta-validation loss
             learner = maml.clone()
             batch = tasksets.validation.sample()
#Call the fast_adapt() method defined in step (3). It will return evaluation #error and accuracy
             evaluation_error, evaluation_accuracy = fast_adapt(batch,
#Validation error and accuracy were initialized to zero. Change them to #the computed values
             valid_error += evaluation_error.item()
             valid_accuracy += evaluation_accuracy.item()
    # Print the computed metrics for each iteration
          print('Iteration', i) #iteration number
         print('Meta Train Error', train_error / batch_size)
          print('Meta Train Accuracy', train_accuracy / batch_size)
          print('Meta Valid Error', valid_error / batch_size)
          print('Meta Valid Accuracy', valid_accuracy / batch_size)
 # Average the accumulated gradients and optimize
         for para in maml.parameters():
     #divide the gradients by batch size
     / batch_size) 
#torch.optim.Adam.step() performs a single optimization step 
  1. Compute test and validation error and accuracy
 #Initialize test error and test accuracy to zero before updating them
      test_error = 0.0
      test_accuracy = 0.0
#Perform the calculations for each iteration (same process as done for #training and evaluation in step (7)
       for task in range(meta_batch_size):
          # Compute meta-testing loss
          learner = maml.clone()
          batch = tasksets.test.sample()
          evaluation_error, evaluation_accuracy = fast_adapt(batch,
 #Update the test error and accuracy
          test_error += evaluation_error.item()
          test_accuracy += evaluation_accuracy.item()
 #Print the test metrics
 print('Meta Test Error', test_error / batch_size)
      print('Meta Test Accuracy', test_accuracy / batch_size) 
  1. Call the main() method
 if __name__ == '__main__':

The output will show the train, test and validation metrics for all 100 iterations.

Sample output for the first 5 iterations:

Note: The output may vary every time you execute the code and also depending upon the execution environment you choose.


Refer to the following sources for a detailed understanding of the learn2learn library:

More Great AIM Stories

Nikita Shiledarbaxi
A zealous learner aspiring to advance in the domain of AI/ML. Eager to grasp emerging techniques to get insights from data and hence explore realistic Data Science applications as well.

Our Upcoming Events

Conference, in-person (Bangalore)
Machine Learning Developers Summit (MLDS) 2023
19-20th Jan, 2023

Conference, in-person (Bangalore)
Rising 2023 | Women in Tech Conference
16-17th Mar, 2023

Conference, in-person (Bangalore)
Data Engineering Summit (DES) 2023
27-28th Apr, 2023

Conference, in-person (Bangalore)
MachineCon 2023
23rd Jun, 2023

Conference, in-person (Bangalore)
Cypher 2023
20-22nd Sep, 2023

3 Ways to Join our Community

Whatsapp group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our newsletter

Get the latest updates from AIM