...

Grad-CAM from Scratch with PyTorch Hooks


car stops suddenly. Worryingly, there is no stop sign in sight. The engineers can only make guesses as to why the car’s neural network became confused. It could be a tumbleweed rolling across the street, a car coming down the other lane or the red billboard in the background. To find the real reason, they turn to Grad-CAM [1].

Grad-CAM is an explainable AI (XAI) technique that helps reveal why a convolutional neural network (CNN) made a particular decision. The method produces a heatmap that highlights the regions in an image that are the most important for a prediction. For our self-driving car example, this could show if the pixels from the weed, car or billboard caused the car to stop.

Now, Grad-CAM is one of many XAI methods for Computer Vision. Due to its speed, flexibility and reliability, it has quickly become one of the most popular. It has also inspired many related methods. So, if you are interested in XAI, it is worth understanding exactly how this method works. To do that, we will be implementing Grad-CAM from scratch using Python.

Specifically, we will be relying on PyTorch Hooks. As you will see, these allow us to dynamically extract gradients and activations from a network during forward and backwards passes. These are practical skills that will not only allow you to implement Grad-CAM but also any gradient-based XAI method. See the full project on GitHub.

The theory behind Grad-CAM

Before we get to the code, it is worth touching on the theory behind Grad-CAM. If you want a deep dive, then check out the video below. If you want to learn about other methods, then see this free XAI for Computer Vision course.

To summarise, when creating Grad-CAM heatmaps, we start with a trained CNN. We then do a forward pass through this network with a single sample image. This will activate all convolutional layers in the network. We call these feature maps ($A^k$). They will be a collection of 2D matrices that contain different features detected in the sample image.

With Grad-CAM, we are typically interested in the maps from the last convolutional layer of the network. When we apply the method to VGG16, you will see that its final layer has 512 feature maps. We use these as they contain features with the most detailed semantic information while still retaining spatial information. In other words, they tell us what was used for a prediction and where in the image it was taken from.

The problem is that these maps also contain features that are important for other classes. To mitigate this, we follow the process shown in Figure 1. Once we have the feature maps ($A^k$), we weight them by how important they are to the class of interest ($y_c$). We do this using $a_k^c$ — the average gradient of the score for $y_c$ w.r.t. to the elements in the feature map. We then do element-wise summation. For VGG16, you will see we go from 512 maps of 14×14 pixels to a single 14×14 map.

Figure 1: element-wise summation of the weighted feature maps from the last convolutional layer in a CNN (source: author)

The gradients for an individual element ($\frac{\partial y^c}{\partial A_{ij}^k}$) tell us how much the score will change with a small change in the element. This means that large average gradients indicate that the entire feature map was important and should contribute more to the final heatmap. So, when we weight and sum the maps, the ones that contain features for other classes will likely contribute less.

The final steps are to apply the ReLU activation function to ensure all negative elements will have a value of zero. Then we upsample with interpolation so the heatmap has the same dimensions as the sample image. The final map is summarised by the formula below. You might recognise it from the Grad-CAM paper [1].

$$ L_{Grad-CAM}^c = ReLU\left( \sum_{k} a_k^c A^k \right) $$

Grad-CAM from Scratch

Don’t worry if the theory is not completely clear. We will walk through it step by step as we apply the method from scratch. You can find the full project on GitHub. To start, we have our imports below. These are all common imports for computer vision problems.

import matplotlib.pyplot as plt
import numpy as np

import cv2
from PIL import Image

import torch
import torch.nn.functional as F
from torchvision import models, transforms

import urllib.request

Load pretrained model from PyTorch

We’ll be applying Grad-CAM to VGG16 pretrained on ImageNet. To help, we have the two functions below. The first will format an image in the correct way for input into the model. The normalisation values used are the mean and standard deviation of the images in ImageNet. The 224×224 size is also standard for ImageNet models.

def preprocess_image(img_path):

    """Load and preprocess images for PyTorch models."""

    img = Image.open(img_path).convert("RGB")

    #Transforms used by imagenet models
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    return transform(img).unsqueeze(0)

ImageNet has many classes. The second function will format the output of the model so we display the classes with the highest predicted probabilities.

def display_output(output,n=5):

    """Display the top n categories predicted by the model."""
    
    # Download the categories
    url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
    urllib.request.urlretrieve(url, "imagenet_classes.txt")

    with open("imagenet_classes.txt", "r") as f:
        categories = [s.strip() for s in f.readlines()]

    # Show top categories per image
    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    top_prob, top_catid = torch.topk(probabilities, n)

    for i in range(top_prob.size(0)):
        print(categories[top_catid[i]], top_prob[i].item())

    return top_catid[0]

We now load the pretrained VGG16 model (line 2), move it to a GPU (lines 5-8) and set it to evaluation mode (line 11). You can see a snippet of the model output in Figure 2. VGG16 is made of 16 weighted layers. Here, you can see the last 2 of 13 convolutional layers and the 3 fully connected layers.

# Load the pre-trained model (e.g., VGG16)
model = models.vgg16(pretrained=True)

# Set the model to gpu
device = torch.device('mps' if torch.backends.mps.is_built() 
                      else 'cuda' if torch.cuda.is_available() 
                      else 'cpu')
model.to(device)

# Set the model to evaluation mode
model.eval()

The names you see in Figure 2 are important. Later, we will use them to reference a specific layer in the network to access its activations and gradients. Specifically, we will use model.features[28]. This is the final convolutional layer in the network. As you can see in the snapshot, this layer contains 512 feature maps.

Figure 2: snapshot of final layers of the VGG16 network (source: author)

Forward pass with sample image

We will be explaining a prediction from this model. To do this, we need a sample image that will be fed into the model. We downloaded one from Wikipedia Commons (lines 2-3). We then load it (lines 5-6), crop it to have equal height and width (line 7) and display it (lines 9-10). In Figure 3, you can see we are using an image of a whale shark in an aquarium.

# Load a sample image from the web
img_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/a/a1/Male_whale_shark_at_Georgia_Aquarium.jpg/960px-Male_whale_shark_at_Georgia_Aquarium.jpg"
urllib.request.urlretrieve(img_url, "sample_image.jpg")[0]

img_path = "sample_image.jpg"
img = Image.open(img_path).convert("RGB")
img = img.crop((320, 0, 960, 640))  # Crop to 640x640

plt.imshow(img)
plt.axis("off")
One of two resident male whale sharks in the Georgia Aquarium in the United States.
Figure 3: male whale shark in aquarium (source: Wikimedia commons) (license: CC BY-SA 2.5)

ImageNet has no dedicated class for whale sharks, so it will be interesting to see what the model predicts. To do this, we start by processing our image (line 2) and moving it to the GPU (line 3). We then do a forward pass to get a prediction (line 6) and display the top 5 probabilities (line 7). You can see these in Figure 4.

# Preprocess the image
img_tensor = preprocess_image(img_path)
img_tensor = img_tensor.to(device)

# Forward pass
predictions = model(img_tensor)
display_output(predictions,n=5)

Given the available classes, these seem reasonable. They are all marine life and the top two are sharks. Now, let’s see how we can explain this prediction. We want to understand what regions of the image contribute the most to the highest predicted class — hammerhead.

Figure 4: top 5 predicted classes of the example image of the whale shark using VGG16 (source: author)

PyTorch hooks naming conventions

Grad-CAM heatmaps are created using both activations from a forward pass and gradients from a backwards pass. To access these, we will use PyTorch hooks. These are functions that allow you to save the inputs and outputs of a layer. We won’t do it here, but they even allow you to alter these aspects. For example, Guided Backpropagation can be applied by ensuring only positive gradients are propagated using a backwards hook.

You can see some examples of these functions below. A forwards_hook will be called during a forward pass. It will be registered on a given module (i.e. layer). By default, the function receives three arguments — the module, its input and its output. Similarly, a backwards_hook is triggered during a backwards pass with the module and gradients of the input and output.

# Example of a forwards hook function
def fowards_hook(module, input, output):
    """Parameters:
            module (nn.Module): The module where the hook is applied.
            input (tuple of Tensors): Input to the module.
            output (Tensor): Output of the module."""
    ...

# Example of a backwards hook function 
def backwards_hook(module, grad_in, grad_out):
    """Parameters:
            module (nn.Module): The module where the hook is applied.
            grad_in (tuple of Tensors): Gradients w.r.t. the input of the module.
            grad_out (tuple of Tensors): Gradients w.r.t. the output of the module."""
    ...

To avoid confusion, let’s clarify the parameter names used by these functions. Take a look at the overview of the standard backpropagation procedure for a convolutional layer in Figure 5. This layer consists of a set of kernels, $K$, and biases, $b$. The other parts are the:

  • input – a set of feature maps or an image
  • output – set of feature maps
  • grad_in is the gradient of the loss w.r.t. the layer’s input.
  • grad_out is the gradient of the loss w.r.t. the layer’s output.

We have labelled these using the same names of the arguments used to call the hook functions that we apply later.

Figure 5: Backpropagation for a convolutional layer in a deep learning model. The blue arrows show the forward pass and the red arrows show the backwards pass. (source: author)

Keep in mind, we won’t use the gradients in the same way as backpropagation. Usually, we use the gradients of a batch of images to update $K$ and $b$. Now, we are only interested in grad_out of a single sample image. This will give us the gradients of the elements in the layer’s feature maps. In other words, the gradients we use to weight the feature maps.

Activations with PyTorch forward hook

Our VGG16 network has been created using ReLU with inplace=True. These modify tensors in memory, so the original values are lost. That is, tensors used as input are overwritten by the ReLU function. This can lead to problems when applying hooks, as we may need the original input. So we use the code below to replace all ReLU functions with inplace=False ones. This will not impact the output of the model, but it will increase its memory usage.

# Replace all in-place ReLU activations with out-of-place ones
def replace_relu(model):

    for name, child in model.named_children():
        if isinstance(child, torch.nn.ReLU):
            setattr(model, name, torch.nn.ReLU(inplace=False))
            print(f"Replacing ReLU activation in layer: {name}")
        else:
            replace_relu(child)  # Recursively apply to submodules

# Apply the modification to the VGG16 model
replace_relu(model)

Below we have our first hook function — save_activations. This will append the output from a module (line 6) to a list of activations (line 2). In our case, we will only register the hook onto one module (i.e. the last convolutional layer), so this list will only contain one element. Notice how we format the output (line 6). We detach it from the computational graph so the network is not affected. We also format them as a numpy array and squeeze the batch dimension.

# List to store activations
activations = []

# Function to save activations
def save_activations(module, input, output):
    activations.append(output.detach().cpu().numpy().squeeze())

To use the hook function, we register it on the last convolutional layer — model.features[28]. This is done using the register_forward_hook function.

# Register the hook to the last convolutional layer
hook = model.features[28].register_forward_hook(save_activations)

Now, when we do a forward pass (line 2), the save_activations hook function will be called for this layer. In other words, its output will be saved to the activations list.

# Forward pass through the model to get activations
prediction = model(img_tensor)

Finally, it is good practice to remove the hook function when it is no longer needed (line 2). This means the forward hook function will not be triggered if we do another forward pass.

# Remove the hook after use
hook.remove()  

The shape of these activations is (512, 14, 14). In other words, we have 512 feature maps and each map is 14×14 pixels. You can see some examples of these in Figure 6. Some of these maps may contain features important for other classes or those that decrease the probability of the predicted class. So let’s see how we can find gradients to help identify the most important maps.

act_shape = np.shape(activations[0])
print(f"Shape of activations: {act_shape}") # (512, 14, 14)
Figure 6: example of activated feature maps from the last convolutional layer of the network (source: author)

Gradients with PyTorch backwards hooks

To get gradients, we follow a similar process to before. The key difference is that we now use the register_full_backward_hook to register the save_gradients function (line 7). This will ensure that it is called during a backwards pass. Importantly, we do the backwards pass (line 16) from the output for the class with the highest score (line 13). This effectively sets the score for this class to 1 and all other scores to 0. In other words, we get the gradients of the hammerhead class w.r.t. to the elements of the feature maps.

gradients = []

def save_gradient(module, grad_in, grad_out):
    gradients.append(grad_out[0].cpu().numpy().squeeze())

# Register the backward hook on a convolutional layer
hook = model.features[28].register_full_backward_hook(save_gradient)

# Forward pass
output = model(img_tensor)

# Pick the class with highest score
score = output[0].max()

# Backward pass from the score
score.backward()

# Remove the hook after use
hook.remove()

We will have a gradient for every element of the feature maps. So, again, the shape is (512, 14, 14). Figure 7 visualises some of these. You can see some tend to have higher values. However, we are not so concerned with the individual gradients. When we create a Grad-CAM heatmap, we will use the average gradient of each feature map.

grad_shape = np.shape(gradients[0])
print(f"Shape of gradients: {grad_shape}") # (512, 14, 14)
Figure 7: gradients of the score w.r.t. to the elements of feature maps in the last convolutional layer (source: author)

Finally, before we move on, it is good practice to reset the model’s gradients (line 2). This is particularly important if you plan to run the code for multiple images, as gradients can be accumulated with each backwards pass.

# Reset gradients
model.zero_grad() 

Creating Grad-CAM heatmaps

First, we find the mean gradients for each feature map. There will be 512 of these average gradients. Plotting a histogram of them, you can see most tend to be around 0. In other words, these don’t have much impact on the predicted score. There are a few that tend to have a negative impact and a positive impact. It is these feature maps we want to give more weight to.

# Step 1: aggregate the gradients
gradients_aggregated = np.mean(gradients[0], axis=(1, 2))
Figure 8: histogram of average gradients (source: author)

We combine all the activations by doing element-wise summation (lines 2-4). When we do this, we weight each feature map by its average gradient (line 3). In the end, we will have one 14×14 array.

# Step 2: weight the activations by the aggregated gradients and sum them up
weighted_activations = np.sum(activations[0] * 
                              gradients_aggregated[:, np.newaxis, np.newaxis], 
                              axis=0)

These weighted activations will contain both positive and negative pixels. We can consider the negative pixels to be suppressing the predicted score. In other words, an increase in the value of these regions tends to decrease the score. Since we are only interested in the positive contributions—regions that support the class prediction—we apply a ReLU activation to the final heatmap (line 2). You can see the difference in the heatmaps in Figure 9.

# Step 3: ReLU summed activations
relu_weighted_activations = np.maximum(weighted_activations, 0)
Figure 9: relu of weighted activations (source: author)

You can see the heatmap in Figure 9 is quite coarse. It would be more useful if it had the dimensions of the original image. This is why the last step for creating Grad-CAM heatmaps is to upsample to the dimension of the input image (lines 2-4). In this case, we have a 224×224 image.

#Step 4: Upsample the heatmap to the original image size
upsampled_heatmap = cv2.resize(relu_weighted_activations, 
                               (img_tensor.size(3), img_tensor.size(2)), 
                               interpolation=cv2.INTER_LINEAR)

print(np.shape(upsampled_heatmap))  # Should be (224, 224)

Figure 10 gives us our final visualisation. We display the sample image (lines 5-7) next to the heatmap (lines 10-15). For the latter, we create a clear visualisation with the help of Canny Edge detection (line 10). This gives us an edge map (i.e. outline) of the sample image. We can then overlay the heatmap on top of this (line 14).

# Step 5: visualise the heatmap
fig, ax = plt.subplots(1, 2, figsize=(8, 8))

# Input image
resized_img = img.resize((224, 224))
ax[0].imshow(resized_img)
ax[0].axis("off")

# Edge map for the input image
edge_img = cv2.Canny(np.array(resized_img), 100, 200)
ax[1].imshow(255-edge_img, alpha=0.5, cmap='gray')

# Overlay the heatmap 
ax[1].imshow(upsampled_heatmap, alpha=0.5, cmap='coolwarm')
ax[1].axis("off")

Looking at our Grad-CAM heatmap, there is some noise. However, it appears the model is relying on the tail fin and, to a lesser extent, the pectoral fin to make its predictions. It is starting to make sense why the model classified this shark as a hammerhead. Perhaps both animals share these characteristics.

Figure 10: input image (left) and grad-cam heatmap overlay on an edge map (right) (source: author)

For some further investigation, we apply the same process but now using an actual image of a hammerhead. In this case, the model appears to be relying on the same features. This is a bit concerning. Would we not expect the model to use one of the shark’s defining features— the hammerhead? Ultimately, this may lead VGG16 to confuse different types of sharks.

Figure 11: an additional example image (source: Wikimedia Commons) (license: CC BY 2.0)

With this example, we see how Grad-CAM can highlight potential flaws in our model. We can not only get their predictions but also understand how they made them. We can understand if the features used will lead to unforeseen predictions down the line. This can potentially save us a lot of time, money and in the case of more consequential applications, lives!

If you want to learn more about XAI for CV check out one of these articles. Or see this Free XAI for CV course.


I hope you enjoyed this article! See the course page for more XAI courses. You can also find me on Bluesky | Threads | YouTube | Medium

References

[1] Ramprasaath R Selvaraju, Michael Cogswell, Abhishek Das, Ramakrishna Vedantam, Devi Parikh, and Dhruv Batra. Grad-cam: Visual explanations from deep networks via gradient-based localization. In Proceedings of the IEEE international conference on computer vision, pages 618–626, 2017.

Source link

#GradCAM #Scratch #PyTorch #Hooks