A Guide to Barlow Twins: Self-Supervised Learning via Redundancy Reduction

Barlow twins is a novel architecture inspired by the redundancy reduction principle in the work of neuroscientist H. Barlow.
Barlow Twins

Advertisement

Redundancy is a recurring problem in Large Neural Networks Architectures. Redundancy in the rank of parameters makes the models extremely huge while not contributing to models’ performance. Another kind of redundancy that is more prevalent in Self Supervised learning models is a redundancy of the representations, i.e. Some of the final representations’ elements are trivial constants. This kind of redundancy in representations is often detrimental to the performance and efficiency of models. Moreover, biological processes like sensory processing reveal that it is efficient to recode the highly redundant inputs into factorial codes(codes with statistically independent components).

Barlow twins is a novel architecture inspired by the redundancy reduction principle in the work of neuroscientist H. Barlow. It is conceptually simple and easy to implement an architecture that learns high-quality representations from noisy image data.This model was proposed by Jure Zbontar, Li Jing, Ishan Misra, Yann LeCun, Stéphane Deny on 4th March 2021.

THE BELAMY

Sign up for your weekly dose of what's up in emerging technology.

Architecture of Barlow Twins

Barlow Twins, as the name suggests, contains two identical networks. These backbone networks consist of a ResNet50 block followed by three linear layers. Each of these networks is fed different distorted versions of the same image. Output representations of these networks are then compared, and a unique loss function is used to ensure non-trivial non-constant representations.

Loss function in Barlow Twins

Loss Function is calculated on a correlation matrix between the components of representations.

If zA and zB are the batch normalized outputs of the twin networks, then the correlation matrix is given by

Then loss is calculated using

The first term tries to make the networks robust to noise, whereas the second term tries to make the representation components independent. Using this loss function, we are training the model to output representations whose correlation matrix is an identity matrix. This loss function performs well with increasing numbers of dimensions in representations.

Implementation

Following code, snippets are excerpts from this implementation.

 class BarlowTwins(nn.Module):
     def __init__(self, args):
         super().__init__()
         self.args = args
         self.backbone = torchvision.models.resnet50(zero_init_residual=True)
         self.backbone.fc = nn.Identity()
         # projector
         sizes = [2048] + list(map(int, args.projector.split('-')))
         layers = []
         for i in range(len(sizes) - 2):
             layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
             layers.append(nn.BatchNorm1d(sizes[i + 1]))
             layers.append(nn.ReLU(inplace=True))
         layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
         self.projector = nn.Sequential(*layers)
         # normalization layer for the representations z1 and z2
         self.bn = nn.BatchNorm1d(sizes[-1], affine=False)
     def forward(self, y1, y2):
         z1 = self.projector(self.backbone(y1))
         z2 = self.projector(self.backbone(y2))
         # empirical cross-correlation matrix
         c = self.bn(z1).T @ self.bn(z2)
         # use --scale-loss to multiply the loss by a constant factor
         on_diag = torch.diagonal(c).add_(-1).pow_(2).sum().mul(self.args.scale_loss)
         off_diag = off_diagonal(c).pow_(2).sum().mul(self.args.scale_loss)
         loss = on_diag + self.args.lambd * off_diag
         return loss 

The two distorted images are obtained using the following transforms

 class Transform:
     def __init__(self):
         self.transform = transforms.Compose([
             transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC),
             transforms.RandomHorizontalFlip(p=0.5),
             transforms.RandomApply(
                 [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                         saturation=0.2, hue=0.1)],
                 p=0.8
             ),
             transforms.RandomGrayscale(p=0.2),
             GaussianBlur(p=1.0),
             Solarization(p=0.0),
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225])
         ])
         self.transform_prime = transforms.Compose([
             transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC),
             transforms.RandomHorizontalFlip(p=0.5),
             transforms.RandomApply(
                 [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                         saturation=0.2, hue=0.1)],
                 p=0.8
             ),
             transforms.RandomGrayscale(p=0.2),
             GaussianBlur(p=0.1),
             Solarization(p=0.2),
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225])
         ])
     def __call__(self, x):
         y1 = self.transform(x)
         y2 = self.transform_prime(x)
         return y1, y2 

Usage Code

Pytorch implementation of Barlow twins is made available by Facebook research in their GitHub repository. This model takes a long time to train(7 days on 16 V100 GPUs). Fortunately, pretrained weights of the model on ImageNet are available. Let’s see how to use this model for image classification.

Loading Pretrained Model

 import torch
 model = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50') 

Preprocessing Input Images

 from torchvision import datasets, transforms
 import torch
 val_dataset = datasets.ImageFolder('/content/data/cats_and_dogs_filtered/validation', transforms.Compose([
             transforms.Resize(256),
             transforms.CenterCrop(224),
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224, 0.225]),
         ]))
 val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32) 

Inference

 preds=[]
 model.eval()
 for images,labels in tqdm(val_loader):
   preds.extend(model(images).detach().numpy())
 predicted_labels=list(map(lambda x:label_map[x],np.argmax(np.array(preds),axis=1))) 

References

More Great AIM Stories

Pavan Kandru
AI enthusiast with a flair for NLP. I love playing with exotic data.

Our Upcoming Events

Conference, in-person (Bangalore)
MachineCon 2022
24th Jun

Conference, Virtual
Deep Learning DevCon 2022
30th Jul

Conference, in-person (Bangalore)
Cypher 2022
21-23rd Sep

3 Ways to Join our Community

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Telegram Channel

Discover special offers, top stories, upcoming events, and more.

Subscribe to our newsletter

Get the latest updates from AIM
MORE FROM AIM