(202301)pytorch图像分类全流程实战Task6:可解释性分析、显著性分析

2023-11-09

 Task6:可解释性分析、显著性分析

B站up同济子豪兄图像分类系列的学习(大佬的完整代码在GitHub开源)  

2022年人工智能依旧飞速发展,从传统机器学习模型到如今以“炼丹”为主的深度神经网络,代表着模型拟合度与模型可解释性各自的发展趋势。至此,深刻体会并成功解释NN为何能取得更优的效果成为各行各业的新目标,而可解释性机器学习便应运而生。 

感想:愈发感受到自己基础的薄弱了,今天的任务涉及的算法较多且深,短时间内恐怕是只能当个调包侠了,如果有和我一样基础薄弱的朋友可以一同前往子豪大佬的B站视频学习:

同济子豪兄的个人空间-同济子豪兄个人主页-哔哩哔哩视频

可解释性分析-论文集

大白话讲解卷积神经网络工作原理_哔哩哔哩_bilibili

CAM可解释性分析-算法讲解_哔哩哔哩_bilibili 

 torch-cam工具包

torch-cam的安装方式如下:

# 下载安装 torch-cam
git clone https://github.com/frgfm/torch-cam.git
pip install -e torch-cam/.

安装之后在jupyterlab中充气内核,导入torchcam检查是否安装成功。

torchcam有两种使用方式,一种是直接通过命令行使用,另一种是在代码中导入。

命令行

python torch-cam/scripts/cam_example.py --help

--help参数查看torch-cam的基本用法。

usage: cam_example.py [-h] [--arch ARCH] [--img IMG] [--class-idx CLASS_IDX]
                      [--device DEVICE] [--savefig SAVEFIG] [--method METHOD]
                      [--alpha ALPHA] [--rows ROWS] [--noblock]

Saliency Map comparison

optional arguments:
  -h, --help            show this help message and exit
  --arch ARCH           Name of the architecture (default: resnet18)
  --img IMG             The image to extract CAM from (default:
                        https://www.woopets.fr/assets/races/000/066/big-
                        portrait/border-collie.jpg)
  --class-idx CLASS_IDX
                        Index of the class to inspect (default: 232)
  --device DEVICE       Default device to perform computation on (default:
                        None)
  --savefig SAVEFIG     Path to save figure (default: None)
  --method METHOD       CAM method to use (default: None)
  --alpha ALPHA         Transparency of the heatmap (default: 0.5)
  --rows ROWS           Number of rows for the layout (default: 1)
  --noblock             Disables blocking visualization (default: False)

下面是通过命令行的使用方式

python torch-cam/scripts/cam_example.py \
        --img test_img/border-collie.jpg \
        --savefig output/B1_border_collie.jpg \
        --arch resnet18 \
        --class-idx 232 \
        --rows 2

如图参数如同字面意思好理解。运行torchcam中的CAM范例文件,读如一张图片(img),处理后保存到(savefig)路径下arch就是architecture,在此处表示针对哪种模型,class-idx表示针对的类别,rows表示行数。

“CAM会对网络最后的特征图进行加权求和,就可以得到一个注意力的机制(就是卷积神经网络更关注图片的什么地方)。”通过这个有助于我们人类研究人工智能(机器学习、深度学习)的思考方式,从而着手处理。

比较常见的例子是区分雪豹和阿拉伯豹的任务中,分类器的注意力实际上集中在背景上,通过生活的背景区分豹子而不是豹子本身。子豪大佬的视频中也提到了给熊猫图片加上噪声之后,被识别为长臂猿的例子。这些例子表明,我们有需要通过了解人工智能的思维方式来对人工智能算法进行进一步的优化。

代码导包

下面的步骤使得我们得到了一个CAM的矩阵。这里用的是SmoothGradCAMpp的算法,它的速度较慢,效果较好。而速度相对较快的是CAM和Grad_CAM算法。

from torchcam.methods import SmoothGradCAMpp 
# CAM GradCAM GradCAMpp ISCAM LayerCAM SSCAM ScoreCAM SmoothGradCAMpp XGradCAM

cam_extractor = SmoothGradCAMpp(model)

activation_map = cam_extractor(pred_id, pred_logits)
activation_map = activation_map[0][0].detach().cpu().numpy()

得到了矩阵之后若将他进行可视化,得到的是一个7*7的图像,得到的矩阵是多少更选择的模型有关,这里是7*7是由于resnet18这个模型的某一层(好像可解释性学习在去年12月份专门作为系列任务出现,那时我没有参加,现在想在一个任务中完成还是太勉强了,还是继续学习吧,悲)

 然后我们通过torchcam库自带的一个函数就可以把它投射到原图上。

rom torchcam.utils import overlay_mask

result = overlay_mask(img_pil, Image.fromarray(activation_map), alpha=0.7)

alpha的值仍然是透明度,在这里直观地表现为原图的亮暗程度。

pytorch-grad-cam工具包

pip install grad-cam torchcam
git clone https://github.com/jacobgil/pytorch-grad-cam.git

还是通过import pytorch_grad_cam来验证安装成功。

Grad-CAM热力图

grad-CAM:利用梯度作为特征图的权重。

CAM最重要的是两部分,一是特征图,一是特征图所对应的权重。grad-CAM就是利用梯度的方法找权重。

这样,Grad-CAM就不像CAM算法那样需要网络上有一个全局平均池化层,具有更大的适用范围。

LayerCAM

由于Grad-CAM存在的一些缺点,如“浅层权重方差变化大,不能使用梯度均值为特征图每个像素赋予同等重要度”,衍生出了一系列算法,比如LayerCAM。

LayerCAM:基于元素的,每一张特征图中每一个元素都有一个对应的权重。

此外还有 Grad-CAMpp,ScoreCAM等衍生算法。

基于Guided Grad-CAM的高分辨率细粒度可解释性分析

对单张图像,进行Guided Grad-CAM可解释性分析,绘制既具有类别判别性(Class-Discriminative),又具有高分辨率的细粒度热力图。

效果展示:

如果去看了源码的话可以知道实际上我们期望的是只显示出狗头,而这个Guided Backpropagation算法还显示出了猫头,所以采用了该图像和Grad-CAM热力图逐元素相乘的方式,结果如下:

这里子豪大佬给出了一个思考题:

Guided Backpropagation确实兼顾了高分辨率和class-discriminative,唯一美中不足的就是可视化效果太不锐利

是否可以进行进一步的图像处理,更适合人眼来看

说实话,我感到这个图像已经很适合人眼来看了,也不太明白可视化效果不太锐利是什么意思。是增加锐度的话可以做个卷积核,但显然不是这个意思。如果是要更突出的话,那么是不是要把狗头着一块(颜色波动明显)用原图覆盖呢。

基于DFF的图像子区域可解释性分析

对单张图像,进行Deep Feature Factorization可解释性分析,展示Concept Discovery概念发现图。参考阅读如下:

GitHub - jacobgil/pytorch-grad-cam: Advanced AI Explainability for computer vision. Support for CNNs, Vision Transformers, Classification, Object detection, Segmentation, Image similarity and more.

Deep Feature Factorizations for better model explainability — Advanced AI explainability with pytorch-gradcam

https://arxiv.org/abs/1806.10206

通过获取每个concept对应的类别来做到这张图像。

我是越来越觉得可解释性分析和语义分割有关系了,似乎都是弱监督,但我很菜没学过,我还是表达以下期待吧。

封装函数如下:

def dff_show(img_path='test_img/cat_dog.jpg', n_components=5, top_k=2, hstack=False):
    img, rgb_img_float, input_tensor = get_image_from_path(img_path)
    dff = DeepFeatureFactorization(model=model, 
                                   target_layer=model.layer4, 
                                   computation_on_concepts=classifier)
    concepts, batch_explanations, concept_outputs = dff(input_tensor, n_components)
    concept_outputs = torch.softmax(torch.from_numpy(concept_outputs), axis=-1).numpy()
    concept_label_strings = create_labels(concept_outputs, top_k=top_k)
    visualization = show_factorization_on_image(rgb_img_float, 
                                                batch_explanations[0],
                                                image_weight=0.3, # 原始图像透明度
                                                concept_labels=concept_label_strings)
    if hstack:
        result = np.hstack((img, visualization))
    else:
        result = visualization
    display(Image.fromarray(result))

captum工具包

直接使用pip安装这个包即可。

遮挡可解释性分析-ImageNet图像分类

遮挡可解释性分析

在输入图像上,用遮挡滑块,滑动遮挡不同区域,探索哪些区域被遮挡后会显著影响模型的分类决策。

提示:因为每次遮挡都需要分别单独预测,因此代码运行可能需要较长时间。

我们可以改变遮挡滑块的大小。

下面代码以小滑块为例。

# 更改遮挡滑块的尺寸
attributions_occ = occlusion.attribute(input_tensor,
                                       strides = (3, 2, 2), # 遮挡滑动移动步长
                                       target=pred_id, # 目标类别
                                       sliding_window_shapes=(3, 4, 4), # 遮挡滑块尺寸
                                       baselines=0)

# 转为 224 x 224 x 3的数据维度
attributions_occ_norm = np.transpose(attributions_occ.detach().cpu().squeeze().numpy(), (1,2,0))

viz.visualize_image_attr_multiple(attributions_occ_norm, # 224 224 3
                                  rc_img_norm,           # 224 224 3
                                  ["original_image", "heat_map"],
                                  ["all", "positive"],
                                  show_colorbar=True,
                                  outlier_perc=2)
plt.show()

Integrated Gradients可解释性分析

Integrated Gradients 原理

输入图像像素由空白变为输入图像像素的过程中,模型预测为某一特定类别的概率相对于输入图像像素的梯度积分。

noise_tunnel = NoiseTunnel(integrated_gradients)

# 获得输入图像每个像素的 IG 值
attributions_ig_nt = noise_tunnel.attribute(input_tensor, nt_samples=12, nt_type='smoothgrad_sq', target=pred_id)

# 转为 224 x 224 x 3的数据维度
attributions_ig_nt_norm = np.transpose(attributions_ig_nt.squeeze().cpu().detach().numpy(), (1,2,0))

# 设置配色方案
default_cmap = LinearSegmentedColormap.from_list('custom blue', 
                                                 [(0, '#ffffff'),
                                                  (0.25, '#000000'),
                                                  (1, '#000000')], N=256)

viz.visualize_image_attr_multiple(attributions_ig_nt_norm, # 224 224 3
                                  rc_img_norm, # 224 224 3
                                  ["original_image", "heat_map"],
                                  ["all", "positive"],
                                  cmap=default_cmap,
                                  show_colorbar=True)
plt.show()

GradientShap可解释性分析

QAQ我是在是弄不明白了,真应该在过年那段时间提前做任务的。看看视频吧。

可解释机器学习公开课_哔哩哔哩_bilibili 

Feature Ablation特征消融可解释性分析

根据实例分割标注图,分别除去图像中的不同语义分组区域,观察对模型预测结果的影响。

feature group 特征分组

在实例分割标注图中,每一个类别都被划为一类 feature group。

Feature Ablation 就是分析每个 feature group 存在(或者不存在)的影响。

参考教程:Captum · Model Interpretability for PyTorch

shap工具包

直接pip安装即可。

预备知识

图像分类全流程:构建数据集、训练模型、预测新图、测试集评估、可解释性分析、终端部署

视频教程:构建自己的图像分类数据集【两天搞定AI毕设】_哔哩哔哩_bilibili

代码教程:GitHub - TommyZihao/Train_Custom_Dataset: 标注自己的数据集,训练、评估、测试、部署自己的人工智能算法

UCI心脏病二分类+可解释性分析:【子豪兄Kaggle】玩转UCI心脏病二分类数据集_哔哩哔哩_bilibili

shap工具包相关

shap工具包:GitHub - slundberg/shap: A game theoretic approach to explain the output of any machine learning model.

shap工具包论文:https://proceedings.neurips.cc/paper/2017/file/8a20a8621978632d76c43dfd28b67767-Paper.pdf

DataWhale公众号推送【6个机器学习可解释性框架!】:6个机器学习可解释性框架!

 lime工具包

直接pip安装即可。

推荐视频

LIME论文逐句精读:同济子豪兄的个人空间-同济子豪兄个人主页-哔哩哔哩视频

图像分类可解释性分析实战-shap工具包:图像分类可解释性分析实战-shap工具包_哔哩哔哩_bilibili

UCI心脏病二分类+可解释性分析:【子豪兄Kaggle】玩转UCI心脏病二分类数据集_哔哩哔哩_bilibili

公开课

【同济子豪兄“两天搞定AI毕业设计之【图像分类】”公开课】

图像分类全流程:构建数据集、训练模型、预测新图、测试集评估、可解释性分析、终端部署

视频合集:同济子豪兄的个人空间-同济子豪兄个人主页-哔哩哔哩视频

代码:GitHub - TommyZihao/Train_Custom_Dataset: 标注自己的数据集,训练、评估、测试、部署自己的人工智能算法

【同济子豪兄可解释机器学习公开课】

包含人工智能可解释性、显著性分析领域的导论、算法综述、经典论文精读、代码实战、前沿讲座。

课程主页:zihao_course/XAI at main · TommyZihao/zihao_course · GitHub

博客链接

lime工具包:https://github.com/marcotcr/lime

DataWhale公众号推送【6个机器学习可解释性框架!】:6个机器学习可解释性框架!

https://towardsdatascience.com/lime-how-to-interpret-machine-learning-models-with-python-94b0e7e4432e

代码运行云GPU平台:Featurize

给句痛快话,投降就投降!再过半个月我又是一条好汉!

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

(202301)pytorch图像分类全流程实战Task6:可解释性分析、显著性分析 的相关文章

  • 使用RabbitMQ实现延时队列

    之前公司是一个类电商公司 会有用户下单后未支付取消订单的场景 解决方案是使用RabbitMQ的死信队列来实现一个延时队列 下单时 将订单丢进消息队列 设置过期时间 订单失效时间 然后到时候检查订单状态 如果未支付则取消订单 1 什么是死信
  • 【LeetCode】345. 反转字符串中的元音字母

    题目 给你一个字符串 s 仅反转字符串中的所有元音字母 并返回结果字符串 元音字母包括 a e i o u 且可能以大小写两种形式出现 示例 1 输入 s hello 输出 holle 示例 2 输入 s leetcode 输出 leotc
  • odoo连接器-odoo数据拉取,Java xml-rpc实现

    背景 odoo数据拉取 创建 更新 参考 官方external api文档 External API Odoo 14 0 文档 术语 ORM odoo数据以对象模型呈现 支持one2many many2one many2many等对象关联关
  • FSDataOutputStream 的深入分析

    对于一般文件 都有满足随机读写的api 而hadoop中的读api很简单用FSDataInputStream类就可以满足一般要求 而hadoop中的写操作却是和普通java操作不一样 在这里插入代码片 Hadoop对于写操作提供了一个类 F

随机推荐

  • 刷脸支付服务商重金之下必有勇夫

    为了吸引消费者使用刷脸支付 而非扫码 支付宝和微信会给消费者提供优惠 比如在店里面使用刷脸会有随机立减 打折活动 而扫码则没有 这只是给消费者的补贴 以利益吸引商家与推广人员加入刷脸支付同样重要 显然 与二维码支付的几乎没有成本不同 刷脸支
  • C++全局变量的声明和定义

    参考 http wrchen blog sohu com 71617539 html 1 编译单元 模块 在VC或VS上编写完代码 点击编译按钮准备生成exe文件时 编译器做了两步工作 第一步 将每个 cpp c 和相应的 h文件编译成ob
  • 算法学习:插值型求积公式

    算法学习 插值型求积公式 牛顿 柯斯特 Newton Cotes 求积公式 定义 牛顿 柯斯特 Newton Cotes 求积公式是插值型求积公式的特殊形式 在插值求积公式 baf x dx baP x dx k 0nAkf xk a b
  • Stall Reservations POJ - 3190

    这道题 是学长给我们布置的学习用的题目 重在给我们讲解了什么是优先队列以及其对应的贪心问题 好了 先送上 中文翻译过的题意 手动 滑稽 Oh those picky N 1 lt N lt 50 000 cows They are so p
  • Armbian5.9.0如何安装docker及部署可视化portainer

    安装 docker 通过 ssh 进去 Armbian 系统后 输入下面代码 按提示输入y 等待安装完成即可 apt install docker io 如何查看 docker 是否安装成果 输入命令 docker 可出现docker帮助内
  • MySQL常用命令用法总结

    原文 http www jb51 net article 22110 htm 一 启动与退出 1 进入MySQL 启动MySQL Command Line Client MySQL的DOS界面 直接输入安装时的密码即可 此时的提示符是 my
  • Win 10系统无法连接蓝牙耳机问题

    问题描述 本人刚入手的笔记本电脑 设置里面也有 蓝牙 的开关 由于处于实验室的环境不能开外放 有线耳机显得很不舒服 于是本人兴奋地拿起平时用的蓝牙耳机 想和电脑连起来 但是打开蓝牙开关之后 遇到了下面的情况 打开添加完设备之后 不仅仅是本人
  • springboot2.0学习笔记 自定义JSON序列化程序和反序列化器

    如果使用jackson序列化和反序列化json数据 则可能需要编写 自己JsonSerializer和JsonDeserializer的类 Spring提供了一个替代方案 JsonComponent创建注释 直接注册spring bean容
  • SecureCRT 64位 破解版v8.1.4

    http www xue51 com soft 1510 html xzdz securecrt 破解版是一款支持SSH1和SSH2的终端仿真程序 这个程序能够在windows系统中登陆UNIX或Linux的服务器主机并且还能进行管理设置
  • 全网最细的SpringBoot3系列教程

    1 开发第个Spring Boot应用 创建POM 因为是3 0 0 M1版本 是程碑版本 不是正式发布版 需要从Spring提的Maven仓库中才能下载到3 0 0 M1版本的依赖包 需要在pom xml件中单独指定仓库地址 如果使的是正
  • 安卓真机调试安装失败Session ‘app‘: Installation did not succeed. The application could not be installed: IN:

    Session app Installation did not succeed The application could not be installed INSTALL FAILED TEST ONLY 解决方案 在gradle pr
  • 自定义Looper/Handler模型 线程wait/notify版本 非poll版本

    循环 public static class Looper final static ThreadLocal
  • PS证件照换底色

    原图 1 本教程采用photoshop CS5制作 其它版本基本通用 先在PS中打开原图 如下图所示 2 右键单击背景图层 在弹出的菜单中选择 复制图层 如下图所示 3 接着会弹出 复制图层 对话框 直接按确定即可 如下图所示 4 单击选中
  • pymysql的使用

    pymysql是从Python连接到MySQL数据库服务器的接口 其官方文档为 https pymysql readthedocs io en latest 安装 pip install pymysql 对于数据库的操作 我们一般是这样的操
  • 正在开发应用于Maxthon、TT等多页面浏览器的页面模式

    经过大量的用户调查 我们发现 有不少朋友使用了Maxthon 腾讯TT 世界之窗等基于IE的多页面浏览器使用WEBCHAT 而这种模式下弹出窗口将变成一个新页面 用起来不方便
  • 如何查看支付宝旗下的天弘基金一共有多少只?分别是什么?

    如何查看支付宝旗下的天弘基金一共有多少只 分别是什么 2020年 股市风格突变 相对股市个股的跌宕起伏 基金的收益可谓一枝独秀 下面我们将对基金进行研究 看看我们可以获取数据能否到什么程度 利用tushare的数据接口就可以获取基金的名称
  • 排序类算法

    文章目录 利用vector进行排序 数字类元素 字符串类元素 利用其他STL容器排序 map set priority queue 利用vector进行排序 数字类元素 每个元素一般包含多个条件 利用lambda编写特定排序条件 用sort
  • 转:Ogre TerrainGroup地形赏析

    转 Ogre TerrainGroup地形赏析 1 1 参考 http www ogre3d org tikiwiki tiki index php page Ogre Terrain System http www ogre3d org
  • VS2017找不到QT头文件

    一 我的电脑右键属性 高级系统设置 环境变量 增加环境变量Qt INCLUDEPATH 值为QT的头文件目录 二 重启VS 发现波纹线不见了 证明设置环境变量后VS能识别到QT头文件了 原理是 vs导入qt项目附加包含目录继承值有Qt IN
  • (202301)pytorch图像分类全流程实战Task6:可解释性分析、显著性分析

    Task6 可解释性分析 显著性分析 对B站up同济子豪兄的图像分类系列的学习 大佬的完整代码在GitHub开源 2022年人工智能依旧飞速发展 从传统机器学习模型到如今以 炼丹 为主的深度神经网络 代表着模型拟合度与模型可解释性各自的发展