Hands-on Vision Transformers with PyTorch

ViT breaks an input image of 16x16 to a sequence of patches, just like a series of word embeddings generated by an NLP Transformers. Each patch gets flattened into a single vector in a series of interconnected channels of all pixels in a patch, then projects it to desired input dimension.

With the rise in popularity of image and video analytics, the need for better and improved Convolution Neural Networks(CNNs) has been researched and implemented in industry to outperform several computer vision algorithms. We are moving towards an era of quantum computing and solving all the challenges related to the need for large computing power soon. Currently, all the vision tasks are trained on high-end GPUs/ TPUs and a massive amount of datasets.

Transformers(Vaswani et al.), architecture was introduced in 2017, which uses self-attention to accelerate the training process. It was primarily created to solve some of the core challenges in Natural language processing(NLP) related tasks. Unlike Recurrent Neural Networks(RNNs), which were widely used for NLP tasks and not good at parallelization, transformers architecture outperform the RNNs architectures with less training time.


Visual transformers(VTs) are in recent research and moving the barrier to outperform the CNN models for several vision tasks. CNN architectures give equal weightage to all the pixels and thus have an issue of learning the essential features of an image. In the paper by Bichen Wu, Chenfeng Xu, Xiaoliang Dai, Alvin Wan, Peizhao Zhang, Zhicheng Yan, Masayoshi Tomizuka, Joseph Gonzalez, Kurt Keutzer, Peter Vajda: “Visual Transformers: Token-based Image Representation and Processing for Computer Vision”, 2020; [http://arxiv.org/abs/2006.03677 arXiv:2006.03677], they have shown in place of convolutions, VTs apply transformers to relate semantic concepts in token-space directly.

Subscribe to our Newsletter

Join our editors every weekday evening as they steer you through the most significant news of the day, introduce you to fresh perspectives, and provide unexpected moments of joy
Your newsletter subscriptions are subject to AIM Privacy Policy and Terms and Conditions.

You can check the Pytorch implementation here

In this article, we are going to learn and implement one of the recent paper, Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby: “An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale”, 2020; [http://arxiv.org/abs/2010.11929 arXiv:2010.11929] by Google Research teams on Vision transformers(ViT).

Source: https://ai.googleblog.com/2020/12/transformers-for-image-recognition-at.html

How it works

ViT breaks an input image of 16×16 to a  sequence of patches, just like a series of word embeddings generated by an NLP Transformers. Each patch gets flattened into a single vector in a series of interconnected channels of all pixels in a patch, then projects it to desired input dimension. Because transformers operate in self-attention mode, and they do not necessarily depend on the structure of the input elements, which in turns helps the architecture to learn and relate sparsely-distributed information more efficiently. In Vit, the relationship between the patches in an image is not known and thus allows it to learn more relevant features from the training data and encode in positional embedding in ViT. 



We can download the dataset from the above link.


We will be implementing the Vision Transformers with PyTorch. 

Install the ViT PyTorch package and Linformer 

pip install vit-pytorch linformer

# loading Libraries

import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt   

# import Linformer

from linformer import Linformer   
import glob   
from PIL import Image
from itertools import chain   
from vit_pytorch.efficient import ViT   
from tqdm.notebook import tqdm   
from __future__ import print_function

# import torch and related libraries 

import torch   
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms   
from torch.optim.lr_scheduler import StepLR   
from torch.utils.data import DataLoader, Dataset

#to unzip the datasets
import zipfile   

#sklearn to split the data

from sklearn.model_selection import train_test_split  

#definining batch size, epocs, learning rate and gamma for training  

batch_size = 64
epochs = 20
lr = 3e-5
gamma = 0.7 #for learning rate scheduler 

#Load data

os.makedirs('data', exist_ok=True)
train_dir = 'data/train'
test_dir = 'data/test'

#Unzipping dataset
with zipfile.ZipFile('train.zip') as train_zip:
with zipfile.ZipFile('test.zip') as test_zip:

#Creating train and test list 

train_list = glob.glob(os.path.join(train_dir,'*.jpg'))
test_list = glob.glob(os.path.join(test_dir, '*.jpg'))

#printing length of the dataset

print(f"Train Data: {len(train_list)}")
print(f"Test Data: {len(test_list)}")

#Defining labels

labels = [path.split('/')[-1].split('.')[0] for path in train_list]

# printing few images 

random_idx = np.random.randint(1, len(train_list), size=9)
fig, axes = plt.subplots(3, 3, figsize=(16, 12))
for idx, ax in enumerate(axes.ravel()):
    img = Image.open(train_list[idx])

#Splitting train and validation list

train_list, valid_list = train_test_split(train_list, 
print(f"Train Data: {len(train_list)}")
print(f"Validation Data: {len(valid_list)}")
print(f"Test Data: {len(test_list)}")

# Torch transforms

train_transforms = transforms.Compose(
        transforms.Resize((224, 224)),
val_transforms = transforms.Compose(
        transforms.Resize((224, 224)),
test_transforms = transforms.Compose(
        transforms.Resize((224, 224)),

#Loading dataset for training 

class CatsDogsDataset(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform
    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength
    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        img = Image.open(img_path)
        img_transformed = self.transform(img)
        label = img_path.split("/")[-1].split(".")[0]
        label = 1 if label == "dog" else 0
        return img_transformed, label

#defining train, validation and test dataset

train_data = CatsDogsDataset(train_list, transform=train_transforms)
valid_data = CatsDogsDataset(valid_list, transform=test_transforms)
test_data = CatsDogsDataset(test_list, transform=test_transforms)

#loading dataloader

train_loader = DataLoader(dataset = train_data, batch_size=batch_size, shuffle=True )
valid_loader = DataLoader(dataset = valid_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset = test_data, batch_size=batch_size, shuffle=True)

#Line transformer 

efficient_transformer = Linformer(
    seq_len=49+1,  # 7x7 patches + 1 cls-token

#Visual transformer 

model = ViT(

# loss function

criterion = nn.CrossEntropyLoss()

# optimizer

optimizer = optim.Adam(model.parameters(), lr=lr)

# scheduler

scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

#start training

for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0
    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)
        output = model(data)
        loss = criterion(output, label)
        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)
    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)
            val_output = model(data)
            val_loss = criterion(val_output, label)
            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"

Please check the full code with the explanation here


In this article, we have discussed Visual Transformers, which are going to be the mainstream in the object detection or image classification task in computer vision task. You can also check 15 other paper implementations here with TensorFlow, PyTorch and Jax/Flax.

Krishna Rastogi
Krishna currently working as an Associate Director at ADaSci. He has 6+ experience research & development, cutting edge engineering to develop products from idea to deployment. He comes with expertise in building deep learning computer vision applications using both hardware and software solutions in several domains. His interests are the domain of distributed learning and Edge AI.

Download our Mobile App

MachineHack | AI Hackathons, Coding & Learning

Host Hackathons & Recruit Great Data Talent!

AIMResearch Pioneering advanced AI market research

With a decade of experience under our belt, we are transforming how businesses use AI & data-driven insights to succeed.

The Gold Standard for Recognizing Excellence in Data Science and Tech Workplaces

With Best Firm Certification, you can effortlessly delve into the minds of your employees, unveil invaluable perspectives, and gain distinguished acclaim for fostering an exceptional company culture.

AIM Leaders Council

World’s Biggest Community Exclusively For Senior Executives In Data Science And Analytics.

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox