MITB Banner

Complete Guide to T2T-ViT: Training Vision Transformers Efficiently with Minimal Data

T2T-ViT employs progressive tokenization that takes patches of an image and converts it into an overlapped-token over a few iterations

Share

T2T-ViT

Apart from language modeling tasks, transformers recently show good success in computer vision tasks too. The computer vision tasks in which transformers outperform CNNs include image classification, object detection, denoising, medical image segmentation, super-resolution and deraining. Vision Transformer, famously known as ViT, has found its prominent place in the image classification tasks. Some other vision models are developed on top of the vanilla ViT such as ResNet50-backed ViT and TransUNet for some task-specific applications. 

With large training images and compute power, vanilla ViT outperforms all of the CNN models such as ResNets. However, the vanilla ViT fails to perform over conventional CNN models when the training data is limited. Hence, pre-trained ViT is preferred for deployment in real-world applications with fine-tuning as per the task. This drawback of ViT limits its performance or usage in various applications whose domains are quite different from the domain on which the employed pre-trained-ViT is originally trained.

The two major issues of the ViT model are the tokenization methodology and the attention-backbone design. The ViT breaks an input image into patches and stacks them successively to form tokens. This approach lacks the model to learn the edges and lines in the input images. On the other hand, the attention-backbone design has redundancies that lack the model to attain feature richness. These drawbacks are relatively overcome by training the ViT model with a huge image dataset.

To this end, Li Yuan, Tao Wang, Weihao Yu, Yujun Shi, Zihang Jiang, Francis E.H. Tay and Jiashi Feng of National University of Singapore, and Yunpeng Chen and Shuicheng Yan of YITU Technology have introduced T2T-ViT, the Tokens-to-Token Vision Transformer, that is developed to be free of the ViT’s fundamental issues. This model employs a new way of progressive tokenization that takes the patches of an original image as tokens and converts it into an overlapped-token over a few iterations. This approach enables the model to grasp local feature details such as lines, textures and edges. Before feeding the feature map into every transformer layer, it is tokenized via a tokens-to-token module. 

The feature map of ViT at each block has zero values in a few channels. This is due to the improper architecture design of the attention-backbone in ViT. T2T-ViT implements deep-narrow layers in its transformer blocks that are quite good at capturing features that the vanilla ViT. This helps the model to employ relatively few parameters in the model. This robust design of T2T-ViT permits the model to train even on mid-sized data such as ImageNet from scratch. 

T2T-ViT
Comparison of T2T-ViT model with vanilla ViT and ResNet based on the model complexity (parameters), the multiply-addition operations (MACs) and the top-1 accuracy score (Source)

T2T-ViT is compute-efficient with a fraction of MACs as that of vanilla ViT. it is also employing around a quarter of the parameters that vanilla ViT employs. Nevertheless, T2T-ViT demonstrates an extraordinary Top-1 accuracy score of 82.5% while the top-version of vanilla ViT achieves a maximum score of 78.1% with an equivalent compute-power of 10 Giga MACs!

T2T-ViT
The function of a Tokens-to-Token (T2T) module (Source)

A T2T module performs its processes in two steps: re-structurization and soft split. During the re-structurization, the transformed tokens from the previous transformer block are flattened as sequential patches and then reshaped as a square-block of patches. During the soft split, the reshaped token is split as overlapped patches of short-tokens that carry information of its neighbourhood. Thus the local context information is passed to the next block of the transformer without any loss.

The overall architecture of the T2T-ViT for image classification (Source)

Python Implementation

T2T-ViT architecture requires a PyTorch environment with a single-node multiple GPU (4 or 8 GPUs) runtime to train, evaluate and infer. It should be noted that training may take several hours based on the device configuration. The following commands install the dependencies.

 !pip install timm==0.3.4
 !pip install torch
 !pip install torchvision
 !pip install pyyaml 

ImageNet dataset must be downloaded to the working directory (more than 140GB). Or users may prefer to use pre-downloaded data in the Kaggle platform. It can be noted that Kaggle supports Cloud ML deployment on-site.

Once downloaded, extract the ImageNet training dataset (138GB) using the following commands.

 %%bash
 # Extract the training data:
 #
 mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train
 tar -xvf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar
 find . -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xvf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done
 cd .. 

Extract the ImageNet training dataset (6.3GB) using the following commands.

 %%bash
 # Extract the validation data and move images to subfolders:
 #
 mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xvf ILSVRC2012_img_val.tar
 wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash 

Download the source code from the official repository using the following command.

!git clone https://github.com/yitu-opensource/T2T-ViT.git

Output:

Change the directory to refer the downloaded files,

 %cd /content/T2T-ViT/
 !ls 

Output:

The following codes develop a pre-trained model via transfer learning.

 from models.t2t_vit import *
 from utils import load_for_transfer_learning 
 # create the model
 model = T2t_vit_14()
 # load the pre-trained weights
 load_for_transfer_learning(model, /path/to/pretrained/weights, use_ema=True, strict=False, num_classes=1000) 

The pre-trained T2T-ViT can also be downloaded from the source page and saved in the working directory. The following command tests the downloaded model with available data.

!CUDA_VISIBLE_DEVICES=0 python main.py path/to/data --model T2t_vit_14 -b 100 --eval_checkpoint path/to/checkpoint

Train the built model with training data using the following command. It should be noted that the training may take several hours on a single-noded 8 GPUs. 

!CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 path/to/data --model T2t_vit_14 -b 64 --lr 5e-4 --weight-decay .05 --amp --img-size 224

Once trained, the model can be tested with unseen data. Or the fully-trained model can be deployed for image classification applications such as the CIFAR-10 via transfer learning. For instance, the following command enables transfer learning to perform the CIFAR-10 image classification. 

!CUDA_VISIBLE_DEVICES=0,1 transfer_learning.py --lr 0.05 --b 64 --num-classes 10 --img-size 224 --transfer-learning True --transfer-model /path/to/pretrained/T2T-ViT-19

Performance of T2T-ViT

T2T-ViT is trained on the ImageNet from scratch on a multiple-GPU node. Other competing models are trained under identical compute-configurations on the ImageNet from scratch. 

T2T-ViT
Feature map extraction in different models at different layers/blocks. ViT develops a few invalid channels while T2T-ViT extracts fine details of the original image very well than ResNet or ViT (Source).

T2T-ViT outperforms well-acclaimed models, AlexNet, VGG11, Inception v3, ResNet50, DenseNet201, SeNet50, ResNeXt50, ResNet50-Ghost, vanilla ViT and its variants greatly on top-1 accuracy score. Also, T2T-ViT consumes less compute power based on MACs. 

T2T-ViT
Informative attention mapping in T2T-ViT at earlier layers (Source)

Successful transformer variants and extensions in the computer vision domain may arrive at efficient and improved models in the future by incorporating the key modules and approaches of T2T-ViT architecture.

Further reading

  1. Original research paper
  2. Source code repository
  3. ImageNet dataset
Share
Picture of Rajkumar Lakshmanamoorthy

Rajkumar Lakshmanamoorthy

A geek in Machine Learning with a Master's degree in Engineering and a passion for writing and exploring new things. Loves reading novels, cooking, practicing martial arts, and occasionally writing novels and poems.
Related Posts

CORPORATE TRAINING PROGRAMS ON GENERATIVE AI

Generative AI Skilling for Enterprises

Our customized corporate training program on Generative AI provides a unique opportunity to empower, retain, and advance your talent.

Upcoming Large format Conference

May 30 and 31, 2024 | 📍 Bangalore, India

Download the easiest way to
stay informed

Subscribe to The Belamy: Our Weekly Newsletter

Biggest AI stories, delivered to your inbox every week.

AI Courses & Careers

Become a Certified Generative AI Engineer

AI Forum for India

Our Discord Community for AI Ecosystem, In collaboration with NVIDIA. 

Flagship Events

Rising 2024 | DE&I in Tech Summit

April 4 and 5, 2024 | 📍 Hilton Convention Center, Manyata Tech Park, Bangalore

MachineCon GCC Summit 2024

June 28 2024 | 📍Bangalore, India

MachineCon USA 2024

26 July 2024 | 583 Park Avenue, New York

Cypher India 2024

September 25-27, 2024 | 📍Bangalore, India

Cypher USA 2024

Nov 21-22 2024 | 📍Santa Clara Convention Center, California, USA

Data Engineering Summit 2024

May 30 and 31, 2024 | 📍 Bangalore, India

Subscribe to Our Newsletter

The Belamy, our weekly Newsletter is a rage. Just enter your email below.