Now Reading
Hands-on Vision Transformers with PyTorch

Hands-on Vision Transformers with PyTorch

Krishna Rastogi

With the rise in popularity of image and video analytics, the need for better and improved Convolution Neural Networks(CNNs) has been researched and implemented in industry to outperform several computer vision algorithms. We are moving towards an era of quantum computing and solving all the challenges related to the need for large computing power soon. Currently, all the vision tasks are trained on high-end GPUs/ TPUs and a massive amount of datasets.

Transformers(Vaswani et al.), architecture was introduced in 2017, which uses self-attention to accelerate the training process. It was primarily created to solve some of the core challenges in Natural language processing(NLP) related tasks. Unlike Recurrent Neural Networks(RNNs), which were widely used for NLP tasks and not good at parallelization, transformers architecture outperform the RNNs architectures with less training time.


Visual transformers(VTs) are in recent research and moving the barrier to outperform the CNN models for several vision tasks. CNN architectures give equal weightage to all the pixels and thus have an issue of learning the essential features of an image. In the paper by Bichen Wu, Chenfeng Xu, Xiaoliang Dai, Alvin Wan, Peizhao Zhang, Zhicheng Yan, Masayoshi Tomizuka, Joseph Gonzalez, Kurt Keutzer, Peter Vajda: “Visual Transformers: Token-based Image Representation and Processing for Computer Vision”, 2020; [ arXiv:2006.03677], they have shown in place of convolutions, VTs apply transformers to relate semantic concepts in token-space directly.

You can check the Pytorch implementation here

In this article, we are going to learn and implement one of the recent paper, Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby: “An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale”, 2020; [ arXiv:2010.11929] by Google Research teams on Vision transformers(ViT).


How it works

ViT breaks an input image of 16×16 to a  sequence of patches, just like a series of word embeddings generated by an NLP Transformers. Each patch gets flattened into a single vector in a series of interconnected channels of all pixels in a patch, then projects it to desired input dimension. Because transformers operate in self-attention mode, and they do not necessarily depend on the structure of the input elements, which in turns helps the architecture to learn and relate sparsely-distributed information more efficiently. In Vit, the relationship between the patches in an image is not known and thus allows it to learn more relevant features from the training data and encode in positional embedding in ViT. 


We can download the dataset from the above link.


We will be implementing the Vision Transformers with PyTorch. 

Install the ViT PyTorch package and Linformer 

pip install vit-pytorch linformer

# loading Libraries

import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt   

# import Linformer

from linformer import Linformer   
import glob   
from PIL import Image
from itertools import chain   
from vit_pytorch.efficient import ViT   
from tqdm.notebook import tqdm   
from __future__ import print_function

# import torch and related libraries 

import torch   
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms   
from torch.optim.lr_scheduler import StepLR   
from import DataLoader, Dataset

#to unzip the datasets
import zipfile   

#sklearn to split the data

from sklearn.model_selection import train_test_split  

#definining batch size, epocs, learning rate and gamma for training  

batch_size = 64
epochs = 20
lr = 3e-5
gamma = 0.7 #for learning rate scheduler 

#Load data

os.makedirs('data', exist_ok=True)
train_dir = 'data/train'
test_dir = 'data/test'

#Unzipping dataset
with zipfile.ZipFile('') as train_zip:
with zipfile.ZipFile('') as test_zip:

#Creating train and test list 

train_list = glob.glob(os.path.join(train_dir,'*.jpg'))
test_list = glob.glob(os.path.join(test_dir, '*.jpg'))

#printing length of the dataset

print(f"Train Data: {len(train_list)}")
print(f"Test Data: {len(test_list)}")

#Defining labels

labels = [path.split('/')[-1].split('.')[0] for path in train_list]

# printing few images 

random_idx = np.random.randint(1, len(train_list), size=9)
fig, axes = plt.subplots(3, 3, figsize=(16, 12))
for idx, ax in enumerate(axes.ravel()):
    img =[idx])

#Splitting train and validation list

train_list, valid_list = train_test_split(train_list, 
print(f"Train Data: {len(train_list)}")
print(f"Validation Data: {len(valid_list)}")
print(f"Test Data: {len(test_list)}")

# Torch transforms

train_transforms = transforms.Compose(
        transforms.Resize((224, 224)),
val_transforms = transforms.Compose(
        transforms.Resize((224, 224)),
test_transforms = transforms.Compose(
        transforms.Resize((224, 224)),

#Loading dataset for training 

class CatsDogsDataset(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform
    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength
    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        img =
        img_transformed = self.transform(img)
        label = img_path.split("/")[-1].split(".")[0]
        label = 1 if label == "dog" else 0
        return img_transformed, label

#defining train, validation and test dataset

See Also

train_data = CatsDogsDataset(train_list, transform=train_transforms)
valid_data = CatsDogsDataset(valid_list, transform=test_transforms)
test_data = CatsDogsDataset(test_list, transform=test_transforms)

#loading dataloader

train_loader = DataLoader(dataset = train_data, batch_size=batch_size, shuffle=True )
valid_loader = DataLoader(dataset = valid_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset = test_data, batch_size=batch_size, shuffle=True)

#Line transformer 

efficient_transformer = Linformer(
    seq_len=49+1,  # 7x7 patches + 1 cls-token

#Visual transformer 

model = ViT(

# loss function

criterion = nn.CrossEntropyLoss()

# optimizer

optimizer = optim.Adam(model.parameters(), lr=lr)

# scheduler

scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

#start training

for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0
    for data, label in tqdm(train_loader):
        data =
        label =
        output = model(data)
        loss = criterion(output, label)
        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)
    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data =
            label =
            val_output = model(data)
            val_loss = criterion(val_output, label)
            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"

Please check the full code with the explanation here


In this article, we have discussed Visual Transformers, which are going to be the mainstream in the object detection or image classification task in computer vision task. You can also check 15 other paper implementations here with TensorFlow, PyTorch and Jax/Flax.

What Do You Think?

Subscribe to our Newsletter

Get the latest updates and relevant offers by sharing your email.
Join Our Telegram Group. Be part of an engaging online community. Join Here.

Copyright Analytics India Magazine Pvt Ltd

Scroll To Top