Pytorch CAM特征可视化

2023-11-06

背景

       类别激活映射(Class Activation Mapping, CAM)用于对深度学习特征可视化,通过特征响应定位图像的关键部位,为深度学习可解释性提供了一种方法,ACM以热力图的方式展示了图像局部响应的强弱信息,对应于更强的位置具有更好的特征识别能力。

论文链接:Learning Deep Features for Discriminative Localization

CAM基本原理:

    定义类别分数 S_c = \sum_kw_k^c \sum_{x,y}f_k(x,y) = \sum_{x,y}\sum_kw_k^cf_k(x,y),其中f_k(x,y)表示最后一个卷积层第k通道的输出,w_k^c为第k个通道对应的类别c的权重,定义CAM对第C类的映射M_c,则有M_c(x,y) = \sum_kw_k^cf_k(x,y)

CAM相关方法:Grad-CAM: https://arxiv.org/pdf/1610.02391.pdf、Grad-CAM++: https://arxiv.org/pdf/1610.02391.pdf

基于Resnet50的特征可视化代码:

import os
import torch
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import cv2

os.environ["KMP_DUPLICATE_LIB_OK"]="True"

def draw_cam(model, img_path, save_path, transform=None, visheadmap=False):
    img = Image.open(img_path).convert('RGB')
    if transform is not None:
        img = transform(img)
    img = img.unsqueeze(0)
    model.eval()
    x = model.conv1(img)
    x = model.bn1(x)
    x = model.relu(x)
    x = model.maxpool(x)
    x = model.layer1(x)
    x = model.layer2(x)
    x = model.layer3(x)
    x = model.layer4(x)
    features = x                #1x2048x7x7
    print(features.shape)
    output = model.avgpool(x)   #1x2048x1x1
    print(output.shape)
    output = output.view(output.size(0), -1)
    print(output.shape)         #1x2048
    output = model.fc(output)   #1x1000
    print(output.shape)
    def extract(g):
        global feature_grad
        feature_grad = g
    pred = torch.argmax(output).item()
    pred_class = output[:, pred]
    features.register_hook(extract)
    pred_class.backward()
    greds = feature_grad
    pooled_grads = torch.nn.functional.adaptive_avg_pool2d(greds, (1, 1))
    pooled_grads = pooled_grads[0]
    features = features[0]
    for i in range(2048):
        features[i, ...] *= pooled_grads[i, ...]
    headmap = features.detach().numpy()
    headmap = np.mean(headmap, axis=0)
    headmap /= np.max(headmap)

    if visheadmap:
        plt.matshow(headmap)
        # plt.savefig(headmap, './headmap.png')
        plt.show()

    img = cv2.imread(img_path)
    headmap = cv2.resize(headmap, (img.shape[1], img.shape[0]))
    headmap = np.uint8(255*headmap)
    headmap = cv2.applyColorMap(headmap, cv2.COLORMAP_JET)
    superimposed_img = headmap*0.4 + img
    cv2.imwrite(save_path, superimposed_img)

if __name__ == '__main__':
     model = models.resnet50(pretrained=True)
     transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
     draw_cam(model, './1.jpg', './cam_1.png', transform=transform, visheadmap=True)

效果展示:

 

项目地址:sourceCode

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch CAM特征可视化 的相关文章

  • PyTorch 中的截断反向传播(代码检查)

    我正在尝试在 PyTorch 中实现随时间截断的反向传播 对于以下简单情况K1 K2 我下面有一个实现可以产生合理的输出 但我只是想确保它是正确的 当我在网上查找 TBTT 的 PyTorch 示例时 它们在分离隐藏状态 将梯度归零以及这些
  • 如何检查 PyTorch 是否正在使用 GPU?

    如何检查 PyTorch 是否正在使用 GPU 这nvidia smi命令可以检测 GPU 活动 但我想直接从 Python 脚本内部检查它 这些功能应该有助于 gt gt gt import torch gt gt gt torch cu
  • Pytorch 分析器显示两个不同网络的卷积平均执行时间不同

    我有两个网络 我正在对它们进行分析以查看哪些操作占用了大部分时间 我注意到CUDA time avg为了aten conv2d不同网络的操作有所不同 这也增加了一个数量级 在我的第一个网络中 它是22us 而对于第二个网络则是3ms 我的第
  • Cuda和pytorch内存使用情况

    我在用Cuda and Pytorch 1 4 0 当我尝试增加batch size 我遇到以下错误 CUDA out of memory Tried to allocate 20 00 MiB GPU 0 4 00 GiB total c
  • 二维数组的按行 numpy.isin [重复]

    这个问题在这里已经有答案了 我有两个数组 A np array 3 1 4 1 1 4 B np array 0 1 5 2 4 5 2 3 5 是否可以使用numpy isin二维数组按行排列 我想检查一下是否A i j is in B
  • 为什么测试时一定要用DataParallel?

    在GPU上训练 num gpus设置为1 device ids list range num gpus model NestedUNet opt num channel 2 to device model nn DataParallel m
  • torch-1.1.0-cp37-cp37m-win_amd64.whl 在此平台上不受支持的滚轮

    我在开发 RNN 时需要使用 pyTorch 每当我尝试安装它时 我都会收到一条错误消息 指出 torch 1 1 0 cp37 cp37m win amd32 whl 在此平台上不受支持 pip3安装https download pyto
  • torchvision.transforms.Normalize 是如何操作的?

    我不明白如何标准化Pytorch works 我想将平均值设置为0和标准差1跨越张量中的所有列x形状的 2 2 3 一个简单的例子 gt gt gt x torch tensor 1 2 3 4 5 6 7 8 9 10 11 12 gt
  • 为什么 RNN 需要两个偏置向量?

    In Pytorch RNN 实现 http pytorch org docs master nn html highlight rnn torch nn RNN 有两个偏差 b ih and b hh 为什么是这样 它与使用一种偏差有什么
  • 如何在pytorch中查看DataLoader中的数据

    我在 Github 上的示例中看到类似以下内容 如何查看该数据的类型 形状和其他属性 train data MyDataset int 1e3 length 50 train iterator DataLoader train data b
  • LSTM 错误:AttributeError:“tuple”对象没有属性“dim”

    我有以下代码 import torch import torch nn as nn model nn Sequential nn LSTM 300 300 nn Linear 300 100 nn ReLU nn Linear 300 7
  • BatchNorm 动量约定 PyTorch

    Is the 批归一化动量约定 http pytorch org docs master modules torch nn modules batchnorm html 默认 0 1 与其他库一样正确 例如Tensorflow默认情况下似乎
  • Blenderbot 微调

    我一直在尝试微调 HuggingFace 的对话模型 Blendebot 我已经尝试过官方拥抱脸网站上给出的传统方法 该方法要求我们使用 trainer train 方法来完成此操作 我使用 compile 方法尝试了它 我尝试过使用 Py
  • Pytorch Tensor 如何获取元素索引? [复制]

    这个问题在这里已经有答案了 我有 2 个名为x and list它们的定义如下 x torch tensor 3 list torch tensor 1 2 3 4 5 现在我想获取元素的索引x from list 预期输出是一个整数 2
  • Pytorch GPU 使用率低

    我正在尝试 pytorch 的例子https pytorch org tutorials beginner blitz cifar10 tutorial html https pytorch org tutorials beginner b
  • PyTorch 中的连接张量

    我有一个张量叫做data形状的 128 4 150 150 其中 128 是批量大小 4 是通道数 最后 2 个维度是高度和宽度 我有另一个张量叫做fake形状的 128 1 150 150 我想放弃最后一个list array从第 2 维
  • 在 Pytorch 中估计高斯模型的混合

    我实际上想估计一个以高斯混合作为基本分布的归一化流 所以我有点被火炬困住了 但是 您可以通过估计 torch 中高斯模型的混合来在代码中重现我的错误 我的代码如下 import numpy as np import matplotlib p
  • 保存具有自定义前向功能的 Bert 模型并将其置于 Huggingface 上

    我创建了自己的 BertClassifier 模型 从预训练开始 然后添加由不同层组成的我自己的分类头 微调后 我想使用 model save pretrained 保存模型 但是当我打印它并从预训练上传时 我看不到我的分类器头 代码如下
  • 如何计算cifar10数据的平均值和标准差

    Pytorch 使用以下值作为 cifar10 数据的平均值和标准差 变换 Normalize 0 5 0 5 0 5 0 5 0 5 0 5 我需要理解计算背后的概念 因为这些数据是 3 通道图像 我不明白什么是相加的 什么是除什么的等等
  • ValueError:使用火炬张量时需要解压的值太多

    对于神经网络项目 我使用 Pytorch 并使用 EMNIST 数据集 已经给出的代码加载到数据集中 train dataset dsets MNIST root data train True transform transforms T

随机推荐

  • 计算机网络笔记(一)

    什么是计算机网络 什么是计算机网络 计算机网络就是互连 互联互通 的 自治 无主从关系 的计算机集合 那么 距离远 数据大如何保证互连 通过交换网络互连主机 什么 是 Internet 组成 计算机设备 通信链路 分组交换 数据包转发分组
  • linux unix域socket_python3从零学习-5.8.1、socket—底层网络接口

    源代码 Lib socket py 这个模块提供了访问BSD 套接字 的接口 在所有现代Unix系统 Windows macOS和其他一些平台上可用 这个Python接口是用Python的面向对象风格对Unix系统调用和套接字库接口的直译
  • Kaldi-MFCC模块源码主流程分析

    那么趁着这个机会 研究一下kaldi源码中MFCC部分的内容 不说废话 我们从 compute mfcc feats cc开始讲解 这里是个main函数 需要携带参数 具体使用样例如下 1 compute mfcc feats 其实看到这里
  • JVM完整笔记

    这是我在看课程 黑马程序员JVM完整教程 过程中记的笔记 我觉得该课程总时不长 并且理论 实战是一个入门JVM的好课程 若你看完该课程可以看下面几个参看书进一步深入了解JVM 深入理解Java虚拟机 第二版 实战Java虚拟机 深入JAVA
  • Java基础3--Java流程控制语句

    Java基础3 Java流程控制语句 文章目录 Java基础3 Java流程控制语句 Java循环语句 while循环 do while循环 for循环 增强for循环 Java条件语句 if语句 if else语句 if多分支语句 Jav
  • 【FreeRTOS开发问题】FreeRTOS内存溢出

    FreeRTOS内存溢出 如下图所示 FreeRTOS编译完成后可以看到 系统提示无法分配内存到堆 Objects Template axf Error L6406E No space in execution regions with A
  • Error Microsoft Visual C++ 14.0 is required 最佳解决方法,亲测有效

    这种pip安装不上的包 1 找whl包下载安装 去Python安装包大全中 https www lfd uci edu gohlke pythonlibs 去下载 对应后缀为 whl 的安装包进行安装 后缀为 whl 的安装包进行安装的方法
  • java永久区_Java方法区和永久代

    目前有三大Java虚拟机 HotSpot oracle JRockit IBM J9 JRockit是oracle发明的 用于其WebLogic服务器 IBM JVM是IBM发明的用于其Websphere服务器 因此在某行开发的时候 他们用
  • linux日志打到垃圾箱,shell输出的那个垃圾桶——/dev/null

    昨晚花费一整晚在知乎回答了一个关于shell里面的重定向输出到 dev null的问题 果断今晚也同步发在这里 反正也没人看 以下来自一个重度linux使用患者不请自来的回答 先用简单的语言回答题主的问题 shell程序中 2 gt dev
  • USB学习之一:USB协议基础

    USB开发者论坛http www usb org USB专区 http group ednchina com 93 1 1USB的特点 在USB1 0和USB1 1版本中 只支持1 5Mb s的低速 low speeed 模式 和12Mb
  • 探索健康养老的“最后一公里” 附下载地址

    目前中国机构养老市场参与者主要包括 房地产开发商 保险公司以及一些专业的养老服务企业 其中房地产开发商和保险公司凭借丰富的开发经验和充足的 资金流在市场上处于领先地位 目前市 场主流的机构养老项目的营利模式主要 分三类 即 非销售类 销售类
  • 《Stable Diffusion WebUI折腾实录》在Windows完成安装, 从社区下载热门模型,批量生成小姐姐图片

    环境 操作系统 Windows11 显卡 RTX2060 6GB 显存 安装Python 下载 Python3 10 6 https www python org ftp python 3 10 6 python 3 10 6 amd64
  • 马来西亚旅游不可不去的世外桃源

    全马最漂亮的8大冷门 世外桃源 美到您都不相信这些地方竟然在马来西亚 1 Pulau Besar 柔佛州 情侣来这旅游或蜜月 真的最适合不过了 想要找一个宁静 舒服 温暖的海边度假吗 那么PulauBesar 或许适合你 这里没有其他海边来
  • numpy.arrange函数知识大全

    numpy arrange函数知识大全 numpy arrange函数作用 numpy arrange函数作用 numpy arrange函数的作用是生成带起点和终点的特定步长的排列 根据函数的参数的个数分为以下几种情况 1 只有一个参数
  • 使用Android studio 查看其它app的布局的结构

    日常开发过程中 难免会遇到一些比较好看的布局 这时候我们就想学习一下别人的布局结构 以便参考 如果是前端开发的话 直接用Chrome可以查看别人布局的结构 如果是android的就比较麻烦一些 不过也是可以的 只需要简单的两步 下面来演示一
  • 逻辑判断

    一 论证推理 1 1 基本原理 论证的基本原理 话题一致 例如 甲论证 中国足球不行 这个论点时 乙说 你行你上 这就是典型的话题不一致 因为 我足球行不行 和 国足行不行 是没有关系的 1 2 解题步骤 所有的论据都是为论点服务的 1 明
  • vs+qt添加qtOpengl时,要小写

    奇怪了 一开始添加模块时 写成了Opengl 是qmake错误 改成opengl就可以了
  • 一个简单的测试案例

    题目 有一个处理单价为5角钱的饮料的自动售货机软件测试用例的设计 其规格说明如下 若投入5角钱或1元钱的硬币 押下 橙汁 或 啤酒 的按钮 则相应的饮料就送出来 若售货机没有零钱找 则一个显示 零钱找完 的红灯亮 这时在投入1元硬币并押下按
  • Python · 无限画板(零)· 简介

    项目 GitHub 地址 免费线上示例产品 该示例产品的源代码 封面图对应的项目的源代码 需求 方案 在上一篇文章 用 Python 打造 AIGC 的 操作系统 里 我提到过这个 Python 无限画板的项目 carefree drawb
  • Pytorch CAM特征可视化

    背景 类别激活映射 Class Activation Mapping CAM 用于对深度学习特征可视化 通过特征响应定位图像的关键部位 为深度学习可解释性提供了一种方法 ACM以热力图的方式展示了图像局部响应的强弱信息 对应于更强的位置具有