MITB Banner

Compute Relevancy Of Transformer Networks Via Novel Interpretable Transformer

Share

Self-attention models, specifically Transformers have taken the computer vision field by storm be it OpenAI’s DALL-E or Google’s ViT models.  This creates a need for tools that can interpret and visualize the decision process behind transformer models. These visualizations can be used to debug models and verify that the models are fair and unbiased. A new approach for computing token relevance for Transformer models was proposed in the paper “Transformer Interpretability Beyond Attention Visualization” by Hila Chefer, Shir Gur, and Lior Wolf. The method assigns local relevance based on the Deep Taylor Decomposition and then propagates these relevancy scores through the layers. This propagation involves attention layers and skip-connections; both involve the mixing activation maps and have poised unique challenges to existing approaches. 

Approach & Algorithm

The main building block of Transformer models is the self-attention layer which assigns a pairwise attention value between every two tokens. It is common to consider these attentions as the relevancy scores when trying to visualize Transformer models. This is usually done for a single attention layer. The rollout method recursively computes the token attention in each layer of a model, it assumes that attentions are combined linearly into subsequent contexts. Given a Transformer model with L layers, rollout method computes the attention from all possible positions in layer li to all possible positions in layer lj, where i <j. It produces better results when compared with methods that utilize a single attention layer. However, it relies on simplistic assumptions and often highlights irrelevant tokens. These methods that use the pairwise attentions themselves for visualizing the decision process are fundamentally flawed, they end up ignoring most of the attention scores, and other layers are not even considered.

Another approach for interpreting transformers is attribution propagation methods based on the Deep Taylor Decomposition (DTD) framework. These methods recursively decompose the network’s decision into the contributions of the previous layers, all the way up to the elements of the network’s input. The Layerwise Relevance Propagation (LRP) is one such method that propagates relevance from the predicated class backwards to the input image. Transformers apply non-linearities other than ReLU, resulting in both positive and negative features, but LRP assumes that ReLU non-linearity activations and thus fails. Furthermore, most of the existing transformer interpretability methods are not class-specific in practice, i.e., they return the same visualization regardless of the class one tries to visualize.  

Interpretable Transformer

The new method proposed in the paper has three phases:

  1. Calculating relevance for each attention matrix using a modified formulation of LRP that only considers the elements that have a positive weighted relevance.

There are two operators in Transformer models that involve mixing two feature map tensors: skip connections and attention modules. These require the propagation of relevance through both input tensors. Given two tensors u and v, the relevance propagation of these binary operators are computed as follows:

While LRP results in positive relevance values, these operations yield both positive and negative values

  1. The gradients for each attention matrix with respect to the visualized class are back-propagated. These gradients are used alongside the relevancy scores and are integrated throughout the attention graph to iteratively remove the negative contributions.
  1. Layers are then aggregated using the rollout method. 

The explanation of a transformer model interpreted using this method is given by the matrix C of size s x s, where s denotes the length of the sequence(image/string) fed to the Transformer model. Given the other tokens, each row corresponds to a relevance map for each token. The row  C[CLS] ∈ Rs yields the relevance map for a particular class. This row contains a score that assesses the impact of each token on the final classification token. Only the tokens that match the input are taken into account. The series is reshaped to the patches grid size to obtain the final relevance map.

Using bilinear interpolation, this map is then upsampled to the original sequence’s dimension.

Interpreting the DeiT Transformer model 

The following code has been taken from one of the official example notebooks available here.

  1. Clone the Transformer-Explainability GitHub repository, navigate into the newly created Transformer-Explainability directory and install the requirements.
 !git clone https://github.com/hila-chefer/Transformer-Explainability.git
 import os
 os.chdir(f'./Transformer-Explainability')
 !pip install -r requirements.txt 

You’ll have to restart the runtime after this. Make sure to navigate into the newly created directory once you restart runtime otherwise you’ll encounter ImportErrors.

  1. Import necessary libraries and classes.
 from PIL import Image
 import torchvision.transforms as transforms
 import matplotlib.pyplot as plt
 import torch
 import numpy as np
 import cv2
 from baselines.ViT.ViT_LRP import deit_base_patch16_224 as vit_LRP
 from baselines.ViT.ViT_explanation_generator import LRP 
  1. Download the ImageNet class labels and create an index-to-class labels dictionary.
!wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt 
with open("imagenet_classes.txt", "r") as f:
     index_to_class = {i: s.strip() for i, s in enumerate(f.readlines())} 
  1. Load a pre-trained DeiT model
 # initialize ViT pretrained with DeiT
 model = vit_LRP(pretrained=True).cuda()
 model.eval()
 attribution_generator = LRP(model)
 normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
 transform = transforms.Compose([
     transforms.Resize((224, 224)),
     transforms.ToTensor(),
     normalize,
 ]) 
  1. Create two helper functions: one for visualizing the mask over images and second one for applying softmax to the final dense layer of the model to obtain the predicted classes.
 def show_cam_on_image(img, mask):
     heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
     heatmap = np.float32(heatmap) / 255
     cam = heatmap + np.float32(img)
     cam = cam / np.max(cam)
     return cam

 def print_top_classes(predictions, **kwargs):    
     # Print Top-5 predictions
     prob = torch.softmax(predictions, dim=1)
     class_indices = predictions.data.topk(5, dim=1)[1][0].tolist()
     max_str_len = 0
     class_names = []
     for cls_idx in class_indices:
         class_names.append(index_to_class[cls_idx])
         if len(index_to_class[cls_idx]) > max_str_len:
             max_str_len = len(index_to_class[cls_idx])
     print('Top 5 classes:')
     for cls_idx in class_indices:
         output_string = '\t{} : {}'.format(cls_idx, index_to_class[cls_idx])
         output_string += ' ' * (max_str_len - len(index_to_class[cls_idx])) + '\t\t'
         output_string += 'value = {:.3f}\t prob = {:.1f}%'.format(predictions[0, cls_idx], 100 * prob[0, cls_idx])
         print(output_string) 
  1. Create the function for interpreting the predictions process of DeiT model. generate_LRP is the only novel paper-specific method used in this function, you can find its implementation here.
 def generate_visualization(original_image, class_index=None):
     transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0).cuda(), method="transformer_attribution", index=class_index).detach()
     transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
     transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
     transformer_attribution = transformer_attribution.reshape(224, 224).cuda().data.cpu().numpy()
     transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
     image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy()
     image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())
     vis = show_cam_on_image(image_transformer_attribution, transformer_attribution)
     vis =  np.uint8(255 * vis)
     vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
     return vis 
  1. Visualizing the relevance of image patches for particular predictions in a class-specific manner.

Perform inference on the image to get the class index of the objects in the image.

 image = Image.open('samples/catdog.png')
 dog_cat_image = transform(image)
 output = model(dog_cat_image.unsqueeze(0).cuda())
 print_top_classes(output) 

Visualize the relevance of image patches for different classes.

 fig, axs = plt.subplots(1, 3, figsize=(21,7))
 axs[0].imshow(image);
 axs[0].axis('off');

 # dog - generate visualization for class 243: 'bull mastiff' - the predicted class
 #by default the predicted  class is visualized
 dog = generate_visualization(dog_cat_image)
 axs[1].imshow(dog);
 axs[1].axis('off');

 # cat - generate visualization for class 282 : 'tiger cat'
 cat = generate_visualization(dog_cat_image, class_index=282)
 axs[2].imshow(cat);
 axs[2].axis('off'); 

Let’s try another example

 image = Image.open('samples/dogcat2.png')
 dog_cat_image = transform(image)
 output = model(dog_cat_image.unsqueeze(0).cuda())
 print_top_classes(output)
 fig, axs = plt.subplots(1, 3,figsize=(21,7))
 axs[0].imshow(image);
 axs[0].axis('off');

 # golden retriever - the predicted class
 dog = generate_visualization(dog_cat_image)
 axs[1].imshow(dog);
 axs[1].axis('off');

 # generate visualization for class 285: 'Egyptian cat'
 cat = generate_visualization(dog_cat_image, class_index=285)
 axs[2].imshow(cat);
 axs[2].axis('off'); 

Here is the Colab Notebook containing the above code.

Last Epoch

Comparison of different methods for interpreting a transformer model's decision process.

This article discussed a new method for interpreting transformer models. Multiple factors have prevented the models developed for interpreting other forms of neural networks to be applied to Transformers. These include non-positive activation functions, the frequent use of skip connections, and the challenge of modelling the multiplication used in self-attention.  The new method provides specific solutions to each of these challenges. It obtains state-of-the-art results when compared to other transformer interpretability methods like the LRP method, and the GradCam method.

Reference

For a more in-depth understanding of the new method for interpreting transformers please refer to the following resources: 

PS: The story was written using a keyboard.
Picture of Aditya Singh

Aditya Singh

A machine learning enthusiast with a knack for finding patterns. In my free time, I like to delve into the world of non-fiction books and video essays.
Related Posts

Download our Mobile App

CORPORATE TRAINING PROGRAMS ON GENERATIVE AI

Generative AI Skilling for Enterprises

Our customized corporate training program on Generative AI provides a unique opportunity to empower, retain, and advance your talent.

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox
Recent Stories

Featured

Subscribe to The Belamy: Our Weekly Newsletter

Biggest AI stories, delivered to your inbox every week.

AI Courses & Careers

Become a Certified Generative AI Engineer

AI Forum for India

Our Discord Community for AI Ecosystem, In collaboration with NVIDIA. 

AIM Conference Calendar

Immerse yourself in AI and business conferences tailored to your role, designed to elevate your performance and empower you to accomplish your organization’s vital objectives. Revel in intimate events that encapsulate the heart and soul of the AI Industry.

Flagship Events

Rising 2024 | DE&I in Tech Summit

April 4 and 5, 2024 | 📍 Hilton Convention Center, Manyata Tech Park, Bangalore

MachineCon GCC Summit 2024

June 28 2024 | 📍Bangalore, India

MachineCon USA 2024

26 July 2024 | 583 Park Avenue, New York

Cypher India 2024

September 25-27, 2024 | 📍Bangalore, India

Cypher USA 2024

Nov 21-22 2024 | 📍Santa Clara Convention Center, California, USA

Data Engineering Summit 2024

May 30 and 31, 2024 | 📍 Bangalore, India

Download the easiest way to
stay informed