Redundancy is a recurring problem in Large Neural Networks Architectures. Redundancy in the rank of parameters makes the models extremely huge while not contributing to models’ performance. Another kind of redundancy that is more prevalent in Self Supervised learning models is a redundancy of the representations, i.e. Some of the final representations’ elements are trivial constants. This kind of redundancy in representations is often detrimental to the performance and efficiency of models. Moreover, biological processes like sensory processing reveal that it is efficient to recode the highly redundant inputs into factorial codes(codes with statistically independent components).
Barlow twins is a novel architecture inspired by the redundancy reduction principle in the work of neuroscientist H. Barlow. It is conceptually simple and easy to implement an architecture that learns high-quality representations from noisy image data.This model was proposed by Jure Zbontar, Li Jing, Ishan Misra, Yann LeCun, Stéphane Deny on 4th March 2021.
Architecture of Barlow Twins
Barlow Twins, as the name suggests, contains two identical networks. These backbone networks consist of a ResNet50 block followed by three linear layers. Each of these networks is fed different distorted versions of the same image. Output representations of these networks are then compared, and a unique loss function is used to ensure non-trivial non-constant representations.
Loss function in Barlow Twins
Loss Function is calculated on a correlation matrix between the components of representations.
If zA and zB are the batch normalized outputs of the twin networks, then the correlation matrix is given by
Then loss is calculated using
The first term tries to make the networks robust to noise, whereas the second term tries to make the representation components independent. Using this loss function, we are training the model to output representations whose correlation matrix is an identity matrix. This loss function performs well with increasing numbers of dimensions in representations.
Implementation
Following code, snippets are excerpts from this implementation.
class BarlowTwins(nn.Module): def __init__(self, args): super().__init__() self.args = args self.backbone = torchvision.models.resnet50(zero_init_residual=True) self.backbone.fc = nn.Identity() # projector sizes = [2048] + list(map(int, args.projector.split('-'))) layers = [] for i in range(len(sizes) - 2): layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False)) layers.append(nn.BatchNorm1d(sizes[i + 1])) layers.append(nn.ReLU(inplace=True)) layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False)) self.projector = nn.Sequential(*layers) # normalization layer for the representations z1 and z2 self.bn = nn.BatchNorm1d(sizes[-1], affine=False) def forward(self, y1, y2): z1 = self.projector(self.backbone(y1)) z2 = self.projector(self.backbone(y2)) # empirical cross-correlation matrix c = self.bn(z1).T @ self.bn(z2) # use --scale-loss to multiply the loss by a constant factor on_diag = torch.diagonal(c).add_(-1).pow_(2).sum().mul(self.args.scale_loss) off_diag = off_diagonal(c).pow_(2).sum().mul(self.args.scale_loss) loss = on_diag + self.args.lambd * off_diag return loss
The two distorted images are obtained using the following transforms
class Transform: def __init__(self): self.transform = transforms.Compose([ transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomApply( [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], p=0.8 ), transforms.RandomGrayscale(p=0.2), GaussianBlur(p=1.0), Solarization(p=0.0), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) self.transform_prime = transforms.Compose([ transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomApply( [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], p=0.8 ), transforms.RandomGrayscale(p=0.2), GaussianBlur(p=0.1), Solarization(p=0.2), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def __call__(self, x): y1 = self.transform(x) y2 = self.transform_prime(x) return y1, y2
Usage Code
Pytorch implementation of Barlow twins is made available by Facebook research in their GitHub repository. This model takes a long time to train(7 days on 16 V100 GPUs). Fortunately, pretrained weights of the model on ImageNet are available. Let’s see how to use this model for image classification.
Loading Pretrained Model
import torch model = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50')
Preprocessing Input Images
from torchvision import datasets, transforms import torch val_dataset = datasets.ImageFolder('/content/data/cats_and_dogs_filtered/validation', transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ])) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32)
Inference
preds=[] model.eval() for images,labels in tqdm(val_loader): preds.extend(model(images).detach().numpy()) predicted_labels=list(map(lambda x:label_map[x],np.argmax(np.array(preds),axis=1)))