Hands-on guide to using Vision transformer for Image classification

Vision transformer (ViT) is a transformer used in the field of computer vision that works based on the working nature of the transformers used in the field of natural language processing. Internally, the transformer learns by measuring the relationship between input token pairs. In computer vision, we can use the patches of images as the token.

Vision transformers are one of the popular transformers in the field of deep learning. Before the origin of the vision transformers, we had to use convolutional neural networks in computer vision for complex tasks. With the introduction of vision transformers, we got one more powerful model for computer vision tasks as we have BERT and GPT for complex NLP tasks. In this article, we will learn how can we use a vision transformer for an image classification task. For this purpose, we will demonstrate a hands-on implementation of a vision transformer for image classification. Following are the major points to be covered in this article. 

Table of contents

  1. About vision transformers
  2. Implementing vision transformer for image classification

Step 1: Initializing setup

Step 2: Building network


Sign up for your weekly dose of what's up in emerging technology.

Step 3 Building vision transformer 

Step 4: compile and train

Download our Mobile App

Let’s start with understanding the vision transformer first.

About vision transformers

Vision transformer (ViT) is a transformer used in the field of computer vision that works based on the working nature of the transformers used in the field of natural language processing. Internally, the transformer learns by measuring the relationship between input token pairs. In computer vision, we can use the patches of images as the token. This relationship can be learned by providing attention in the network. This can be done either in conjunction with a convolutional network or by replacing some components of convolutional networks. These structures of the network can be applied to the image classification tasks. The full procedure of image classification using a vision transformer can be explained by the following image.

Image source

In the above image, we can see the procedure we are required to follow. In this article, we are going to discuss how we can perform all these steps using the Keras library. 

For this implementation, we will take the following steps. 

Step 1: Initializing setup

 In this section, we will be performing some of the basic procedures of modelling like importing datasets, defining hyperparameters, data augmentation, etc.

Step 1.1: Importing data      

Let’s start by obtaining data. In this procedure, we are going to use the CIFAR-10 dataset provided by the Keras library. In the dataset, we have 50,000 images of size 32×32 in the training dataset and 10,000 images of the same size in the test dataset. We have the following labels with these images:

Index LabelDescription

We can call this dataset using the following lines of codes.

from keras.datasets import cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()


Checking the shape of the datasets.

print(f"x_train: {x_train.shape} - y_train: {y_train.shape}")
print(f"x_test: {x_test.shape} - y_test: {y_test.shape}")


Step 1.2: Defining hyperparameters

In this section, we will define some of the parameters that we will use with the other sub-processes.  

learning_rate = 0.001
weight_decay = 0.0001
batch_size = 256
num_epochs = 100
image_size = 72
patch_size = 6  
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [projection_dim * 2,projection_dim,] 
transformer_layers = 8
mlp_head_units = [2048, 1024]

Considering the above parameters, we can say that in the process we will be using 100 epochs in the training and will resize the image and convert the image into patches.

Now, we will call the important libraries.

Step 1.3: Data augmentation 

In the procedure, we will provide augmented images to the transformer. In the augmentation we will normalize and resize the images then we will randomly flip the images. This procedure will be completed in sequential methods and using the Keras provided layers.

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
data_augmentation = keras.Sequential(
        layers.Resizing(72, 72),
            height_factor=0.2, width_factor=0.2

In the final step of the augmentation, we will compute the mean and the variance of the training data for normalization.  

Step 1.4 visualizing images

 Let’s see how the images will look in the dataset.

import matplotlib.pyplot as plt
plt.figure(figsize=(4, 4))
image = x_train[np.random.choice(range(x_train.shape[0]))]


The above output is the example of an image in the dataset, since images in the data have a low size it is not clearly visible. Now we can proceed to our second step.

Step 2: Building network

In this step, we will be building a network where we will use an MLP network and a layer that will separate our images into patches. Also, we will use a patch encoder to transform the patches where it will project the patches into vectors of size 64. Let’s start by building an MLP network.

Step 2.1: Building MLP network 

def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

In the above codes, we can see we have built an MLP network that is simply having a dense layer and a dropout layer. 

Step 2.2 Patch maker

In this step, we will define a network that can convert the images into patches. For this, we mainly use the tensor flow provided extract_patches module.  

class Patches(layers.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size
    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches
n = int(np.sqrt(patches.shape[1]))
plt.figure(figsize=(4, 4))
for i, patch in enumerate(patches[0]):
    ax = plt.subplot(n, n, i + 1)
    patch_img = tf.reshape(patch, (patch_size, patch_size, 3))


In the above output, we can see that we have converted the images into patches using which the vision transformer will learn to classify the images.

Step 2.3: Patch encoder

This patch encoder will perform the linear transformation of the image patches and add a learnable position embedding to the projected vector.

class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded

After building this network we are ready to build a vision transformer model.

Step 3: Building vision transformer   

In this section, we will build blocks for the vision transformer. As discussed and implemented above we will use the augmented data that will go through the patch maker block and then the data will go through the patch encoder block. In the transformer block, we will use a self-attention layer on the patch sequences. Output from the transformer block will go through a classification head which will help in producing the final outputs. Let’s see in the below codes.

def create_vit_classifier():
    inputs = layers.Input(shape=input_shape)
    augmented = data_augmentation(inputs)
    patches = Patches(patch_size)(augmented)
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
    for _ in range(transformer_layers):
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        x2 = layers.Add()([attention_output, encoded_patches])
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        encoded_patches = layers.Add()([x3, x2])
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    logits = layers.Dense(num_classes)(features)
    model = keras.Model(inputs=inputs, outputs=logits)
    return model

Using the above function we can define a classifier using the vision transformer in which we have provided the methods for data augmentation, patch making, and patch encoding. Encoded patches will be our final input as image representation to the transformer. Flatten layer will help us to change the shape of the output.    

Step 4: Compiling and training

In this section we will; compile and train the model that we have created and after that, we will evaluate the model in terms of accuracy.

Step  4.1: Compiling the model 

Using the below line of codes we can compile the model.

optimizer = tfa.optimizers.AdamW(learning_rate=learning_rate, weight_decay=weight_decay)
       keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"), ],)

In the compilation, we have used Adam optimizer with sparse categorical cross-entropy loss.

Step  4.2: Training 

Training of the transformer can be done using the following lines of codes:

history = model.fit(


In the above output, we can see that the training has started. It may take a significant amount of time. So to do it fast, it is recommended to enable the GPU during the training. In Google Colab, we can find out the GPU setting in the manage runtime session tab under the runtime tab.

Step 4.3: Checking accuracy

Let’s check the accuracy of the vision transformer in the image classification task.

_, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")


Here in the above output, we can see that our model has a performance of 84.21% accuracy and our top 5 accuracies are 99.24%.  

Final words

In this article, we have introduced the transformer and seen how it can be used for image classification through an image. We implemented a vision transformer for image classification using the CIFAR-10 dataset and followed all the steps that were included in the image.  We achieved a very good result in the task using this transformer.


More Great AIM Stories

Yugesh Verma
Yugesh is a graduate in automobile engineering and worked as a data analyst intern. He completed several Data Science projects. He has a strong interest in Deep Learning and writing blogs on data science and machine learning.

AIM Upcoming Events

Early Bird Passes expire on 3rd Feb

Conference, in-person (Bangalore)
Rising 2023 | Women in Tech Conference
16-17th Mar, 2023

Conference, in-person (Bangalore)
Data Engineering Summit (DES) 2023
27-28th Apr, 2023

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox

Do machines feel pain?

Scientists worldwide have been finding ways to bring a sense of awareness to robots, including feeling pain, reacting to it, and withstanding harsh operating conditions.

IT professionals and DevOps say no to low-code

The obsession with low-code is led by its drag-and-drop interface, which saves a lot of time. In low-code, every single process is shown visually with the help of a graphical interface that makes everything easier to understand.

Neuralink elon musk

What could go wrong with Neuralink?

While the broad aim of developing such a BCI is to allow humans to be competitive with AI, Musk wants Neuralink to solve immediate problems like the treatment of Parkinson’s disease and brain ailments.

Understanding cybersecurity from machine learning POV 

Today, companies depend more on digitalisation and Internet-of-Things (IoT) after various security issues like unauthorised access, malware attack, zero-day attack, data breach, denial of service (DoS), social engineering or phishing surfaced at a significant rate.