Hands-on guide to using Vision transformer for Image classification

Vision transformer (ViT) is a transformer used in the field of computer vision that works based on the working nature of the transformers used in the field of natural language processing. Internally, the transformer learns by measuring the relationship between input token pairs. In computer vision, we can use the patches of images as the token.

Vision transformers are one of the popular transformers in the field of deep learning. Before the origin of the vision transformers, we had to use convolutional neural networks in computer vision for complex tasks. With the introduction of vision transformers, we got one more powerful model for computer vision tasks as we have BERT and GPT for complex NLP tasks. In this article, we will learn how can we use a vision transformer for an image classification task. For this purpose, we will demonstrate a hands-on implementation of a vision transformer for image classification. Following are the major points to be covered in this article. 

Table of contents

  1. About vision transformers
  2. Implementing vision transformer for image classification

Step 1: Initializing setup


Sign up for your weekly dose of what's up in emerging technology.

Step 2: Building network

Step 3 Building vision transformer 

Step 4: compile and train

Let’s start with understanding the vision transformer first.

About vision transformers

Vision transformer (ViT) is a transformer used in the field of computer vision that works based on the working nature of the transformers used in the field of natural language processing. Internally, the transformer learns by measuring the relationship between input token pairs. In computer vision, we can use the patches of images as the token. This relationship can be learned by providing attention in the network. This can be done either in conjunction with a convolutional network or by replacing some components of convolutional networks. These structures of the network can be applied to the image classification tasks. The full procedure of image classification using a vision transformer can be explained by the following image.

Image source

In the above image, we can see the procedure we are required to follow. In this article, we are going to discuss how we can perform all these steps using the Keras library. 

For this implementation, we will take the following steps. 

Step 1: Initializing setup

 In this section, we will be performing some of the basic procedures of modelling like importing datasets, defining hyperparameters, data augmentation, etc.

Step 1.1: Importing data      

Let’s start by obtaining data. In this procedure, we are going to use the CIFAR-10 dataset provided by the Keras library. In the dataset, we have 50,000 images of size 32×32 in the training dataset and 10,000 images of the same size in the test dataset. We have the following labels with these images:

Index LabelDescription

We can call this dataset using the following lines of codes.

from keras.datasets import cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()


Checking the shape of the datasets.

print(f"x_train: {x_train.shape} - y_train: {y_train.shape}")
print(f"x_test: {x_test.shape} - y_test: {y_test.shape}")


Step 1.2: Defining hyperparameters

In this section, we will define some of the parameters that we will use with the other sub-processes.  

learning_rate = 0.001
weight_decay = 0.0001
batch_size = 256
num_epochs = 100
image_size = 72
patch_size = 6  
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [projection_dim * 2,projection_dim,] 
transformer_layers = 8
mlp_head_units = [2048, 1024]

Considering the above parameters, we can say that in the process we will be using 100 epochs in the training and will resize the image and convert the image into patches.

Now, we will call the important libraries.

Step 1.3: Data augmentation 

In the procedure, we will provide augmented images to the transformer. In the augmentation we will normalize and resize the images then we will randomly flip the images. This procedure will be completed in sequential methods and using the Keras provided layers.

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
data_augmentation = keras.Sequential(
        layers.Resizing(72, 72),
            height_factor=0.2, width_factor=0.2

In the final step of the augmentation, we will compute the mean and the variance of the training data for normalization.  

Step 1.4 visualizing images

 Let’s see how the images will look in the dataset.

import matplotlib.pyplot as plt
plt.figure(figsize=(4, 4))
image = x_train[np.random.choice(range(x_train.shape[0]))]


The above output is the example of an image in the dataset, since images in the data have a low size it is not clearly visible. Now we can proceed to our second step.

Step 2: Building network

In this step, we will be building a network where we will use an MLP network and a layer that will separate our images into patches. Also, we will use a patch encoder to transform the patches where it will project the patches into vectors of size 64. Let’s start by building an MLP network.

Step 2.1: Building MLP network 

def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

In the above codes, we can see we have built an MLP network that is simply having a dense layer and a dropout layer. 

Step 2.2 Patch maker

In this step, we will define a network that can convert the images into patches. For this, we mainly use the tensor flow provided extract_patches module.  

class Patches(layers.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size
    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches
n = int(np.sqrt(patches.shape[1]))
plt.figure(figsize=(4, 4))
for i, patch in enumerate(patches[0]):
    ax = plt.subplot(n, n, i + 1)
    patch_img = tf.reshape(patch, (patch_size, patch_size, 3))


In the above output, we can see that we have converted the images into patches using which the vision transformer will learn to classify the images.

Step 2.3: Patch encoder

This patch encoder will perform the linear transformation of the image patches and add a learnable position embedding to the projected vector.

class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded

After building this network we are ready to build a vision transformer model.

Step 3: Building vision transformer   

In this section, we will build blocks for the vision transformer. As discussed and implemented above we will use the augmented data that will go through the patch maker block and then the data will go through the patch encoder block. In the transformer block, we will use a self-attention layer on the patch sequences. Output from the transformer block will go through a classification head which will help in producing the final outputs. Let’s see in the below codes.

def create_vit_classifier():
    inputs = layers.Input(shape=input_shape)
    augmented = data_augmentation(inputs)
    patches = Patches(patch_size)(augmented)
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
    for _ in range(transformer_layers):
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        x2 = layers.Add()([attention_output, encoded_patches])
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        encoded_patches = layers.Add()([x3, x2])
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    logits = layers.Dense(num_classes)(features)
    model = keras.Model(inputs=inputs, outputs=logits)
    return model

Using the above function we can define a classifier using the vision transformer in which we have provided the methods for data augmentation, patch making, and patch encoding. Encoded patches will be our final input as image representation to the transformer. Flatten layer will help us to change the shape of the output.    

Step 4: Compiling and training

In this section we will; compile and train the model that we have created and after that, we will evaluate the model in terms of accuracy.

Step  4.1: Compiling the model 

Using the below line of codes we can compile the model.

optimizer = tfa.optimizers.AdamW(learning_rate=learning_rate, weight_decay=weight_decay)
       keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"), ],)

In the compilation, we have used Adam optimizer with sparse categorical cross-entropy loss.

Step  4.2: Training 

Training of the transformer can be done using the following lines of codes:

history = model.fit(


In the above output, we can see that the training has started. It may take a significant amount of time. So to do it fast, it is recommended to enable the GPU during the training. In Google Colab, we can find out the GPU setting in the manage runtime session tab under the runtime tab.

Step 4.3: Checking accuracy

Let’s check the accuracy of the vision transformer in the image classification task.

_, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")


Here in the above output, we can see that our model has a performance of 84.21% accuracy and our top 5 accuracies are 99.24%.  

Final words

In this article, we have introduced the transformer and seen how it can be used for image classification through an image. We implemented a vision transformer for image classification using the CIFAR-10 dataset and followed all the steps that were included in the image.  We achieved a very good result in the task using this transformer.


More Great AIM Stories

Yugesh Verma
Yugesh is a graduate in automobile engineering and worked as a data analyst intern. He completed several Data Science projects. He has a strong interest in Deep Learning and writing blogs on data science and machine learning.

Our Upcoming Events

Conference, in-person (Bangalore)
Machine Learning Developers Summit (MLDS) 2023
19-20th Jan, 2023

Conference, in-person (Bangalore)
Rising 2023 | Women in Tech Conference
16-17th Mar, 2023

Conference, in-person (Bangalore)
Data Engineering Summit (DES) 2023
27-28th Apr, 2023

Conference, in-person (Bangalore)
MachineCon 2023
23rd Jun, 2023

3 Ways to Join our Community

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Telegram Channel

Discover special offers, top stories, upcoming events, and more.

Subscribe to our newsletter

Get the latest updates from AIM