With the rise in popularity of image and video analytics, the need for better and improved Convolution Neural Networks(CNNs) has been researched and implemented in industry to outperform several computer vision algorithms. We are moving towards an era of quantum computing and solving all the challenges related to the need for large computing power soon. Currently, all the vision tasks are trained on high-end GPUs/ TPUs and a massive amount of datasets.
Transformers(Vaswani et al.), architecture was introduced in 2017, which uses self-attention to accelerate the training process. It was primarily created to solve some of the core challenges in Natural language processing(NLP) related tasks. Unlike Recurrent Neural Networks(RNNs), which were widely used for NLP tasks and not good at parallelization, transformers architecture outperform the RNNs architectures with less training time.
Introduction
Visual transformers(VTs) are in recent research and moving the barrier to outperform the CNN models for several vision tasks. CNN architectures give equal weightage to all the pixels and thus have an issue of learning the essential features of an image. In the paper by Bichen Wu, Chenfeng Xu, Xiaoliang Dai, Alvin Wan, Peizhao Zhang, Zhicheng Yan, Masayoshi Tomizuka, Joseph Gonzalez, Kurt Keutzer, Peter Vajda: “Visual Transformers: Token-based Image Representation and Processing for Computer Vision”, 2020; [http://arxiv.org/abs/2006.03677 arXiv:2006.03677], they have shown in place of convolutions, VTs apply transformers to relate semantic concepts in token-space directly.
Subscribe to our Newsletter
Join our editors every weekday evening as they steer you through the most significant news of the day, introduce you to fresh perspectives, and provide unexpected moments of joy
Your newsletter subscriptions are subject to AIM Privacy Policy and Terms and Conditions.
You can check the Pytorch implementation here.
In this article, we are going to learn and implement one of the recent paper, Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby: “An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale”, 2020; [http://arxiv.org/abs/2010.11929 arXiv:2010.11929] by Google Research teams on Vision transformers(ViT).
Source: https://ai.googleblog.com/2020/12/transformers-for-image-recognition-at.html
How it works
ViT breaks an input image of 16×16 to a sequence of patches, just like a series of word embeddings generated by an NLP Transformers. Each patch gets flattened into a single vector in a series of interconnected channels of all pixels in a patch, then projects it to desired input dimension. Because transformers operate in self-attention mode, and they do not necessarily depend on the structure of the input elements, which in turns helps the architecture to learn and relate sparsely-distributed information more efficiently. In Vit, the relationship between the patches in an image is not known and thus allows it to learn more relevant features from the training data and encode in positional embedding in ViT.
Dataset
https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/data
We can download the dataset from the above link.
Code
We will be implementing the Vision Transformers with PyTorch.
Install the ViT PyTorch package and Linformer
pip install vit-pytorch linformer
# loading Libraries
import os import random import numpy as np import pandas as pd import matplotlib.pyplot as plt
# import Linformer
from linformer import Linformer import glob from PIL import Image from itertools import chain from vit_pytorch.efficient import ViT from tqdm.notebook import tqdm from __future__ import print_function
# import torch and related libraries
import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torchvision import datasets, transforms from torch.optim.lr_scheduler import StepLR from torch.utils.data import DataLoader, Dataset #to unzip the datasets import zipfile
#sklearn to split the data
from sklearn.model_selection import train_test_split
#definining batch size, epocs, learning rate and gamma for training
batch_size = 64 epochs = 20 lr = 3e-5 gamma = 0.7 #for learning rate scheduler
#Load data
os.makedirs('data', exist_ok=True) train_dir = 'data/train' test_dir = 'data/test' #Unzipping dataset with zipfile.ZipFile('train.zip') as train_zip: train_zip.extractall('data') with zipfile.ZipFile('test.zip') as test_zip: test_zip.extractall('data') #Creating train and test list train_list = glob.glob(os.path.join(train_dir,'*.jpg')) test_list = glob.glob(os.path.join(test_dir, '*.jpg')) #printing length of the dataset print(f"Train Data: {len(train_list)}") print(f"Test Data: {len(test_list)}")
#Defining labels
labels = [path.split('/')[-1].split('.')[0] for path in train_list]
# printing few images
random_idx = np.random.randint(1, len(train_list), size=9) fig, axes = plt.subplots(3, 3, figsize=(16, 12)) for idx, ax in enumerate(axes.ravel()): img = Image.open(train_list[idx]) ax.set_title(labels[idx]) ax.imshow(img)
#Splitting train and validation list
train_list, valid_list = train_test_split(train_list, test_size=0.2, stratify=labels, random_state=seed) print(f"Train Data: {len(train_list)}") print(f"Validation Data: {len(valid_list)}") print(f"Test Data: {len(test_list)}")
# Torch transforms
train_transforms = transforms.Compose( [ transforms.Resize((224, 224)), transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ] ) val_transforms = transforms.Compose( [ transforms.Resize((224, 224)), transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ] ) test_transforms = transforms.Compose( [ transforms.Resize((224, 224)), transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ] )
#Loading dataset for training
class CatsDogsDataset(Dataset): def __init__(self, file_list, transform=None): self.file_list = file_list self.transform = transform def __len__(self): self.filelength = len(self.file_list) return self.filelength def __getitem__(self, idx): img_path = self.file_list[idx] img = Image.open(img_path) img_transformed = self.transform(img) label = img_path.split("/")[-1].split(".")[0] label = 1 if label == "dog" else 0 return img_transformed, label
#defining train, validation and test dataset
train_data = CatsDogsDataset(train_list, transform=train_transforms) valid_data = CatsDogsDataset(valid_list, transform=test_transforms) test_data = CatsDogsDataset(test_list, transform=test_transforms)
#loading dataloader
train_loader = DataLoader(dataset = train_data, batch_size=batch_size, shuffle=True ) valid_loader = DataLoader(dataset = valid_data, batch_size=batch_size, shuffle=True) test_loader = DataLoader(dataset = test_data, batch_size=batch_size, shuffle=True)
#Line transformer
efficient_transformer = Linformer( dim=128, seq_len=49+1, # 7x7 patches + 1 cls-token depth=12, heads=8, k=64 )
#Visual transformer
model = ViT( dim=128, image_size=224, patch_size=32, num_classes=2, transformer=efficient_transformer, channels=3, ).to(device)
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
#start training
for epoch in range(epochs): epoch_loss = 0 epoch_accuracy = 0 for data, label in tqdm(train_loader): data = data.to(device) label = label.to(device) output = model(data) loss = criterion(output, label) optimizer.zero_grad() loss.backward() optimizer.step() acc = (output.argmax(dim=1) == label).float().mean() epoch_accuracy += acc / len(train_loader) epoch_loss += loss / len(train_loader) with torch.no_grad(): epoch_val_accuracy = 0 epoch_val_loss = 0 for data, label in valid_loader: data = data.to(device) label = label.to(device) val_output = model(data) val_loss = criterion(val_output, label) acc = (val_output.argmax(dim=1) == label).float().mean() epoch_val_accuracy += acc / len(valid_loader) epoch_val_loss += val_loss / len(valid_loader) print( f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n" )
Please check the full code with the explanation here.
Conclusion
In this article, we have discussed Visual Transformers, which are going to be the mainstream in the object detection or image classification task in computer vision task. You can also check 15 other paper implementations here with TensorFlow, PyTorch and Jax/Flax.