learn2learn is a software library designed for meta-learning research. It was introduced by Sebastien M. R. Arnold from University of Southern California, Praateek Mahajan from Iterable Inc., Debajyoti Datta from University of Virginia, Ian Bunner from University of Waterloo and Konstantinos Saitas Zarkias from KTH Royal Institute of Technology and Research Institute of Sweden (RISE) in the year 2020 (research paper).
Before going into the details of learn2learn, let us first understand what meta-learning means.
What is Meta-Learning?
Meta-Learning is a sub-domain of machine learning. It deals with the systematic observation of how various ML approaches perform different learning tasks. The model then learns from meta-data to learn more quickly than otherwise possible. It can absorb information from one task and efficiently generalize for accomplishing unseen tasks. Hence meta-learning is also known as ‘learning how to learn. A detailed explanation of this learning paradigm can be found here.
Overview of learn2learn
learn2learn has been built on the top of PyTorch library. Prototyping and reproducibility are the two major issues faced by meta-learning researchers which the learn2learn library tackles. Since modern meta-learning methods rely on unconventional ML frameworks, researchers are likely to commit mistakes while prototyping new tasks and algorithms. As a result, it becomes difficult to reproduce existing results. This also is a consequence of inadequate standardized implementations and benchmarks. learn2learn library solves these issues, enabling fast prototyping and correct reproducibility.
learn2learn provides low-level utilities and a unified interface for the creation of new algorithms and domains. It also provides standardized benchmarks and qualitative implementations of existing algorithms. Besides providing such functionalities, it remains compatible with PyTorch-based libraries such as torchvision, torchaudio, torchtext and cherry.
Components of learn2learn
- learn2learn.data: A set of utilities for loading, preprocessing and sampling of data and tasks
- learn2learn.algorithms: A set of high-level implementations of several algorithms
- learn2learn.optim: A set of utilities for differentiable optimization algorithms
- learn2learn.nn: A set of modules used for meta-learning
- learn2learn.vision: Models, datasets and other utilities for Computer Vision tasks
- learn2learn.gym: Models, environments and other utilities for reinforcement learning and OpenAI Gym
Practical implementation
Here’s a demonstration of few-shot learning using MAML wrapper for fast-adaptation. MAML (Model-Agnostic Meta-Learning) is a model-agnostic algorithm for meta-learning i.e. it is compatible with any kind of model trained using gradient descent and is applicable to a wide range of tasks such as reinforcement learning, classification and regression. (MAML research paper). The code uses the benchmark interface for loading the mtini-ImageNet dataset.
The code has been implemented in Google colab using Python 3.7.10 and learn2learn 0.1.5 versions. Step-wise explanation of the code is as follows:
- Clone the learn2learn GitHub repository
!git clone https://github.com/learnables/learn2learn
- Install the learn2learn library using pip command
!pip install learn2learn
- Import the required libraries
import random import numpy as np import torch from torch import nn, optim import learn2learn as l2l from learn2learn.data.transforms import (NWays, KShots, LoadData, RemapLabels, ConsecutiveLabels)
learn2learn.data.transforms is a collection of general task transformations (objects which implement the callable interface). Each transformation returns a new task description containing all samples from the dataset under consideration. A task transform modifies the list of task descriptions to create a particular task. Among the task transforms imported above
- NWays keeps samples from N random output labels present in the task description
- KShots keeps K samples for each output label
- LoadData loads a sample with given index from the dataset
- RemapLabels maps the labels of samples taken from K classes to 0,…,K
- ConsecutiveLabels rearranges the samples in the task description such that they get sorted in consecutive order
4. Define a function for computing model accuracy
def accuracy(pred, targets): #Maximum predicted value pred = pred.argmax(dim=1).view(targets.shape) #Take the average of predicted values which match the actual target labels return (pred == targets).sum().float() / targets.size(0)
5. Define a method for fast adaption of the neural network
def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device): data, labels = batch #separate features and labels data, labels = data.to(device), labels.to(device) # Separate data into adaptation/evaluation sets #initialize adaptation indices to zero adapt_indices = np.zeros(data.size(0), dtype=bool) adapt_indices[np.arange(shots*ways) * 2] = True eval_indices = torch.from_numpy(adapt_indices) adapt_indices = torch.from_numpy(adapt_indices) #Separate adaptation and evaluation data and labels adapt_data, adapt_labels = data[adapt_indices], labels[adapt_indices] eval_data, eval_labels = data[eval_indices], labels[eval_indices] # Adapt the model for step in range(adaptation_steps): #for each adaptation step #adaptation error adapt_error = loss(learner(adapt_data), adapt_labels) learner.adapt(adaptation_error) #maml.adapt() takes a gradient step on the loss and updates the cloned #parameters # Evaluate the adapted model predictions = learner(evaluation_data) eval_error = loss(predictions, evaluation_labels) #evaluation error #evaluation accuracy eval_accuracy = accuracy(predictions, evaluation_labels) return eval_error, eval_accuracy
- Define main() method which marks the point from where the code execution begins
def main( ways=5, #number of classes whose samples are to be considered shots=5, #number of labeled examples per class/category meta_lr=0.003, #learning rate for meta-learning fast_lr=0.5, #learning rate for fast adaptation batch_size=32, adaptation_steps=1, #number of steps for fast-adaptation num_iterations=100, #number of times model is to be run cuda=True, seed=42, ): np.random.seed(seed) #initialize the random number generator torch.manual_seed(seed) device = torch.device('cpu') #torch.device is an object which represents the device on which a #torch.Tensor will be allocated if cuda and torch.cuda.device_count(): #if ‘cuda’ parameter passed to main() is set to True and number of #available GPUs is 1, then the logical ‘and’ in the if condition will cause #the condition to be satisfied #set seed for random number generation on current GPU torch.cuda.manual_seed(seed) #set the device for torch.Tensor’s allocation device = torch.device('cuda') #Create Tasksets using the learn2learn.vision.benchmarks module which #provides an interface to standardized benchmarks tasksets = l2l.vision.benchmarks.get_tasksets ('mini-imagenet', #name of the benchmark #number of samples per train task train_samples=2*shots, train_ways=ways, #number of classes per train task #number of samples per test task test_samples=2*shots, test_ways=ways, #number of classes per test task root='~/data', #where the data is stored )
- Model creation
model = l2l.vision.models.MiniImagenetCNN(ways) model.to(device) #sends model parameters to the GPU #Use high-level implementation of MAML algorithm maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False) #’model’ is the module to be wrapped, ‘lr’ is the fast adaptation learning # rate, ‘first_order’ represents whether to use first-order adaptation of #MAML #use Adam optimization algorithm using torch.optim.Adam() opt = optim.Adam(maml.parameters(), meta_lr) #maml.parameters() give the model parameters to be optimized and ‘meta_lr’ #is the learning rate loss = nn.CrossEntropyLoss(reduction='mean') #loss function
- Compute training and validation errors and accuracies
for i in range(num_iterations): #zero out the gradients first so that the parameters update correctly opt.zero_grad() train_error = 0.0 #training error train_accuracy = 0.0 #training accuracy valid_error = 0.0 #validation error valid_accuracy = 0.0 #validation accuracy for task in range(batch_size): #for each task in the batch # Compute meta-training loss learner = maml.clone() #maml.clone() returns a MAML-wrapped copy of the module.Its parameters and #buffers are torch.cloned from original module batch = tasksets.train.sample() #sample of the training taskset #Call the fast_adapt() method defined in step (3). It will return evaluation #error and accuracy evaluation_error, evaluation_accuracy = fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device) #torch.autograd.backward() computes the sum of gradients of the tensors evaluation_error.backward() #Training error and accuracy were initialized to zero. Change them to the #computed values train_error += evaluation_error.item() train_accuracy += evaluation_accuracy.item() # Compute meta-validation loss learner = maml.clone() batch = tasksets.validation.sample() #Call the fast_adapt() method defined in step (3). It will return evaluation #error and accuracy evaluation_error, evaluation_accuracy = fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device) #Validation error and accuracy were initialized to zero. Change them to #the computed values valid_error += evaluation_error.item() valid_accuracy += evaluation_accuracy.item() # Print the computed metrics for each iteration print('\n') print('Iteration', i) #iteration number print('Meta Train Error', train_error / batch_size) print('Meta Train Accuracy', train_accuracy / batch_size) print('Meta Valid Error', valid_error / batch_size) print('Meta Valid Accuracy', valid_accuracy / batch_size) # Average the accumulated gradients and optimize for para in maml.parameters(): #divide the gradients by batch size para.grad.data.mul_(1.0 / batch_size) #torch.optim.Adam.step() performs a single optimization step opt.step()
- Compute test and validation error and accuracy
#Initialize test error and test accuracy to zero before updating them test_error = 0.0 test_accuracy = 0.0 #Perform the calculations for each iteration (same process as done for #training and evaluation in step (7) for task in range(meta_batch_size): # Compute meta-testing loss learner = maml.clone() batch = tasksets.test.sample() evaluation_error, evaluation_accuracy = fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device) #Update the test error and accuracy test_error += evaluation_error.item() test_accuracy += evaluation_accuracy.item() #Print the test metrics print('Meta Test Error', test_error / batch_size) print('Meta Test Accuracy', test_accuracy / batch_size)
- Call the main() method
if __name__ == '__main__': main()
The output will show the train, test and validation metrics for all 100 iterations.
Sample output for the first 5 iterations:
Note: The output may vary every time you execute the code and also depending upon the execution environment you choose.
- Code source
- Google colab notebook for the above implementation
References
Refer to the following sources for a detailed understanding of the learn2learn library: