注意力机制——CAM、SAM、CBAM、SE

2023-11-16

  CAM、SAM、CBAM详见:CBAM——即插即用的注意力模块(附代码)

目录

1.什么是注意力机制?

2.通道注意力机制——SE

(1)Squeeze

(2)Excitation

(3)SE Block

3.CAM

4.SAM

5.CBAM

6.代码

参考


1.什么是注意力机制?

从数学角度看,注意力机制即提供一种权重模式进行运算。

神经网络中,注意力机制即利用一些网络层计算得到特征图对应的权重值,对特征图进行”注意力机制“。

2.通道注意力机制——SE

论文地址论文

该论文于2018年发表于CVPR,是较早的将注意力机制引入卷积神经网络,并且该机制是一种即插即用的模块,可嵌入任意主流的卷积神经网络中,为卷积神经网络模型设计提供新思路——即插即用模块设计。

摘要核心

  • 背景介绍:卷积神经网络的核心是卷积操作,其通过局部感受野的方式融合空间和通道维度的特征;针对空间维度的特征提取方法已被广泛研究。
  • 本文内容:本文针对通道维度进行研究,探索通道之间的关系,并提出SE Block,它可自适应的调整通道维度上的特征。
  • 研究成果:SE Block可堆叠构成SENet,SENet在多个数据集上表现良好;SENet不仅可以大幅提升精度,同时仅需要增加少量的参数。
  • 比赛成绩:ILSVRC 2017分类冠军,top-5 error低至2.251%,相对于2016冠军下降了~25%
  • 代码开源

SE Block模型图如下所示:由两部分组成Squeeze和Excitation

(1)Squeeze

Squeeze(Global Information Embedding):全局信息低维嵌入

Squeeze操作:采用全局池化,即压缩H和W至1*1,利用1个像素来表示一个通道,实现低维嵌入。压缩后的特征本质是一个向量,无空间维度,只有通道维度。

Squeeze计算公式:

​ 

相对应模型的实现位置

(2)Excitation

Excitation(Adaptative Recalibration):适应变换

Excitation部分是用2个全连接来实现 ,第一个全连接把C个通道压缩成了C/r个通道来降低计算量(后面跟了RELU),第二个全连接再恢复回C个通道(后面跟了Sigmoid),r是指压缩的比例。作者尝试了r在各种取值下的性能 ,最后得出结论r=16时整体性能和计算量最平衡。

Excitation公式:

为什么要加全连接层呢?这是为了利用通道间的相关性来训练出真正的scale。一次mini-batch个样本的squeeze输出并不代表通道真实要调整的scale值,真实的scale要基于全部数据集来训练得出,而不是基于单个batch,所以后面要加个全连接层来进行训练。可以拿SE Block和下面3种错误的结构比较来进一步理解:
图2最上方的结构,squeeze的输出直接scale到输入上,没有了全连接层,某个通道的调整值完全基于单个通道GAP的结果,事实上只有GAP的分支是完全没有反向计算、没有训练的过程的,就无法基于全部数据集来训练得出通道增强、减弱的规律。
图2中间是经典的卷积结构,有人会说卷积训练出的权值就含有了scale的成分在里面,也利用了通道间的相关性,为啥还要多个SE Block?那是因为这种卷积有空间的成分在里面,为了排除空间上的干扰就得先用GAP压缩成一个点后再作卷积,压缩后因为没有了Height、Width的成分,这种卷积就是全连接了。
图2最下面的结构,SE模块和传统的卷积间采用并联而不是串联的方式,这时SE利用的是Ftr输入X的相关性来计算scale,X和U的相关性是不同的,把根据X的相关性计算出的scale应用到U上明显不合适。

相对应模型的实现位置

(3)SE Block

分开看完之后,再整合起来看就是如下图这样的操作过程。

  • Squeeze:压缩特征图至向量形式
  • Excitation:两个全连接对特征向量进行映射变换
  • Scale:将得到的权重向量于通道的乘法

SE Block的嵌入方式:只“重构”特征图,不改变原来结构。

3.CAM

4.SAM

5.CBAM

6.代码

空间注意力模块

import torch
from torch import nn

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)  # 7,3     3,1
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

if __name__ == '__main__':
    SA = SpatialAttention(7)
    data_in = torch.randn(8,32,300,300)
    data_out = SA(data_in)
    print(data_in.shape)  # torch.Size([8, 32, 300, 300])
    print(data_out.shape)  # torch.Size([8, 1, 300, 300])

通道注意力模块

import torch
from torch import nn

class ChannelAttention(nn.Module):
	def __init__(self, in_planes, ratio=16):
		super(ChannelAttention, self).__init__()
		self.avg_pool = nn.AdaptiveAvgPool2d(1)
		self.max_pool = nn.AdaptiveMaxPool2d(1)

		self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
		self.relu1 = nn.ReLU()
		self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
		self.sigmoid = nn.Sigmoid()

	def forward(self, x):
		avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
		max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
		out = avg_out + max_out
		return self.sigmoid(out)


if __name__ == '__main__':
    CA = ChannelAttention(32)
    data_in = torch.randn(8,32,300,300)
    data_out = CA(data_in)
    print(data_in.shape)  # torch.Size([8, 32, 300, 300])
    print(data_out.shape)  # torch.Size([8, 32, 1, 1])

CBAM注意力机制

import torch
from torch import nn

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)  # 7,3     3,1
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

class CBAM(nn.Module):
    def __init__(self, in_planes, ratio=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.ca = ChannelAttention(in_planes, ratio)
        self.sa = SpatialAttention(kernel_size)
        
    def forward(self, x):
        out = x * self.ca(x)
        result = out * self.sa(out)
        return result


if __name__ == '__main__':
    print('testing ChannelAttention'.center(100,'-'))
    torch.manual_seed(seed=20200910)
    CA = ChannelAttention(32)
    data_in = torch.randn(8,32,300,300)
    data_out = CA(data_in)
    print(data_in.shape)  # torch.Size([8, 32, 300, 300])
    print(data_out.shape)  # torch.Size([8, 32, 1, 1])




if __name__ == '__main__':
    print('testing SpatialAttention'.center(100,'-'))
    torch.manual_seed(seed=20200910)
    SA = SpatialAttention(7)
    data_in = torch.randn(8,32,300,300)
    data_out = SA(data_in)
    print(data_in.shape)  # torch.Size([8, 32, 300, 300])
    print(data_out.shape)  # torch.Size([8, 1, 300, 300])



if __name__ == '__main__':
    print('testing CBAM'.center(100,'-'))
    torch.manual_seed(seed=20200910)
    cbam = CBAM(32, 16, 7)
    data_in = torch.randn(8,32,300,300)
    data_out = cbam(data_in)
    print(data_in.shape)  # torch.Size([8, 32, 300, 300])
    print(data_out.shape)  # torch.Size([8, 1, 300, 300])

SE注意力机制

from torch import nn
import torch

class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)
        # return x * y


if __name__ == '__main__':
    torch.manual_seed(seed=20200910)
    data_in = torch.randn(8,32,300,300)
    SE = SELayer(32) 
    data_out = SE(data_in)
    print(data_in.shape)  # torch.Size([8, 32, 300, 300])
    print(data_out.shape)  # torch.Size([8, 32, 300, 300])
    
    
    
    
    
    
    

参考

注意力机制代码

SE模型详解

SENet

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

注意力机制——CAM、SAM、CBAM、SE 的相关文章

  • Spark 请求最大计数

    我是 Spark 的初学者 我尝试请求允许我检索最常访问的网页 我的要求如下 mostPopularWebPageDF logDF groupBy webPage agg functions count webPage alias cntW
  • 从数据框中按索引删除行

    我有一个数组wrong indexes train其中包含我想从数据框中删除的索引列表 0 63 151 469 1008 要删除这些索引 我正在尝试这样做 df train drop wrong indexes train 但是 代码失败
  • Python中Decimal类型的澄清

    每个人都知道 或者至少 每个程序员都应该知道 http docs oracle com cd E19957 01 806 3568 ncg goldberg html 即使用float类型可能会导致精度错误 然而 在某些情况下 精确的解决方
  • Python Popen 与 psexec 挂起 - 不良结果

    我对 subprocess Popen 和我认为是管道的问题有疑问 我有以下代码块 从 cli 运行时 100 都不会出现问题 p subprocess Popen psexec serverName get cmd c ver echo
  • python 中的代表

    我实现了这个简短的示例来尝试演示一个简单的委托模式 我的问题是 这看起来我已经理解了委托吗 class Handler def init self parent None self parent parent def Handle self
  • pydev 调试器:严重警告:此版本的 python 似乎编译不正确(内部生成的文件名不是绝对的)[重复]

    这个问题在这里已经有答案了 通过运行 from sklearn datasets import fetch california housing import pandas as pd pd set option precision 4 m
  • 如何使用 Plotly 中的直方图将所有离群值分入一个分箱?

    所以问题是 我可以在 Plotly 中绘制直方图 其中所有大于某个阈值的值都将被分组到一个箱中吗 所需的输出 但使用标准情节Histogram类我只能得到这个输出 import pandas as pd from plotly import
  • Pandas 中允许重复列

    我将一个大的 CSV 包含股票财务数据 文件分割成更小的块 CSV 文件的格式不同 像 Excel 数据透视表之类的东西 第一列的前几行包含一些标题 公司名称 ID 等在以下列中重复 因为一家公司有多个属性 而不是一家公司只有一栏 在前几行
  • 为什么Python的curses中escape键有延迟?

    In the Python curses module I have observed that there is a roughly 1 second delay between pressing the esc key and getc
  • python suds SOAP 请求中的名称空间前缀错误

    我使用 python suds 来实现客户端 并且在发送的 SOAP 标头中得到了错误的命名空间前缀 用于定义由element ref 在 wsdl 中 wsdl 正在引用数据类型 xsd 文件 请参见下文 问题出在函数上GetRecord
  • 使用 OLS 回归预测未来值(Python、StatsModels、Pandas)

    我目前正在尝试在 Python 中实现 MLR 但不确定如何将我找到的系数应用于未来值 import pandas as pd import statsmodels formula api as sm import statsmodels
  • 如何通过在 Python 3.x 上按键来启动和中断循环

    我有这段代码 当按下 P 键时会中断循环 但除非我按下非 P 键 否则循环不会工作 def main openGame while True purchase imageGrab if a sum gt 1200 fleaButton ti
  • TensorFlow的./configure在哪里以及如何启用GPU支持?

    在我的 Ubuntu 上安装 TensorFlow 时 我想将 GPU 与 CUDA 结合使用 但我却停在了这一步官方教程 http www tensorflow org get started os setup md 这到底是哪里 con
  • 使用鼻子获取设置中当前测试的名称

    我目前正在使用鼻子编写一些功能测试 我正在测试的库操作目录结构 为了获得可重现的结果 我存储了一个测试目录结构的模板 并在执行测试之前创建该模板的副本 我在测试中执行此操作 setup功能 这确保了我在测试开始时始终具有明确定义的状态 现在
  • Pandas 根据 diff 列形成簇

    我正在尝试使用 Pandas 根据表示时间 以秒为单位 的列中的差异来消除数据框中的一些接近重复项 例如 import pandas as pd numpy as np df pd DataFrame 1200 1201 1233 1555
  • 创建嵌套字典单行

    您好 我有三个列表 我想使用一行创建一个三级嵌套字典 i e l1 a b l2 1 2 3 l3 d e 我想创建以下嵌套字典 nd a 1 d 0 e 0 2 d 0 e 0 3 d 0 e 0 b a 1 d 0 e 0 2 d 0
  • mac osx 10.8 上的初学者 python

    我正在学习编程 并且一直在使用 Ruby 和 ROR 但我觉得我更喜欢 Python 语言来学习编程 虽然我看到了 Ruby 和 Rails 的优点 但我觉得我需要一种更容易学习编程概念的语言 因此是 Python 但是 我似乎找不到适用于
  • 使用yield 进行字典理解

    作为一个人为的例子 myset set a b c d mydict item yield join item s for item in myset and list mydict gives as cs bs ds a None b N
  • 如何为每个屏幕添加自己的 .py 和 .kv 文件?

    我想为每个屏幕都有一个单独的 py 和 kv 文件 应通过 main py main kv 中的 ScreenManager 选择屏幕 设计应从文件 screen X kv 加载 类等应从文件 screen X py 加载 Screens
  • 您可以使用关键字参数而不提供默认值吗?

    我习惯于在 Python 中使用这样的函数 方法定义 def my function arg1 None arg2 default do stuff here 如果我不供应arg1 or arg2 那么默认值None or default

随机推荐

  • springboot 打印请求路径到 日志 控制台

    文章目录 application properties 添加 logging level org springframework web servlet mvc method annotation RequestMappingHandler
  • 安装双系统后,将windows设置为默认启动选项的方法

    原先的电脑只有windows系统 后来加装了ubuntu系统 但由于大部分时间仍然需要使用windows 但是默认启动项为ubuntu 难免会带来一些不便 将windows设为默认第一启动项的方法很简单 打开终端 查看grub的配置文件 s
  • VC++实用宏定义

    前言 在日常的编程工作中 常常定义一些实用的宏方便调用 该文章将收集一些常用的宏供大家参考 欢迎大家讨论和添加 指针释放 最常用的就是指针的安全释放 对应new的释放 ifndef ReleasePtr define ReleasePtr
  • File Processing by Python

    Go through all the file in destination path import os import sys def GetFileList dir fileList newDir dir if os path isfi
  • 【计算机网络】TCP协议

    实验目的 应用所学知识 1 熟悉 TCP 的协议格式 2 理解 TCP 对序列号和确认号的使用 3 理解 TCP 的流量控制算法和拥塞控制算法 实验步骤与结果 1 任务一 将Alice txt上传到服务器 使用wireshark捕获数据包
  • Windows平台的SDK、DDK与WDK

    尽管Windows平台的SDK DDK与WDK都包含了WinDBG工具包 但是用户获取WinDBG工具包的最主要方式还是从微软网站自由下载 因为这样获得的版本最新 最近尝试去了解WINDOWS下的驱动开发 现在总结一下最近看到的资料 1 首
  • 下采样与上采样

    一 下采样 概念 下采样 subsampled 又称为降采样 downsampled 可以通俗地理解为缩小图像 减少矩阵的采样点数 方法 1 最常用隔位取值 每行每列每隔k个点取一个点 2 合并区域 每 row k col k 窗口内所有像
  • python selenium 键盘操作 常用

    键盘事件 前面的 send keys 方法用来模拟键盘输入 keys 类提供了键盘上几乎所有按键的方法 组合键也是可以的 常用的键盘操作如下 send keys Keys BACK SPACE 删除键 BackSpace send keys
  • 三十、纯虚函数、抽象类、多态、简单工厂模式

    一 纯虚函数 虚函数是多态是实现多态的前提 如果我们需要在基类中定义共同的结构 那么接口就需要定义成虚函数 但是很多情况下基类的接口是无法实现的 比如形状类Shape 定义一个Draw方法 很明显这个方法没法实现 因为我们可以画出圆 正方形
  • 乾坤微服务子项目图片资源加载失败

    一 背景 子项目单独运行时正常 放在乾坤上 img 加载图片时失败 二 分析原因 假设乾坤项目域名为 http www aaa com 子项目域名为 http www bbb com 项目实际运行时 图片的 html 写法为 img src
  • python:正向最大匹配法分词(以藏文为例)

    前段时间研究了如何用分词工具进行分词 但是分词中涉及的一些算法 不太了解 所以 准备这段时间专攻分词算法原理 大家有补充 或者建议 欢迎留言 1 最大匹配法 Maximum Matching 最大匹配法是指以词典为依据 取词典中最长词长度作
  • 某游戏大厂测开笔试题分享

    测开笔试题 某厂笔试题 执行时限1000ms 一个典型的电话拨号盘如下 1 2 3 4 5 6 7 8 9 0 手指在两个按键之间的移动距离被定义为这两个键的x y坐标差的绝对值之和 比如 6到自身的距离是0 到3 5 9键的距离是1 到2
  • 证件照如何换底色,分享三种证件照换底色的方法!

    在我们的日常生活中 不同场景需要使用不同颜色的证件照 如果我们需要更换证件照的背景颜色 通常情况下人们会选择去照相馆重新拍摄一组照片 但这样费时费力 而且在遇到紧急情况时可能来不及 本文将介绍三种非常实用的方法 希望能对您有所帮助 方法一
  • Langchain使用介绍之-文档加载

    Lanchain提供了加载多种文档的能力 Lanchain初了能加载txt csv等格式文档外 还支持加载网页 音频 pdf等 本篇博客将介绍如何通过Langchain完成PDF文档 音频文档 网页文档的加载 加载PDF文档 通过使用Lan
  • ChatGPT 中文调教指南。各种场景使用指南。学习怎么让它听你的话

    ChatGPT是由OpenAI训练的一款大型语言模型 能够生成类人文本 您只需要给出提示或提出问题 它就可以生成你想要的东西 在此页面中 您将找到可与 ChatGPT 一起使用的各种提示 正经指南 写小说 写一本拥有出人意料结局的推理小说
  • 1.1关于数据挖掘

    一 数据挖掘是什么 从技术层面讲 数据挖掘指从大量数据中提取潜在有用的信息和知识的过程 从商业层面讲 数据挖掘是一种对大量业务数据进行抽取 转换 分析和建模处理 并从中提取辅助商业决策的关键数据的商业信息处理技术 二 数据挖掘与传统传统数据
  • 2021-5-13 爬虫之Xpath的下载与安装,简单教学!

    5 13学习日记之Xpath Xpath的安装 怎么安装Xpath 问题一 Xpath的安装 XPath 是一门在 XML 文档中查找信息的语言 XPath 可用来在 XML 文档中对元素和属性进行遍历 简单来说 在进行网页信息爬取时 Xp
  • 抖音新版本抓包(绕过sslpinning证书校验)

    目录 前言 方案 frida 替换so Xposed 前言 当我们想要分析较新版本的接口时 会发现一个有趣的现象 无论是用Charles还是Fiddler 都会出现抓不到包的情况 如下图 这是因为使用SSL Pinning证书锁定技术 是一
  • mysql性能优化

    1 表字段要选择合适的属性 邮政编码设置char 6 就可以了 文本字段如省份或者性别用enum enum被当做数值型数据来处理 比文本类型快 2 建立索引 3 优化查询语句 查询条件里最好用in替代on 条件列表值如果连续 用betwee
  • 注意力机制——CAM、SAM、CBAM、SE

    CAM SAM CBAM详见 CBAM 即插即用的注意力模块 附代码 目录 1 什么是注意力机制 2 通道注意力机制 SE 1 Squeeze 2 Excitation 3 SE Block 3 CAM 4 SAM 5 CBAM 6 代码