The recent success in self-supervised models can be attributed in the renewed interest of the researchers in exploring contrastive learning, a paradigm of self-supervised learning. For instance, humans can identify objects in the wild even if we do not recollect what the object exactly looks like.
We do this by remembering high-level features and ignoring the details at the microscopic level. So, now the question is can we build representation learning algorithms that do not concentrate on pixel-level details, and only encode high-level features sufficient enough to distinguish different objects? With contrastive learning, researchers are trying to address this.
Recently, even Google’s SimCLR demonstrated the implications of contrastive learning, which we will briefly go into at the end of this article.
Subscribe to our Newsletter
Join our editors every weekday evening as they steer you through the most significant news of the day, introduce you to fresh perspectives, and provide unexpected moments of joy
Your newsletter subscriptions are subject to AIM Privacy Policy and Terms and Conditions.
Principle Of Contrastive Learning
Contrastive learning is an approach to formulate the task of finding similar and dissimilar things for an ML model. Using this approach, one can train a machine learning model to classify between similar and dissimilar images.
The inner working of contrastive learning can be formulated as a score function, which is a metric that measures the similarity between two features.
Here
x+ is data point similar to x, referred to as a positive sample
x− is a data point dissimilar to x, referred to as a negative sample
Over this, a softmax classifier can be built that classifies positive and negative samples correctly. A similar application of this technique can be found in the recently introduced framework SimCLR.
Applying Contrastive Learning
Google has introduced a framework called “SimCLR” that uses contrastive learning. This framework first learns generic representations of images on an unlabeled dataset and then is fine-tuned with a small dataset of labelled images for a given classification task.
The basic representations are learned by simultaneously maximising agreement between different versions or views of the same image and cutting down the difference using contrastive learning.
When the parameters of a neural network are updated using this contrastive objective causes representations of corresponding views to “attract” each other, while representations of non-corresponding views “repel” each other.
A finer explanation of the original paper was given in this blog.
The procedure is as follows:
- First, generate batches of a certain size, say N from the raw images
- For each image in this batch, a random transformation function is applied to get a pair of two images
- Each augmented image in a pair is passed through an encoder to get image representations.
- The representations of the two augmented images are then passed through a non-linear dense layer followed by a ReLU, which is then followed by another dense layer. These images are passed over a series of these layers to apply non-linear transformation and project it into a representation
- For each augmented image in the batch, get an embedding vector.
Now, the similarity between two augmented versions of an image is calculated using cosine similarity. SimCLR uses “NT-Xent loss” (Normalised Temperature-Scaled Cross-Entropy Loss), which is known as contrastive loss.
First, the augmented pairs in the batch are taken one by one. Later a softmax function is applied to find the probability of these two images being similar.
As shown above, the softmax function can be used to calculate how similar the two augmented cat images are and all remaining images in the batch are sampled as dissimilar images (negative pair).
Based on the loss, the encoder and projection head representations improve over time, and the representations obtained place similar images closer in the space.
The results from SimCLR showed that it outperformed previous self-supervised methods on ImageNet.