Compute Relevancy Of Transformer Networks Via Novel Interpretable Transformer

Self-attention models, specifically Transformers have taken the computer vision field by storm be it OpenAI’s DALL-E or Google’s ViT models.  This creates a need for tools that can interpret and visualize the decision process behind transformer models. These visualizations can be used to debug models and verify that the models are fair and unbiased. A new approach for computing token relevance for Transformer models was proposed in the paper “Transformer Interpretability Beyond Attention Visualization” by Hila Chefer, Shir Gur, and Lior Wolf. The method assigns local relevance based on the Deep Taylor Decomposition and then propagates these relevancy scores through the layers. This propagation involves attention layers and skip-connections; both involve the mixing activation maps and have poised unique challenges to existing approaches. 

Approach & Algorithm

The main building block of Transformer models is the self-attention layer which assigns a pairwise attention value between every two tokens. It is common to consider these attentions as the relevancy scores when trying to visualize Transformer models. This is usually done for a single attention layer. The rollout method recursively computes the token attention in each layer of a model, it assumes that attentions are combined linearly into subsequent contexts. Given a Transformer model with L layers, rollout method computes the attention from all possible positions in layer li to all possible positions in layer lj, where i <j. It produces better results when compared with methods that utilize a single attention layer. However, it relies on simplistic assumptions and often highlights irrelevant tokens. These methods that use the pairwise attentions themselves for visualizing the decision process are fundamentally flawed, they end up ignoring most of the attention scores, and other layers are not even considered.

Another approach for interpreting transformers is attribution propagation methods based on the Deep Taylor Decomposition (DTD) framework. These methods recursively decompose the network’s decision into the contributions of the previous layers, all the way up to the elements of the network’s input. The Layerwise Relevance Propagation (LRP) is one such method that propagates relevance from the predicated class backwards to the input image. Transformers apply non-linearities other than ReLU, resulting in both positive and negative features, but LRP assumes that ReLU non-linearity activations and thus fails. Furthermore, most of the existing transformer interpretability methods are not class-specific in practice, i.e., they return the same visualization regardless of the class one tries to visualize.  

Interpretable Transformer

The new method proposed in the paper has three phases:

  1. Calculating relevance for each attention matrix using a modified formulation of LRP that only considers the elements that have a positive weighted relevance.

There are two operators in Transformer models that involve mixing two feature map tensors: skip connections and attention modules. These require the propagation of relevance through both input tensors. Given two tensors u and v, the relevance propagation of these binary operators are computed as follows:

While LRP results in positive relevance values, these operations yield both positive and negative values

  1. The gradients for each attention matrix with respect to the visualized class are back-propagated. These gradients are used alongside the relevancy scores and are integrated throughout the attention graph to iteratively remove the negative contributions.
  1. Layers are then aggregated using the rollout method. 

The explanation of a transformer model interpreted using this method is given by the matrix C of size s x s, where s denotes the length of the sequence(image/string) fed to the Transformer model. Given the other tokens, each row corresponds to a relevance map for each token. The row  C[CLS] ∈ Rs yields the relevance map for a particular class. This row contains a score that assesses the impact of each token on the final classification token. Only the tokens that match the input are taken into account. The series is reshaped to the patches grid size to obtain the final relevance map.

Using bilinear interpolation, this map is then upsampled to the original sequence’s dimension.

Interpreting the DeiT Transformer model 

The following code has been taken from one of the official example notebooks available here.

  1. Clone the Transformer-Explainability GitHub repository, navigate into the newly created Transformer-Explainability directory and install the requirements.
 !git clone https://github.com/hila-chefer/Transformer-Explainability.git
 import os
 os.chdir(f'./Transformer-Explainability')
 !pip install -r requirements.txt 

You’ll have to restart the runtime after this. Make sure to navigate into the newly created directory once you restart runtime otherwise you’ll encounter ImportErrors.

  1. Import necessary libraries and classes.
 from PIL import Image
 import torchvision.transforms as transforms
 import matplotlib.pyplot as plt
 import torch
 import numpy as np
 import cv2
 from baselines.ViT.ViT_LRP import deit_base_patch16_224 as vit_LRP
 from baselines.ViT.ViT_explanation_generator import LRP 
  1. Download the ImageNet class labels and create an index-to-class labels dictionary.
!wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt 
with open("imagenet_classes.txt", "r") as f:
     index_to_class = {i: s.strip() for i, s in enumerate(f.readlines())} 
  1. Load a pre-trained DeiT model
 # initialize ViT pretrained with DeiT
 model = vit_LRP(pretrained=True).cuda()
 model.eval()
 attribution_generator = LRP(model)
 normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
 transform = transforms.Compose([
     transforms.Resize((224, 224)),
     transforms.ToTensor(),
     normalize,
 ]) 
  1. Create two helper functions: one for visualizing the mask over images and second one for applying softmax to the final dense layer of the model to obtain the predicted classes.
 def show_cam_on_image(img, mask):
     heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
     heatmap = np.float32(heatmap) / 255
     cam = heatmap + np.float32(img)
     cam = cam / np.max(cam)
     return cam

 def print_top_classes(predictions, **kwargs):    
     # Print Top-5 predictions
     prob = torch.softmax(predictions, dim=1)
     class_indices = predictions.data.topk(5, dim=1)[1][0].tolist()
     max_str_len = 0
     class_names = []
     for cls_idx in class_indices:
         class_names.append(index_to_class[cls_idx])
         if len(index_to_class[cls_idx]) > max_str_len:
             max_str_len = len(index_to_class[cls_idx])
     print('Top 5 classes:')
     for cls_idx in class_indices:
         output_string = '\t{} : {}'.format(cls_idx, index_to_class[cls_idx])
         output_string += ' ' * (max_str_len - len(index_to_class[cls_idx])) + '\t\t'
         output_string += 'value = {:.3f}\t prob = {:.1f}%'.format(predictions[0, cls_idx], 100 * prob[0, cls_idx])
         print(output_string) 
  1. Create the function for interpreting the predictions process of DeiT model. generate_LRP is the only novel paper-specific method used in this function, you can find its implementation here.
 def generate_visualization(original_image, class_index=None):
     transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0).cuda(), method="transformer_attribution", index=class_index).detach()
     transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
     transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
     transformer_attribution = transformer_attribution.reshape(224, 224).cuda().data.cpu().numpy()
     transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
     image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy()
     image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())
     vis = show_cam_on_image(image_transformer_attribution, transformer_attribution)
     vis =  np.uint8(255 * vis)
     vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
     return vis 
  1. Visualizing the relevance of image patches for particular predictions in a class-specific manner.

Perform inference on the image to get the class index of the objects in the image.

 image = Image.open('samples/catdog.png')
 dog_cat_image = transform(image)
 output = model(dog_cat_image.unsqueeze(0).cuda())
 print_top_classes(output) 

Visualize the relevance of image patches for different classes.

 fig, axs = plt.subplots(1, 3, figsize=(21,7))
 axs[0].imshow(image);
 axs[0].axis('off');

 # dog - generate visualization for class 243: 'bull mastiff' - the predicted class
 #by default the predicted  class is visualized
 dog = generate_visualization(dog_cat_image)
 axs[1].imshow(dog);
 axs[1].axis('off');

 # cat - generate visualization for class 282 : 'tiger cat'
 cat = generate_visualization(dog_cat_image, class_index=282)
 axs[2].imshow(cat);
 axs[2].axis('off'); 

Let’s try another example

 image = Image.open('samples/dogcat2.png')
 dog_cat_image = transform(image)
 output = model(dog_cat_image.unsqueeze(0).cuda())
 print_top_classes(output)
 fig, axs = plt.subplots(1, 3,figsize=(21,7))
 axs[0].imshow(image);
 axs[0].axis('off');

 # golden retriever - the predicted class
 dog = generate_visualization(dog_cat_image)
 axs[1].imshow(dog);
 axs[1].axis('off');

 # generate visualization for class 285: 'Egyptian cat'
 cat = generate_visualization(dog_cat_image, class_index=285)
 axs[2].imshow(cat);
 axs[2].axis('off'); 

Here is the Colab Notebook containing the above code.

Last Epoch

Comparison of different methods for interpreting a transformer model's decision process.

This article discussed a new method for interpreting transformer models. Multiple factors have prevented the models developed for interpreting other forms of neural networks to be applied to Transformers. These include non-positive activation functions, the frequent use of skip connections, and the challenge of modelling the multiplication used in self-attention.  The new method provides specific solutions to each of these challenges. It obtains state-of-the-art results when compared to other transformer interpretability methods like the LRP method, and the GradCam method.

Reference

For a more in-depth understanding of the new method for interpreting transformers please refer to the following resources: 

Download our Mobile App

Aditya Singh
A machine learning enthusiast with a knack for finding patterns. In my free time, I like to delve into the world of non-fiction books and video essays.

Subscribe to our newsletter

Join our editors every weekday evening as they steer you through the most significant news of the day.
Your newsletter subscriptions are subject to AIM Privacy Policy and Terms and Conditions.

Our Recent Stories

Our Upcoming Events

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox
MOST POPULAR