Since the inception of AlexNet, convolution neural networks have been the main design paradigm for computer vision tasks. In recent years, attention-based neural networks have caused a slew of breakthroughs in natural language processing tasks. Inspired by this success, several architectures incorporate the attention mechanism within convnets, such as ResNeSt and SE-Net. However, none of these architectures has been able to outperform the CNNs.
Only vision transformers (ViT) have been able to achieve state-of-the-art performance on ImageNet without using convolution. Even ViT was only able to achieve this when trained with a large private labelled image dataset using extensive computing resources. In their paper, “Training data-efficient image transformers & distillation through attention”, Hugo Touvron, Matthieu Cord, et al. proposed a convolution-free transformer network, DeiT, that achieves top-1 accuracy of 83.1% on ImageNet with no external data. DeiT introduces a new teacher-student strategy specific to transformers that relies on a distillation token, similar to the class token already employed in transformer networks.
Architecture & Approach
DeiT builds upon the ViT transformer block. It uses a simple architecture that processes input images as a sequence of input tokens. The RGB image is decomposed into a batch of N patches of size 16 x 16 pixels. These patches are then projected with a linear layer that conserves the overall dimensions (3 x 16 x 16 = 768). This transformer block is invariant to the input patch embeddings’ order and does not consider their relative position. The positional embeddings are added to the input patch tokens along with the learnable class token before the first transformer block. DeiT uses a linear classifier for pre-training instead of the MLP head used in ViT.
Subscribe to our Newsletter
Join our editors every weekday evening as they steer you through the most significant news of the day, introduce you to fresh perspectives, and provide unexpected moments of joy
Your newsletter subscriptions are subject to AIM Privacy Policy and Terms and Conditions.
Distillation through attention
Soft distillation minimizes the Kullback-Leibler divergence between the teacher model’s softmax and the student model’s softmax. Let Zt define the logits of the teacher model, Zs define the logits of the student model. The distillation objective is given by:
Here τ denotes the distillation temperature, ψ the softmax function, and λ denotes the coefficient that balances the contribution of Kullback–Leibler divergence loss (KL) and the cross-entropy loss (LCE).
DeiT introduced a variant of distillation where the teacher’s hard decision is taken as the true label,i.e., the teacher prediction yt plays is used instead of the true label y. Let yt = argmaxc Zt(c) denote the hard decision of the teacher; the objective associated with this hard-label distillation is given by:
The hard labels can easily be converted into soft labels with label smoothing. The true label’s probability is considered as 1 − ε and the remaining ε is shared across the remaining classes.
DieT introduces a new token, the distillation token, that is used similarly as the class token. This token is added to the initial embeddings, the patches and class tokens, before the first transformer block. It interacts with the other embeddings through self-attention and is output by the network. The objective of the distillation token is to reproduce the hard label yt produced by the teacher network, it enables the model to learn from the teacher’s output while remaining complementary to the class token that is tasked to reproduce the true label y. And much like the class embedding, the distillation embedding is learned by the transformer network through back-propagation.
In terms of the trade-off between precision and throughput, the vision transformer produced using this distillation process is on par with the best convolution networks. Convnet teachers produce better results than transformers because of the inductive bias inherited by the transformers through distillation. Using RegNetY-16GF as the teacher creates the best DeiT model that achieves top-1 accuracy of 83.1%. Interestingly, the distilled model outperforms its teacher, RegNetY-16GF, which reaches 82.9% top-1 accuracy on ImageNet.
Image Classification with a pre-trained DeiT model
Following code implementation is a reference to this Colab Notebook provided by DeiT developers.
- Install PyTorch Image Models (timm)
!pip install timm==0.3.2
- Download ImageNet class labels and create a list.
!wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt # Read the ImageNet categories with open("imagenet_classes.txt", "r") as f: imagenet_categories = [s.strip() for s in f.readlines()]
- Import necessary libraries and classes
from PIL import Image import requests import matplotlib.pyplot as plt %config InlineBackend.figure_format = 'retina' import torch import timm import torchvision import torchvision.transforms as T from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD torch.set_grad_enabled(False);
- Create the data transform expected by DeiT
transform = T.Compose([ T.Resize(256, interpolation=3), T.CenterCrop(224), T.ToTensor(), T.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), ])
- Load the pre-trained model from TorchHub and get an image to perform inference on.
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True) model.eval(); url = 'http://images.cocodataset.org/val2017/000000039769.jpg' im = Image.open(requests.get(url, stream=True).raw) # display the image im
- Transform the image and perform inference.
# transform the original image and add a batch dimension img = transform(im).unsqueeze(0) # compute the predictions out = model(img) # and convert them into probabilities scores = torch.nn.functional.softmax(out, dim=-1)[0] # get the index of the prediction with highest score topk_scores, topk_label = torch.topk(scores, k=5, dim=-1) for i in range(5): pred_name = imagenet_categories[topk_label[i]] print(f"Prediction index {i}: {pred_name:<25}, score: {topk_scores[i].item():.3f}")
Last Epoch
This article introduced DeiT, image transformers that do not require a very large amount of data to be trained thanks to improved training methods and a novel distillation procedure. Convolution neural networks have had a decade’s worth of architectural and other optimizations; on the other hand, DeiT is one of the first optimizations introduced to the Vision Transformers. And image transformers are already on par with convnets. With their smaller memory footprint for a given accuracy, transformers are poised to become the method of choice.
References
For a more in-depth understanding of the improved training method and the new distillation token, please refer to the following resources:
Want to visualize the decision process of DeiT? Check out this article.