Complete Guide To VIT-AugReg: A PyTorch Image Model Descriptive Predictions

This article tried to explore and understand Image Models and how they work. We also learned about a descriptive PyTorch Image Model known as VIT, where we implemented AugReg to create an Image Label Predictor.

In today’s world, we live in a generation where we tend to generate and produce vast amounts of data every day. It won’t be wrong to describe these times as an era of big data, where all areas of science and industry thrive upon masses of data and its related technologies. Although this also confronts us with unprecedented challenges regarding their analysis and interpretation. This is the sole reason that there seems to be an urgent need for novel and self-aware machine learning and artificial intelligence methods that can help in utilizing the data properly. Deep learning (DL) is a method that is currently receiving much attention. DL can be described as a family of learning algorithms that can be used to make our systems learn complex prediction models. Deep learning is and has been successfully applied to several application problems. 

Deep learning models present us with a new learning paradigm in artificial intelligence (AI) and Machine Learning (ML). The recent breakthrough results gained in image analysis and speech recognition have also aided in generating massive interest in this field, as applications in many other domains inculcating big data seem possible. The mathematical and computational methodology underlying deep learning models is challenging, especially for interdisciplinary scientists. These models form the major core architectures of deep learning models currently being used and should always belong in a data scientist’s toolbox. The core architectural building blocks can be composed flexibly to build new application-specific network architectures. 

Data analysis varies from company to company depending upon the needs, so the data model must always be designed to meet the requirements. Predictive modeling is a major subpart of data analytics that uses data mining and probabilistic methods to predict results. Each model is built using many predictors that make them highly favourable to determine future decisions. Once the data is received for a specific prediction type, an analytical model is formulated. Then, simple linear equations or a complex neural structure can be further applied and are outlined by a concerned software. If in case, additional data is available, then the analytical model is revised. Predictive Modeling also uses different regression algorithms and analytics or statistics to estimate the probability of an event, applying detection theory and largely employed in the fields related to Artificial Intelligence (AI). 

PyTorch is an optimized tensor library primarily used for Deep Learning applications that use GPUs and CPUs to enhance the processing power. It is an open-source machine learning library for Python, developed by the Facebook AI Research team and is one of the widely used Machine learning libraries, others being TensorFlow and Keras.

What Are Image Transformers?

Image Transformer is a model-dependent entirely on the self-attention mechanism, where the encoder generates a per-pixel-channel representation of the source image. Despite comparatively low resources required for training, the Image Transformer models are usually trained on images from the standard ImageNet data set. Many applications of image models require conditioning on additional information of various kinds: from images in enhancement or reconstruction tasks such as superresolution, in-painting and denoising to text when synthesizing images from natural language descriptions. 

In visualization tasks, image generation models using Transformers can predict future frames of video based upon the previous frames and taken actions. Image Transformers treat pixel intensities as either discrete categories or ordinal values, where the setting is subjective and depends on the distribution of image data. For both the image encoder and decoder during image preprocessing, the Image Transformer uses multiple stacks of self-attention and position-wise feed-forward layers. The decoder uses an attention mechanism to take the encoder representation as an input.

For parametric and conditional image preprocessing using the Image Transformer, a decoder only configuration is used. Each self-attention layer computes a D dimensional representation for each position, that is, each channel of each image pixel. To recompute the representation for a given position, it first compares the position’s current representation to other positions’ representations, obtaining an attention distribution over the other positions. This distribution is then used to judge the contribution of the other positions’ representations to the next representation for the position at hand. Vision Transformers (ViT) have been shown to attain highly competitive performance for a wide range of computer vision and image-based applications, such as image classification, object detection and semantic image segmentation. Vision Transformer is generally found to have an increased reliance on model regularization or data augmentation, also known as “AugReg”, for short when training on smaller training datasets.

Image Source: Original Paper

Getting Started with the Code for VIT-AugReg

This article will try to generate a descriptive prediction from an image dataset using the VIT library. We will try to predict the dog breed in the image and provide it with a label using Vision Transformer. The following code is inspired by the library’s creators, whose Github link can be accessed here

Installing the Library

To create this prediction model, we will first install the Vision Transformer. The following code can be used to do so,

# Install the vision_transformer Library.
![ -d vision_transformer ] || git clone --depth=1

Install the further dependencies,

# Install dependencies.
!pip install -qr vision_transformer/vit_jax/requirements.txt
    |████████████████████████████████| 57 kB 2.8 MB/s 
     |████████████████████████████████| 76 kB 4.5 MB/s 
     |████████████████████████████████| 179 kB 24.6 MB/s 
     |████████████████████████████████| 88 kB 7.8 MB/s 
     |████████████████████████████████| 168.3 MB 15 kB/s 

Importing the AugReg Model,

import sys
if './vision_transformer' not in sys.path:
%load_ext autoreload
%autoreload 2
from vit_jax import checkpoint
from vit_jax import models
from vit_jax import train
from vit_jax.configs import augreg as augreg_config
from vit_jax.configs import models as models_config

Importing dependencies for analysis,

import glob
import os
import random
import shutil
import time
from absl import logging
import pandas as pd
import seaborn as sns
import tensorflow as tf
import tensorflow_datasets as tfds
from matplotlib import pyplot as plt
pd.options.display.max_colwidth = None

Loading the Data

Now we will start loading our data table from the VIT Cloud.

# Load master table from Cloud.
with'gs://vit_models/augreg/index.csv') as f:
  df = pd.read_csv(f)

# List the rows and columns

print(f'loaded {len(df):,} rows')

#print length of dataset
len(set(df.filename)), len(set(df.adapt_filename))

# loading the dataset checkpoint
best_filenames = set(
    .apply(lambda df: df.sort_values('final_val').iloc[-1])
# Fine Tuning these models.
best_df = df.loc[df.filename.apply(lambda filename: filename in best_filenames)]

Now that all the essential model and model data checkpoints are loaded, we can create the predictor model.

Creating The Predictor Model

We will start creating the predictor model by loading our pet image dataset first. The following code can be used to do so,

#loading the image dataset
ds, ds_info = tfds.load(tfds_name, with_info=True)

# Get model instance
model = models.VisionTransformer(
    num_classes=ds_info.features['label'].num_classes, **model_config)

Now we will import a single random image from the pets dataset perform our prediction on,

d = next(iter(ds['test']))

#display random image
def pp(img, sz):
  img = tf.cast(img, float) / 255.0
  img = tf.image.resize(img, [sz, sz])
  return img
plt.imshow(pp(d['image'], resolution));


# Applying the VIT-AugReg model on image
logits, = model.apply({'params': params}, [pp(d['image'], resolution)], train=False)

# Plotting the label probabilities.
plt.figure(figsize=(10, 4))['label'].int2str, range(len(logits)))), logits)

Output :

As we can see, the model has predicted the dog breed to be Leonberger. So let’s now compare the result with an image of the Leonberger dog.

Image Source

 As we can observe, our predictor model seems to have correctly predicted the dog breed label for our sample image!

The created Vision Transformer can also be incorporated into other Pytorch Image models as well. Let’s try it with the Timm Model,

# Installing the timm model library
!pip install timm
import timm
import torch

# Loading the model into timm
timm_model = timm.create_model(
    'vit_small_r26_s32_384', num_classes=ds_info.features['label'].num_classes)
if not'{filename}.npz'):'gs://vit_models/augreg/{filename}.npz', f'{filename}.npz')
timm.models.load_checkpoint(timm_model, f'{filename}.npz')

Processing the image into the model,

#loading the image into model
def pp_torch(img, sz):
  img = pp(img, sz)
  img = img.numpy().transpose([2, 0, 1]) 
  return torch.tensor(img[None])
with torch.no_grad():
  logits, = timm_model(pp_torch(d['image'], resolution)).detach().numpy()

# Visualizing results for Timm
plt.figure(figsize=(10, 4))['label'].int2str, range(len(logits)))), logits)

We can observe that our model still gives us a correctly predicted label.


This article tried to explore and understand Image Models and how they work. We also learned about a descriptive PyTorch Image Model known as VIT, where we implemented AugReg to create an Image Label Predictor. The following implementation can be found as a Colab notebook which can be accessed using the link here.  

Happy Learning!


Download our Mobile App

Victor Dey
Victor is an aspiring Data Scientist & is a Master of Science in Data Science & Big Data Analytics. He is a Researcher, a Data Science Influencer and also an Ex-University Football Player. A keen learner of new developments in Data Science and Artificial Intelligence, he is committed to growing the Data Science community.

Subscribe to our newsletter

Join our editors every weekday evening as they steer you through the most significant news of the day.
Your newsletter subscriptions are subject to AIM Privacy Policy and Terms and Conditions.

Our Recent Stories

Our Upcoming Events

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox

6 IDEs Built for Rust

Rust IDEs aid efficient code development by offering features like code completion, syntax highlighting, linting, debugging tools, and code refactoring