We can all agree that Convolutional neural networks have proven to be very proficient in tasks like image classification, face recognition and document analysis. But with increasing efficiency and complexity, there is a gradual decrease in the interpretability of these algorithms. A solution to problems like face recognition involves hundreds of layers and thousands of parameters to train, making it difficult to read, debug and build trust in the model. CNNs appear to be black boxes that take in inputs and give outputs with great accuracy without giving an intuition about the working.
As a deep learning engineer, it is your responsibility to make sure the model is working correctly. Suppose you are given a task of classifying different birds. The dataset contains images of different birds and plant/trees in the background. If the network is looking at the plants and trees instead of the bird, there is a good chance the network will misclassify the image and miss all the features of the bird. How do we know our model is looking at the right thing? Through this article, we will discuss how to address the risk of working with CNN models in a black-box manner and how can we identify whether CNN is working correctly with the features which are important for classification or recognition.
What will we discuss in this article?
- What is Grad-CAM?
- How to use Grad-CAM?
- How does Grad-CAM visualize the region of interest of a CNN model?
What is Grad-CAM?
One way to ensure this is by visualizing what CNNs are actually looking at, using Grad-CAM. Gradient weighted Class Activation Map (Grad-CAM) produces a heat map that highlights the important regions of an image by using the gradients of the target(bird, elephant) of the final convolutional layer.
We take the feature maps of the final layer, weigh every channel in that feature with the gradient of the class with respect to the channel. It tells us how intensely the input image activates different channels by how important each channel is with regard to the class. It does not require any re-training or change in the existing architecture.
Implementation
We begin with a pre-trained model like VGG. The dataset used here is ImageNet. ImageNet is a very large collection of annotated photographs and consists of 1000 classes. In this example, we will try to highlight a class called ‘Shades’ and apply Grad-cam on this.
Loading Pre-Trained CNN model
from keras.applications.vgg16 import VGG16, preprocess_input, decode_predictions
from keras.preprocessing import image
import numpy as np
import cv2
import keras.backend as K
from skimage import io
model = VGG16(weights="imagenet")
Loading and preparing the Image
Now, choose any image from the internet that contains sunglasses. I have chosen an image of Tony Stark here. Since we are using a pre-trained model we need to make sure that our image size is 244×244. Once the image is resized the image is converted to an array.
sunglasses= io.imread("https://ae01.alicdn.com/kf/HTB1wnD8bcrrK1RjSspaq6AREXXaR/MK100-Tony-Stark-Doll-Head-Carved-Glasses-Seamless-Flexible-Body-1-6-Action-Figure-Scale-Model.jpg_q50.jpg")
sunglasses = cv2.resize(sunglasses, dsize=(224, 224), interpolation=cv2.INTER_CUBIC)
x = image.img_to_array(sunglasses)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
Making Prediction
We get the predictions of the images and take the output from the final convolution layer. Since ImageNet has 1000 classes where the label ‘shades/sunglasses’ belongs to class 837.
print(model.summary())
Model: “vgg16”
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 224, 224, 3) 0
_________________________________________________________________
block1_conv1 (Conv2D) (None, 224, 224, 64) 1792
_________________________________________________________________
block1_conv2 (Conv2D) (None, 224, 224, 64) 36928
_________________________________________________________________
block1_pool (MaxPooling2D) (None, 112, 112, 64) 0
_________________________________________________________________
block2_conv1 (Conv2D) (None, 112, 112, 128) 73856
_________________________________________________________________
block2_conv2 (Conv2D) (None, 112, 112, 128) 147584
_________________________________________________________________
block2_pool (MaxPooling2D) (None, 56, 56, 128) 0
_________________________________________________________________
block3_conv1 (Conv2D) (None, 56, 56, 256) 295168
_________________________________________________________________
block3_conv2 (Conv2D) (None, 56, 56, 256) 590080
_________________________________________________________________
block3_conv3 (Conv2D) (None, 56, 56, 256) 590080
_________________________________________________________________
block3_pool (MaxPooling2D) (None, 28, 28, 256) 0
_________________________________________________________________
block4_conv1 (Conv2D) (None, 28, 28, 512) 1180160
_________________________________________________________________
block4_conv2 (Conv2D) (None, 28, 28, 512) 2359808
_________________________________________________________________
block4_conv3 (Conv2D) (None, 28, 28, 512) 2359808
_________________________________________________________________
block4_pool (MaxPooling2D) (None, 14, 14, 512) 0
_________________________________________________________________
block5_conv1 (Conv2D) (None, 14, 14, 512) 2359808
_________________________________________________________________
block5_conv2 (Conv2D) (None, 14, 14, 512) 2359808
_________________________________________________________________
block5_conv3 (Conv2D) (None, 14, 14, 512) 2359808
_________________________________________________________________
block5_pool (MaxPooling2D) (None, 7, 7, 512) 0
_________________________________________________________________
flatten (Flatten) (None, 25088) 0
_________________________________________________________________
fc1 (Dense) (None, 4096) 102764544
_________________________________________________________________
fc2 (Dense) (None, 4096) 16781312
_________________________________________________________________
predictions (Dense) (None, 1000) 4097000
=================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
_________________________________________________________________
None
#Prediction
preds = model.predict(x)
class_output = model.output[:, 837]
last_conv_layer = model.get_layer("block5_conv3")
Visualizing the Region of Interest of the CNN Model
We now have all the information needed for performing the visualization. Compute the gradients of the output class with respect to the features of the last layer. Then, sum up the gradients in all the axes and weigh the output feature map with the computed gradient values.
grads = K.gradients(class_output, last_conv_layer.output)[0]
print(grads.shape)
pooled_grads = K.mean(grads, axis=(0, 1, 2))
print(pooled_grads.shape)
iterate = K.function([model.input], [pooled_grads, last_conv_layer.output[0]])
pooled_grads_value, conv_layer_output_value = iterate([x])
for i in range(512):
conv_layer_output_value[:, :, i] *= pooled_grads_value[i]
We take the average of the weighted feature map along the channel dimension resulting in a heat map of size 14×14 and normalize the map to lie between 0 and 1 and plot the map.
heatmap = np.mean(conv_layer_output_value, axis = -1)
print(conv_layer_output_value.shape)
print(heatmap.shape)
heatmap = np.maximum(heatmap, 0)
heatmap /= np.max(heatmap)
heatmap = cv2.resize(heatmap, (sunglasses.shape[1], sunglasses.shape[0]))
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
superimposed_img = cv2.addWeighted(sunglasses, 0.5, heatmap, 0.5, 0)
from google.colab.patches import cv2_imshow
cv2_imshow(superimposed_img)
Conclusion
From the image above it is clear that the network is looking exactly where we want it to look and not misclassifying the image. Grad-CAM is not only useful for visualization but also prove to be really effective for debugging and fine-tuning the model to get better results.