CNN可视化技术 -- CAM & Grad-CAM详解及pytorch简洁实现

2023-11-18

前言

CNN中的特征可视化大体可分为两类:

  • 细节信息:ZFNet中使用的deconvolution,改进的guide backpropagation
  • 信息的重要性区分:类激活图(CAM),改进的Grad-CAM

第一类方法只显示了在深层特征中保留了哪些信息,而没有突出显示这些信息的相对重要性。第二类方法则具有一定的解释性,例如在分类任务中,通过CAM能够解释模型究竟是通过重点学习输入图像中的哪些信息来判断类别的。

1. CAM(Class Activation Map)

Network in Network中提出了用全局平均池化(GAP)替代全连接层以加强特征映射与类别之间的联系,更具可解释性。受该思想启发,CAM可视化技术应运而出。生成CAM的流程如下图所示(论文原图):
在这里插入图片描述

可以看出,生成CAM的步骤非常简单,但是对网络结构有要求(网络末端为GAP+FC这样的结构,并且FC只有一层,用于输出类别概率)。假设分类任务采用的是VGG网络,此时生成CAM的步骤为:

  1. 将VGG中的前两个FC替换为GAP,重新训练;
  2. 获取最后一个卷积层输出的特征图 [ f 1 , f 2 , . . . , f n ] [f_1, f_2, ..., f_n] [f1,f2,...,fn],以及全连接层的权重 [ w 1 , w 2 , . . . , w n ] [w_1, w_2, ..., w_n] [w1,w2,...,wn]
  3. 计算 C A M = ∑ i = 1 n w i f i CAM=\sum_{i=1}^{n}w_if_i CAM=i=1nwifi

不难发现,若网络结构不符合要求,按照上述方法计算CAM需要修改网络结构和重新训练。针对该问题,后续研究中提出了Gard-CAM。

2. Grad-CAM

由上述CAM的计算方法可知,生成CAM的关键是获取特征图的权重。基于对原始CAM的改进,Grad-CAM通过求网络输出的类别置信度对特征图的偏导来获取权重,适用于任意网络,并且能够可视化任意层的类激活图(通常选择最后一个卷积层,因为其包含了丰富的高级语义和空间信息)。
在这里插入图片描述

  • 生成Grad-CAM的步骤如下:
  1. 图片送入网络,前向传播,获取最后一个卷积层的特征图 A k A^k Ak(可选,任意层均可, k k k为通道index);
  2. 反向传播,获取网络输出的类别 c c c 的概率 y c y^c yc关于 A k A^k Ak的梯度 ∂ y c ∂ A k \frac{\partial y^c}{\partial A^k} Akyc
  3. 计算权重 α k c = 1 Z ∑ i ∑ j ∂ y c ∂ A i , j k \alpha^{c}_{k}=\frac{1}{Z}\sum\limits_{i}\sum\limits_{j}\frac{\partial y^c}{\partial A^k_{i,j}} αkc=Z1ijAi,jkyc
  4. 计算Grad-CAM: L G r a d − C A M c = R e L U ( ∑ k α k c A k ) L_{Grad-CAM}^{c}=ReLU(\sum\limits_{k}\alpha^{c}_{k}A^k) LGradCAMc=ReLU(kαkcAk)
  • 求偏导的意义:参考知乎中的文章,偏导表示输出关于输入的变化率,也就是特征图上变化一个单位,得到的输出变化多少单位。可以反映出输出 y c y^c yc关于 A i , j k A^k_{i,j} Ai,jk的敏感程度,如果梯度大,则非常敏感,表示该位置更有可能属于类别 c c c

3. PyTorch中的hook机制

  • PyTorch中设计hook的目的:在不改变网络代码、不在forward中返回某一层的输出的情况下,获取网络中某一层在前向传播或反向传播过程的输入和输出,并对其进行相关操作(例如:特征图可视化,梯度裁剪)。

4. Grad-CAM的PyTorch简洁实现

import numpy as np
import torch
import cv2
import matplotlib.pyplot as plt
import torchvision.models as models
from torchvision.transforms import Compose, Normalize, ToTensor

class GradCAM():
    '''
    Grad-cam: Visual explanations from deep networks via gradient-based localization
    Selvaraju R R, Cogswell M, Das A, et al. 
    https://openaccess.thecvf.com/content_iccv_2017/html/Selvaraju_Grad-CAM_Visual_Explanations_ICCV_2017_paper.html
    '''
    def __init__(self, model, target_layers, use_cuda=True):
        super(GradCAM).__init__()
        self.use_cuda = use_cuda
        self.model = model
        self.target_layers = target_layers
        
        self.target_layers.register_forward_hook(self.forward_hook)
        self.target_layers.register_full_backward_hook(self.backward_hook)
        
        self.activations = []
        self.grads = []
        
    def forward_hook(self, module, input, output):
        self.activations.append(output[0])
        
    def backward_hook(self, module, grad_input, grad_output):
        self.grads.append(grad_output[0].detach())
        
    def calculate_cam(self, model_input):
        if self.use_cuda:
            device = torch.device('cuda')
            self.model.to(device)                 # Module.to() is in-place method 
            model_input = model_input.to(device)  # Tensor.to() is not a in-place method
        self.model.eval()
        
        # forward
        y_hat = self.model(model_input)
        max_class = np.argmax(y_hat.cpu().data.numpy(), axis=1)
        
        # backward
        model.zero_grad()
        y_c = y_hat[0, max_class]
        y_c.backward()
        
        # get activations and gradients
        activations = self.activations[0].cpu().data.numpy().squeeze()
        grads = self.grads[0].cpu().data.numpy().squeeze()
        
        # calculate weights
        weights = np.mean(grads.reshape(grads.shape[0], -1), axis=1)
        weights = weights.reshape(-1, 1, 1)
        cam = (weights * activations).sum(axis=0)
        cam = np.maximum(cam, 0) # ReLU
        cam = cam / cam.max()
        return cam
    
    @staticmethod
    def show_cam_on_image(image, cam):
        # image: [H,W,C]
        h, w = image.shape[:2]
        
        cam = cv2.resize(cam, (h,w))
        cam = cam / cam.max()
        heatmap = cv2.applyColorMap((255*cam).astype(np.uint8), cv2.COLORMAP_JET) # [H,W,C]
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
        
        image = image / image.max()
        heatmap = heatmap / heatmap.max()
        
        result = 0.4*heatmap + 0.6*image
        result = result / result.max()
        
        plt.figure()
        plt.imshow((result*255).astype(np.uint8))
        plt.colorbar(shrink=0.8)
        plt.tight_layout()
        plt.show()
        
    @staticmethod
    def preprocess_image(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
        preprocessing = Compose([
        	ToTensor(),
        	Normalize(mean=mean, std=std)
            ])
        return preprocessing(img.copy()).unsqueeze(0) 


if __name__ == '__main__':
    image = cv2.imread('both.png') # (224,224,3)
    input_tensor = GradCAM.preprocess_image(image)
    model = models.resnet18(pretrained=True)
    grad_cam = GradCAM(model, model.layer4[-1], 224)
    cam = grad_cam.calculate_cam(input_tensor)
    GradCAM.show_cam_on_image(image, cam)
  • 测试结果
    在这里插入图片描述
    (https://github.com/jacobgil/pytorch-grad-cam/blob/master/examples/both.png)
    在这里插入图片描述

参考资料

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

CNN可视化技术 -- CAM & Grad-CAM详解及pytorch简洁实现 的相关文章

随机推荐

  • 解决win10升级到win11,打不开安全中心的问题(亲测有效,已修复)

    相信很多人也碰上过这种问题 升级到了win 11 但是安全中心打不开了 报错 需要使用新应用以打开此windowsdefender链接 但是微软的应用商店并没有这个软件 然后我实验了一种方法 1 去微软的应用商店 Microsoft Sto
  • mysql中如何操作varchar类型的日期进行比较、排序等操作

    在mysql使用过程中 日期一般都是以datetime timestamp等格式进行存储的 但有时会因为特殊的需求或历史原因 日期的存储格式是varchar 那么我们该如何处理这个varchar格式的日期数据呢 使用函数 STR TO DA
  • SSM框架基于JSP犬舍寄养系统

    项目介绍 SSM框架基于JSP犬舍寄养系统的设计与实现 高清视频演示 SSM框架基于JSP犬舍寄养系统的设计与实现 安装视频演示 SSM框架基于JSP犬舍寄养系统的设计与实现 系统说明 1 前台功能模块 首先注册会员 登录进平台 然后选择自
  • 【PyTorch学习】(三)自定义Datasets

    torchvision datasets源码地址 https github com pytorch vision blob master torchvision datasets 前两篇从搭建经典的ResNet DenseNet入手简单的了
  • 【LVGL 学习】样式(style)风格学习

    概述 在 LVGL 中 样式都是以对象的方式存在 一个对象可以描述一种样式 每个控件都可以独立添加样式 创建的样式之间互不影响 可以使用 lv style t 类型创建一个样式并初始化 static lv style t style lv
  • 数据结构算法:写一个递归算法来实现字符串的逆序存放

    题目要求 写一个递归算法来实现字符串的逆序存储 要求不另设存储空间 首先我们定一个一个存放字符串的结构体 typedef struct String ElemType data int length String PString 创建字符串
  • 解决IE浏览器报错,对象不支持“assign”属性或方法

    报错页面 报错代码 解决后 解决代码 function doTest if typeof Object assign function Object assign function target use strict if target n
  • 线性回归分析

    文章目录 一 高尔顿数据集进行线性回归分析 1 1 父母平均身高和儿子身高线性回归分析 1 2 父亲身高和儿子身高线性回归分析 1 3 母亲身高和儿子身高线性回归分析 二 Anscombe四重奏数据集进行线性回归分析 一 高尔顿数据集进行线
  • MFC读取Excel(一)

    软件 vs2013 程序功能 MFC读取Excel里的第一个单元格的值 步骤 第一步 创建基于对话框的MFC工程 第二步 添加库 添加Excel类库 在工程名上右键 选择 添加 类 或者点击菜单栏的 项目 gt 添加类 选择 TypeLib
  • Windows下常用的快捷方式

    罗列出windows下我最常用的快捷键 逐步补充 打开我的电脑 win e
  • Kubesphere部署三高商城组织架构说明

    KubeSphere部署三高商城组织架构说明 一 创建企业空间 1 使用ws manager用户登录KubeSphere web控制器 创建企业空间 2 登出控制台 然后以 ws admin 身份重新登录 在企业空间设置中 选择企业空间成员
  • uniapp开发支付宝小程序之上传小程序

    市面上很多关于微信小程序通过uinapp开发的文档 支付宝的文档较少 这里做一下补充 为后浪提供参考 一给窝里giaogiao 通过hbuilder编码小程序后不能直接在支付宝开发者工具中上传 应该先通过HBuilder编译一下在操作 步骤
  • ERP中HR模块的操作与设计--开源软件诞生26

    赤龙ERP的EHR功能讲解 第26篇 用日志记录 开源软件 的诞生 进入地址 点亮星星 祈盼着一个鼓励 博主开源地址 码云 https gitee com redragon redragon erp GitHub https github
  • 51单片机——ADC模数转换、DAC数模转换PWM C语言入门编程

    目录 ADC XPT2046 1 ADC模数转换 数码管上显示AD模块采集电位器的阻值 热敏的温度值 光敏的光值 DAC PWM 1 DAC数模转换 DAC PWM 模块上的指示灯DA1呈呼吸灯效果 由暗变亮再由亮变暗 ADC ADC an
  • 聊聊技术专家谈阿里云史诗级故障

    序言 什么是技术专家 其实也是很懂 是做的时间足够长呢 还是说经历的厂比较多 还是说纸上谈兵比较牛逼 专家嘛 大家都懂的 只会弹别人 喔 是谈别人 原来不是弹 有本事技术专家谈谈自己呗 风言风语 阿里云出现史诗级故障 处理的时间足够长 然后
  • 过去式加ed的发音_「初中英语语法大全」不规则动词过去式和过去分词巧记方法...

    动词的过去式和过去分词是初中英语教学中的重点 而有些动词的不规则变化是这些重点中的难点 但这些不规则变化也不是毫无规律可循的 现将初中英语中一些常用的不规则动词变化介绍如下 一 原形 过去式和过去分词的词形和读音都相同的单词 结尾字母一般是
  • 计算机视觉项目实战(一)、图像滤波和图像混合 Image Filtering and Hybrid Images

    图像滤波和图像混合 Image Filtering and Hybrid Images 项目要求 项目原理 主要函数 my imfilter 函数解释 输入参数 输出参数 主要实现步骤 gen hybrid image 函数解释 输入参数
  • java中JSONArray 遍历方式

    第一种 java8 遍历JSONArray 拼接字符串 public static void main String args JSONArray jSONArray new JSONArray JSONObject jb new JSON
  • Linux中退出编辑模式的命令

    vi 文件 回车后就进入进入编辑模式 按 o 进行编辑 编辑结束 shift 退出编辑模式 然后输入退出命令 1 保存不退出 w 保存文件但不退出vi 编辑 w 强制保存 不退出vi 编辑 w file 将修改另存到file中 不退出vi
  • CNN可视化技术 -- CAM & Grad-CAM详解及pytorch简洁实现

    文章目录 前言 1 CAM Class Activation Map 2 Grad CAM 3 PyTorch中的hook机制 4 Grad CAM的PyTorch简洁实现 参考资料 前言 CNN中的特征可视化大体可分为两类 细节信息 ZF