背景:
类别激活映射(Class Activation Mapping, CAM)用于对深度学习特征可视化,通过特征响应定位图像的关键部位,为深度学习可解释性提供了一种方法,ACM以热力图的方式展示了图像局部响应的强弱信息,对应于更强的位置具有更好的特征识别能力。
论文链接:Learning Deep Features for Discriminative Localization
CAM基本原理:
定义类别分数 ,其中表示最后一个卷积层第通道的输出,为第个通道对应的类别的权重,定义CAM对第类的映射,则有。
CAM相关方法:Grad-CAM: https://arxiv.org/pdf/1610.02391.pdf、Grad-CAM++: https://arxiv.org/pdf/1610.02391.pdf
基于Resnet50的特征可视化代码:
import os
import torch
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import cv2
os.environ["KMP_DUPLICATE_LIB_OK"]="True"
def draw_cam(model, img_path, save_path, transform=None, visheadmap=False):
img = Image.open(img_path).convert('RGB')
if transform is not None:
img = transform(img)
img = img.unsqueeze(0)
model.eval()
x = model.conv1(img)
x = model.bn1(x)
x = model.relu(x)
x = model.maxpool(x)
x = model.layer1(x)
x = model.layer2(x)
x = model.layer3(x)
x = model.layer4(x)
features = x #1x2048x7x7
print(features.shape)
output = model.avgpool(x) #1x2048x1x1
print(output.shape)
output = output.view(output.size(0), -1)
print(output.shape) #1x2048
output = model.fc(output) #1x1000
print(output.shape)
def extract(g):
global feature_grad
feature_grad = g
pred = torch.argmax(output).item()
pred_class = output[:, pred]
features.register_hook(extract)
pred_class.backward()
greds = feature_grad
pooled_grads = torch.nn.functional.adaptive_avg_pool2d(greds, (1, 1))
pooled_grads = pooled_grads[0]
features = features[0]
for i in range(2048):
features[i, ...] *= pooled_grads[i, ...]
headmap = features.detach().numpy()
headmap = np.mean(headmap, axis=0)
headmap /= np.max(headmap)
if visheadmap:
plt.matshow(headmap)
# plt.savefig(headmap, './headmap.png')
plt.show()
img = cv2.imread(img_path)
headmap = cv2.resize(headmap, (img.shape[1], img.shape[0]))
headmap = np.uint8(255*headmap)
headmap = cv2.applyColorMap(headmap, cv2.COLORMAP_JET)
superimposed_img = headmap*0.4 + img
cv2.imwrite(save_path, superimposed_img)
if __name__ == '__main__':
model = models.resnet50(pretrained=True)
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
draw_cam(model, './1.jpg', './cam_1.png', transform=transform, visheadmap=True)
效果展示:
项目地址:sourceCode