Hands-on Vision Transformers with PyTorch

ViT breaks an input image of 16x16 to a sequence of patches, just like a series of word embeddings generated by an NLP Transformers. Each patch gets flattened into a single vector in a series of interconnected channels of all pixels in a patch, then projects it to desired input dimension.

With the rise in popularity of image and video analytics, the need for better and improved Convolution Neural Networks(CNNs) has been researched and implemented in industry to outperform several computer vision algorithms. We are moving towards an era of quantum computing and solving all the challenges related to the need for large computing power soon. Currently, all the vision tasks are trained on high-end GPUs/ TPUs and a massive amount of datasets.

Transformers(Vaswani et al.), architecture was introduced in 2017, which uses self-attention to accelerate the training process. It was primarily created to solve some of the core challenges in Natural language processing(NLP) related tasks. Unlike Recurrent Neural Networks(RNNs), which were widely used for NLP tasks and not good at parallelization, transformers architecture outperform the RNNs architectures with less training time.


Visual transformers(VTs) are in recent research and moving the barrier to outperform the CNN models for several vision tasks. CNN architectures give equal weightage to all the pixels and thus have an issue of learning the essential features of an image. In the paper by Bichen Wu, Chenfeng Xu, Xiaoliang Dai, Alvin Wan, Peizhao Zhang, Zhicheng Yan, Masayoshi Tomizuka, Joseph Gonzalez, Kurt Keutzer, Peter Vajda: “Visual Transformers: Token-based Image Representation and Processing for Computer Vision”, 2020; [ arXiv:2006.03677], they have shown in place of convolutions, VTs apply transformers to relate semantic concepts in token-space directly.


Sign up for your weekly dose of what's up in emerging technology.

You can check the Pytorch implementation here

In this article, we are going to learn and implement one of the recent paper, Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby: “An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale”, 2020; [ arXiv:2010.11929] by Google Research teams on Vision transformers(ViT).

Download our Mobile App


How it works

ViT breaks an input image of 16×16 to a  sequence of patches, just like a series of word embeddings generated by an NLP Transformers. Each patch gets flattened into a single vector in a series of interconnected channels of all pixels in a patch, then projects it to desired input dimension. Because transformers operate in self-attention mode, and they do not necessarily depend on the structure of the input elements, which in turns helps the architecture to learn and relate sparsely-distributed information more efficiently. In Vit, the relationship between the patches in an image is not known and thus allows it to learn more relevant features from the training data and encode in positional embedding in ViT. 


We can download the dataset from the above link.


We will be implementing the Vision Transformers with PyTorch. 

Install the ViT PyTorch package and Linformer 

pip install vit-pytorch linformer

# loading Libraries

import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt   

# import Linformer

from linformer import Linformer   
import glob   
from PIL import Image
from itertools import chain   
from vit_pytorch.efficient import ViT   
from tqdm.notebook import tqdm   
from __future__ import print_function

# import torch and related libraries 

import torch   
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms   
from torch.optim.lr_scheduler import StepLR   
from import DataLoader, Dataset

#to unzip the datasets
import zipfile   

#sklearn to split the data

from sklearn.model_selection import train_test_split  

#definining batch size, epocs, learning rate and gamma for training  

batch_size = 64
epochs = 20
lr = 3e-5
gamma = 0.7 #for learning rate scheduler 

#Load data

os.makedirs('data', exist_ok=True)
train_dir = 'data/train'
test_dir = 'data/test'

#Unzipping dataset
with zipfile.ZipFile('') as train_zip:
with zipfile.ZipFile('') as test_zip:

#Creating train and test list 

train_list = glob.glob(os.path.join(train_dir,'*.jpg'))
test_list = glob.glob(os.path.join(test_dir, '*.jpg'))

#printing length of the dataset

print(f"Train Data: {len(train_list)}")
print(f"Test Data: {len(test_list)}")

#Defining labels

labels = [path.split('/')[-1].split('.')[0] for path in train_list]

# printing few images 

random_idx = np.random.randint(1, len(train_list), size=9)
fig, axes = plt.subplots(3, 3, figsize=(16, 12))
for idx, ax in enumerate(axes.ravel()):
    img =[idx])

#Splitting train and validation list

train_list, valid_list = train_test_split(train_list, 
print(f"Train Data: {len(train_list)}")
print(f"Validation Data: {len(valid_list)}")
print(f"Test Data: {len(test_list)}")

# Torch transforms

train_transforms = transforms.Compose(
        transforms.Resize((224, 224)),
val_transforms = transforms.Compose(
        transforms.Resize((224, 224)),
test_transforms = transforms.Compose(
        transforms.Resize((224, 224)),

#Loading dataset for training 

class CatsDogsDataset(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform
    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength
    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        img =
        img_transformed = self.transform(img)
        label = img_path.split("/")[-1].split(".")[0]
        label = 1 if label == "dog" else 0
        return img_transformed, label

#defining train, validation and test dataset

train_data = CatsDogsDataset(train_list, transform=train_transforms)
valid_data = CatsDogsDataset(valid_list, transform=test_transforms)
test_data = CatsDogsDataset(test_list, transform=test_transforms)

#loading dataloader

train_loader = DataLoader(dataset = train_data, batch_size=batch_size, shuffle=True )
valid_loader = DataLoader(dataset = valid_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset = test_data, batch_size=batch_size, shuffle=True)

#Line transformer 

efficient_transformer = Linformer(
    seq_len=49+1,  # 7x7 patches + 1 cls-token

#Visual transformer 

model = ViT(

# loss function

criterion = nn.CrossEntropyLoss()

# optimizer

optimizer = optim.Adam(model.parameters(), lr=lr)

# scheduler

scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

#start training

for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0
    for data, label in tqdm(train_loader):
        data =
        label =
        output = model(data)
        loss = criterion(output, label)
        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)
    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data =
            label =
            val_output = model(data)
            val_loss = criterion(val_output, label)
            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"

Please check the full code with the explanation here


In this article, we have discussed Visual Transformers, which are going to be the mainstream in the object detection or image classification task in computer vision task. You can also check 15 other paper implementations here with TensorFlow, PyTorch and Jax/Flax.

More Great AIM Stories

Krishna Rastogi
Krishna currently working as an Associate Director at ADaSci. He has 6+ experience research & development, cutting edge engineering to develop products from idea to deployment. He comes with expertise in building deep learning computer vision applications using both hardware and software solutions in several domains. His interests are the domain of distributed learning and Edge AI.

AIM Upcoming Events

Regular Passes expire on 3rd Mar

Conference, in-person (Bangalore)
Rising 2023 | Women in Tech Conference
16-17th Mar, 2023

Early Bird Passes expire on 17th Feb

Conference, in-person (Bangalore)
Data Engineering Summit (DES) 2023
27-28th Apr, 2023

Conference, Virtual
Deep Learning DevCon 2023
27 May, 2023

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox