feature map-CAM 和 利用pytorch-hook注册实现CAM可视化

发布时间 2023-03-25 14:07:52作者: 努力的孔子

什么是CAM

CAM的全称是Class Activation MappingClass Activation Map,即类激活映射类激活图

论文《Learning Deep Features for Discriminative Localization》发现了CNN分类模型的一个有趣的现象:

CNN的最后一层卷积输出的特征图,对其通道进行加权叠加后,其激活值(ReLU激活后的非零值)所在的区域,即为图像中的物体所在区域。

而将这一叠加后的单通道特征图覆盖到输入图像上,即可高亮图像中物体所在位置区域。

该文章作者将实现这一现象的方法命名为类激活映射,并将特征图叠加在原始输入图像上生成的新图片命名为类激活图

 

 

Hook注册实现CAM

使用pytorch的hook注册, 取出网络某中间层feature map

(为啥用hook? 因为pytorch是动态图结构, 计算后的节点会被释放. 想要取出某中间结构, 需手动注册获取),

结合weighted_softmax, 点乘得到CAM(Class Activation Mapping)和heatmap.

 

以 resnet18 为例

 

import numpy as np
import cv2
import torch
from PIL import Image
from torchvision import models, transforms
from torch.autograd import Variable
from torch.nn import functional as F


def hook_feature(module, input, output):  # hook注册, 响应图提取
    print("hook input", input[0].shape)
    features_blobs.append(output.data.cpu().numpy())

def returnCAM(feature_conv, weight_softmax, class_idx, size_upsample):
    # 生成CAM图: 输入是feature_conv和weight_softmax
    bz, nc, h, w = feature_conv.shape
    output_cam = []
    for idx in class_idx:
        # feature_conv和weight_softmax 点乘(.dot)得到cam
        print(weight_softmax[idx].shape, feature_conv.reshape((nc, h * w)).shape)   # (512,) (512, 49)
        cam = weight_softmax[idx].dot(feature_conv.reshape((nc, h * w)))
        cam = cam.reshape(h, w)
        cam = cam - np.min(cam)
        cam_img = cam / np.max(cam)
        cam_img = np.uint8(255 * cam_img)
        output_cam.append(cv2.resize(cam_img, size_upsample))
    return output_cam


if __name__ == '__main__':
    size_upsample = (224, 224)
    # 1. imput image process
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])
    preprocess = transforms.Compose([
        transforms.Resize(size_upsample),
        transforms.ToTensor(),
        normalize])
    img_name = 'image1.jpg'
    img_pil = Image.open(img_name)
    img = cv2.imread(img_name)
    img_tensor = preprocess(img_pil)
    img_variable = Variable(img_tensor.unsqueeze(0))

    # 2. 导入res18 pretrain, 也可自行定义net结构然后导入.pth
    net = models.resnet18(pretrained=True)
    # net = models.resnet18(pretrained=False)
    # net.load_state_dict(torch.load('./resnet18-f37072fd.pth'), strict=True)
    net.eval()
    # print(net)

    # 3. 获取特定层的feature map
    # 3.1. hook the feature extractor
    features_blobs = []
    finalconv_name = 'layer4'   # 最后一个卷积模块
    # 对layer4层注册, 把layer4层的输出加入features
    net._modules.get(finalconv_name).register_forward_hook(hook_feature)
    print(net._modules)

    # 3.2. 得到weight_softmax
    params = list(net.parameters())  # 将参数变换为列表 按照weights bias 排列 池化无参数
    print(params)
    weight_softmax = np.squeeze(params[-2].data.numpy())  # 提取softmax 层的参数 (weights,-1是bias)
    print('weight_softmax.shape', weight_softmax.shape)     # (1000, 512)

    # 4. imput img inference
    logit = net(img_variable)
    h_x = F.softmax(logit, dim=1).data.squeeze()
    probs, idx = h_x.sort(0, True)
    probs = probs.numpy()
    idx = idx.numpy()
    print(idx.shape, idx[2])        # (1000,) 465

    # features_blobs[0], weight_softmax点乘得到CAM
    CAMs = returnCAM(features_blobs[0], weight_softmax, [idx[2], idx[3]], size_upsample)

    # 将图片和CAM拼接在一起展示定位结果结果
    img = cv2.resize(img, size_upsample)
    height, width, _ = img.shape
    # 生成热度图
    heatmap = cv2.applyColorMap(cv2.resize(CAMs[0], (width, height)), cv2.COLORMAP_JET)
    cv2.imwrite('./heatmap.jpg', heatmap)
    result = heatmap * 0.3 + img * 0.5
    cv2.imwrite('./CAM.jpg', result)

 

 

 

 

 

参考资料:

https://mp.weixin.qq.com/s/3mz7RyfBdOmY8WyZtr739w  pytorch-hook注册: 生成feature map可视化热力图

https://www.jianshu.com/p/fd2f09dc3cc9  CAM系列(一)之CAM(原理讲解和PyTorch代码实现)