Now Reading
PyTorch Code for Self-Attention Computer Vision

PyTorch Code for Self-Attention Computer Vision

self-attention

As discussed in one of our articles, Self-Attention is gradually gaining prominent place from sequence modeling in natural language processing to Medical Image Segmentation. It replaces conventional recurrent neural networks and convolutional neural networks in many applications to achieve new state-of-the-art in respective fields. Transformers, its variants and extensions are well-utilizing self-attention mechanisms. 

Self-Attention Computer Vision, known technically as self_attention_cv, is a PyTorch based library providing a one-stop solution for all of the self-attention based requirements. It includes varieties of self-attention based layers and pre-trained models that can be simply employed in any custom architecture. Rather than building the self-attention layers or blocks from scratch, this library helps its users perform model building in no-time. On the other hand, the pre-trained heavy models such as TransUNet, ViT can be incorporated into custom models and can finish training in minimal time even in a CPU environment!  According to its contributors Adaloglou Nicolas and Sergios Karagiannakos, the library is still under development by updating the latest models and architectures.

REGISTER FOR OUR UPCOMING ML WORKSHOP

We will explore the self-attention-based layers, blocks, and pre-trained models available in the library with randomly generated simple inputs in the sequel. The setup requires the PyTorch environment and Python 3.6+ to train and evaluate the models.

The following command installs the self_attention_cv library and its dependencies.

!pip install self_attention_cv

Since the library and its modules run on top of the PyTorch framework, we need to import the framework.

import torch

Self-Attention based layers, blocks, models are provided as modules of the self_attention_cv library and they can be imported as per need.

Multi-head Attention

According to the authors of the paper, Attention Is All You Need,

An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values. The weight assigned to each value is computed by a compatibility function of the query with the corresponding key. 

self-attention
A single self-attention layer and a multi-head self-attention layer (Source).

Queries are obtained from the previous decoded layer and the memory keys and values are obtained from the encoded layer’s output. These parameters are the vectors/matrices of text tokens in case of text processing, image patches in case of image processing, and sequence of images in case of video processing. A multi-head self-attention layer consists of a number of single self-attention layers stacked in parallel. Transformers heavily rely on this multi-head self-attention layer in every stage of its architecture. The following codes demonstrate an example of multi-head self-attention modules with randomly generated tokens each of dimension 64.

 from self_attention_cv import MultiHeadSelfAttention
 model = MultiHeadSelfAttention(dim=64)
 x = torch.rand(16, 10, 64)  # [batch, tokens, dim]
 mask = torch.zeros(10, 10)  # tokens X tokens
 mask[5:8, 5:8] = 1
 y = model(x, mask)
 print('Shape of output is: ', y.shape)
 print('-'*70)
 print('Output corresponding to the first token/patch in the first batch \n')
 print(y.detach().numpy()[0][0]) 

Axial Attention

Axial attention is a special kind of self-attention layers collection incorporated in autoregressive models such as Axial Transformers that take high-dimensional data as input such as high-resolution images. The following codes demonstrate Axial attention block implementation with randomly generated image data of size 64 by 64.

 # Axial Attention
 from self_attention_cv import AxialAttentionBlock
 model = AxialAttentionBlock(in_channels=256, dim=64, heads=8)
 x = torch.rand(1, 256, 64, 64)  # [batch, tokens, dim, dim]
 y = model(x)
 print('Shape of output is: ', y.shape)
 print('-'*70)
 print('Output corresponding to the first token/patch in the first batch \n')
 print(y.detach().numpy()[0][0]) 

Bottleneck Attention block

Bottleneck Transformers employ multi-head self-attention layers in multiple computer vision tasks. The whole transformer block is available as a module in our library. The Bottleneck block is demonstrated in the following codes with randomly generated images of size 32 by 32.

 from self_attention_cv.bottleneck_transformer import BottleneckBlock
 x = torch.rand(1, 512, 32, 32)
 bottleneck_block = BottleneckBlock(in_channels=512, fmap_size=(32, 32), heads=4, out_channels=1024, pooling=True)
 y = bottleneck_block(x)
 print('Shape of output is: ', y.shape)
 print('-'*70)
 print('Output corresponding to the first patch in the first head, first batch \n')
 print(y.detach().numpy()[0][0][0]) 

Vanilla Transformer Encoder

The encoder part of base Transformer architecture can be simply obtained using the module TransformerEncoder. The following codes demonstrate usage of this module with randomly generated tokens of dimension 64. 

See Also

 # Transformer Encoder
 from self_attention_cv import TransformerEncoder
 model = TransformerEncoder(dim=64,blocks=6,heads=8)
 x = torch.rand(16, 10, 64)  # [batch, tokens, dim]
 mask = torch.zeros(10, 10)  # tokens X tokens
 mask[5:8, 5:8] = 1
 y = model(x,mask)
 print('Shape of output is: ', y.shape)
 print('-'*70)
 print('Output corresponding to the first token/patch in the first batch \n')
 print(y.detach().numpy()[0][0]) 

Vision Transformer

Vision Transformer (ViT) became popular with all kinds of computer vision tasks, achieving state-of-the-art performance in many applications at its publication time. Though few other latest architectures outperform ViT, most of them are built on top of it. The basic ViT is available as a module so that it can be simply used in any custom architecture. The following codes demonstrate the module’s usage with randomly generated 3-channel colored images of ize 256 by 256 in a 10-class classification problem. 

 from self_attention_cv import ViT
 model = ViT(img_dim=256, in_channels=3, patch_dim=16, num_classes=10,dim=512)
 x = torch.rand(2, 3, 256, 256)
 y = model(x) # [2,10]
 print('Shape of output is: ', y.shape)
 print('-'*70)
 print('Output corresponding to the first image \n')
 print(y.detach().numpy()[0]) 
self-attention

Vision Transformer with ResNet50 backbone

The Vision Transformer backed with ResNet performs greatly with many of the computer vision tasks. The following codes demonstrate the corresponding module’s usage with randomly generated 3-channel colored images of size 256 by 256 in a 10-class classification problem. 

 from self_attention_cv import ResNet50ViT
 model = ResNet50ViT(img_dim=256, pretrained_resnet=False, 
                         blocks=6, num_classes=10, 
                         dim_linear_block=256, dim=256)
 x = torch.rand(2, 3, 256, 256)
 y = model(x) # [2,10]
 print('Shape of output is: ', y.shape)
 print('-'*70)
 print('Output corresponding to the first image \n')
 print(y.detach().numpy()[0]) 

TransUNet

TransUNet is the present state-of-the-art architecture in Medical Image Segmentation tasks. This architecture is available as a module in the self_attention_cv library. The following codes demonstrate the module’s usage with randomly generated 3-channel colored images of dimensions 128 by 128. The output of the model built with this module corresponds to the dimensions of the input images. 

 from self_attention_cv.transunet import TransUnet
 x = torch.rand(2, 3, 128, 128)
 model = TransUnet(in_channels=3, img_dim=128, vit_blocks=8,
 vit_dim_linear_mhsa_block=512, classes=5)
 y = model(x) # [2, 5, 128, 128]
 print('Shape of output is: ', y.shape)
 print('-'*70)
 print('Output corresponding to the first image \n')
 print(y.detach().numpy()[0][0]) 
self-attention

1D Absolute Positional Embeddings

Two forms of positional embeddings are fed into a self-attention layer to denote memory vectors’ position, namely, absolute positioning and relative positioning. Position-aware self-attention models exhibit memory efficiency and improved performance. Self-attention Computer Vision library has separate modules for absolute and relative position embeddings for 1D and 2D sequential data. The following codes demonstrate application of 1-dimensional absolute positional embedding of tokens of dimension 64 with the corresponding module.

 from self_attention_cv.pos_embeddings import AbsPosEmb1D
 model = AbsPosEmb1D(tokens=20, dim_head=64)
 # batch heads tokens dim_head
 x = torch.rand(2, 3, 20, 64)
 y = model(x)
 print('Shape of output is: ', y.shape)
 print('-'*70)
 print('Output corresponding to the first token in the first head, first batch \n')
 print(y.detach().numpy()[0][0][0]) 
self-attention

1D Relative Positional Embeddings

Relative positional embedding helps greater performance in Neural Machine Translation compared to absolute positional embedding. The following codes demonstrate the application of 1-dimensional relative positional embedding of tokens of dimension 64 with the corresponding module.

 from self_attention_cv.pos_embeddings import RelPosEmb1D
 model = RelPosEmb1D(tokens=20, dim_head=64, heads=3)
 x = torch.rand(2, 3, 20, 64)
 y = model(x)
 print('Shape of output is: ', y.shape)
 print('-'*70)
 print('Output corresponding to the first token in the first head, first batch \n')
 print(y.detach().numpy()[0][0][0]) 

2D Relative Positional Embeddings

The following codes demonstrate the 2-dimensional relative positional embedding module usage with input feature map patches of dimension 32 by 32.

 from self_attention_cv.pos_embeddings import RelPosEmb2D
 dim = 32  # spatial dim of the feat map
 model = RelPosEmb2D(
     feat_map_size=(dim, dim),
     dim_head=128)
 x = torch.rand(2, 4, dim*dim, 128)
 y = model(x)
 print('Shape of output is: ', y.shape)
 print('-'*70)
 print('Output corresponding to the first patch in the first head, first batch \n')
 print(y.detach().numpy()[0][0][0]) 

References and Further reading:

What Do You Think?

Join Our Telegram Group. Be part of an engaging online community. Join Here.

Subscribe to our Newsletter

Get the latest updates and relevant offers by sharing your email.

Copyright Analytics India Magazine Pvt Ltd

Scroll To Top