Coordinate attention,SE,CBAM

2023-11-17

1、SE

因为普通卷积难以建模信道关系,SE考虑通道的相互依赖关系增强模型对信息通道的敏感性,同时全局平均池化可以帮助模型捕获全局信息。然而SE只考虑了内部通道信息而忽略了位置信息的重要性。

输入X首先经过全局平均池化
在这里插入图片描述
然后经过全连接层来捕获每个通道的重要性,再经过非线性层也就是使用ReLU激活函数来增加非线性因素,再经过全连接层来捕获每个通道的重要性。在这里插入图片描述

最后全连接层的输出用sigmoid归一化加权后和输入X通道乘法。

在这里插入图片描述

2、CA(coordinate  attention)

在这里插入图片描述
主要分为两步,位置信息的嵌入和协调注意力生成。
(1)位置信息嵌入:
全局平均池化通常用于通道注意中,它将全局位置信息压缩到通道信息中,很难保持位置信息。所在我们通过两个维度上的一维平均池化,这种转换允许获得该方向上的长期依赖关系和保持另一方向上的位置信息,有助于网络更加精准地定位感兴趣的对象。
给定输入x,我们使用(H,1)或(1,W)分别沿着水平坐标和垂直坐标对每个通道进行编码。
在这里插入图片描述
在这里插入图片描述
(2)协调注意力产生
沿着空间维度进行concat,然后二维卷积减少通道数,较小模型的复杂度,接着进行正则化BatchNorm和非线性激活
在这里插入图片描述
f沿着空间维度分成两个张量(c/r,1,H)和(c/r,1,w),然后分别经过卷积恢复到和输入x相同的通道数,最后经过sigmoid归一化加权。
在这里插入图片描述
协调注意力Y的输出可以表达成公式9,相当于将gh和gw作为注意力权重来使用。沿着水平方向和垂直方向的注意同时应用于输入张量,这两个注意力图中的每个元素都反映了感兴趣的对象是否存在于相对于的队和列中。
在这里插入图片描述
CA中,两个一维的全局池化操作,使得网络可以获得更大的一个感受野以及编码准确的空间位置信息。CA考虑了不同通道之间关系的重要性同时也考虑了编码空间信息。

3、CBAM

在这里插入图片描述
首先,CBAM使用squeeze通道数到1导致信息损失。然而在CA中使用适当的较少比率r来缩减通道数,避免过多的通道信息的丢失。
其次,CBAM使用7x7卷积来获得局部空间位置信息,而CA是通过两个一维全局池化,使得可以捕获到空间位置之间的长期依赖关系。
CBAM是通过对每个位置的多个通道取最大值和平均值来作为加权系数,这种加权只考虑局部范围的信息。

4、CA代码

大佬写的代码,先记录在这里,方便后期复习回顾方便。

import torch
import torch.nn as nn
#---------------------------------------------------#
#CA模块这个类的定义
#其参数共有三个分别为特征图像的高、宽以及通道数
#在经过CA模块的前后,特征图像的通道数并不会发生变化
#使用池化不会造成数据矩阵深度的改变,只会在高度和宽带上降低,达到降维的目的
#池化并不会改变特征图像的通道数
#---------------------------------------------------#
class CA_Block(nn.Module):
    def __init__(self, channels,reduction=16):
        super(CA_Block, self).__init__()

        self.avg_pool_x = nn.AdaptiveAvgPool2d((None, 1))    # 先后顺序为h,w 为1则为在x轴进行平均池化操作,x轴即为水平方向w,进而使w的值变为1
        self.avg_pool_y = nn.AdaptiveAvgPool2d((1, None))    #在y轴进行平均池化操作,y轴为垂直方向h,进而使h的值变为1
        self.conv_1x1 = nn.Conv2d(in_channels=channels, out_channels=channels // reduction, kernel_size=1, stride=1,
                                  bias=False)                         #图中的r即为reduction,进而使其输出的特征图像的通道数变为原先的1/16

        self.relu = nn.ReLU()    #relu激活函数
        self.bn = nn.BatchNorm2d(channels // reduction)   #二维的正则化操作

        self.F_h = nn.Conv2d(in_channels=channels // reduction, out_channels=channels, kernel_size=1, stride=1,
                             bias=False)              #将垂直方向上的通道数通过卷积来将其复原
        self.F_w = nn.Conv2d(in_channels=channels // reduction, out_channels=channels, kernel_size=1, stride=1,
                             bias=False)              #将水平方向上的通道数通过卷积来将其复原

        self.sigmoid_h = nn.Sigmoid()          #定义的sigmoid方法
        self.sigmoid_w = nn.Sigmoid()

    def forward(self, x):   #定义Tensor: 16,1024,13,13
        h=x.shape[2]  #13
        w=x.shape[3]  #13
        x_h  = self.avg_pool_x(x).permute(0, 1, 3, 2) # 16,1024,13,1 ->16,1024,1,13
        x_w = self.avg_pool_y(x)  #16,1024,1,13
        #现在x_h以及x_w的shape均为16,1024,1,13    两个16,1024,1,13 堆叠->16,1024,1,26 经过conv之后,成为16,64(1024/16),1,26 过BN以及ReLU进行位置信息编码
        x_cat_conv_relu = self.relu(self.bn(self.conv_1x1(torch.cat((x_h, x_w), 3))))
        print(x_cat_conv_relu.shape) #16,64,1,26     #具体的即为将维度3上的13与13相加,同时通过卷积调整其通道数为64,过BN以及ReLU进行位置信息编码
        x_cat_conv_split_h, x_cat_conv_split_w = x_cat_conv_relu.split([h,w], 3)   #按照维度3以及h和w的值将这个张量分开
        #print(x_cat_conv_split_h.shape)    #对于垂直方向,其输出的shape仍为16,64,1,13
        #print(x_cat_conv_split_w.shape)    #对于水平方向,其输出的shape仍为16,64,1,13

        s_h = self.sigmoid_h(self.F_h(x_cat_conv_split_h.permute(0, 1, 3, 2)))  #对于垂直方向的,为先进行一个转置,之后通过卷积达到所原先的通道数,再过sigmoid进行归一化处理
        #print(s_h.shape)  #为16,1024,13,1  #此为垂直方向  #16,64,1,13->16,64,13,1->conv->16,1024,13,1

        s_w = self.sigmoid_w(self.F_w(x_cat_conv_split_w)) #对于水平方向,使用卷积来达到原先的通道数,之后进行归一化的处理
        #print(s_w.shape)  #为16,1024,1,13  #此为水平方向  16,64,1,13->conv->16,1024,1,13
        out = x * (s_h.expand_as(x) * s_w.expand_as(x))# 生成attention map之后进行加权
        return out

F=torch.randn(16,1024,13,13)
print('As Begin!!')
print(F.shape)
CA=CA_Block(1024)
F=CA(F)
print('After Change!!')
print(F.shape)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Coordinate attention,SE,CBAM 的相关文章

  • jQuery验证码插件:jquery.idycode.js

    对于任何一个又评论功能的网站来说 验证码都是重中之重 没有验证码的话 用户就可以肆意刷评论 甚至是通过一些工具来操作 会对网络环境产生极大的危害 验证码这个词最早是在2002年由卡内基梅隆大学的路易斯 冯 安 Manuel Blum Nic
  • 标识符和关键字应该如何理解?

    思考 为什么语言中需要关键字和表示符 程序来源于生活 想想我们人类在生产生活过程中的一些语言使用都有其特定的含义 而每个事物或者事物的一些属性功能也都需要给予特定的语言符号来表示 故java语言的发明者们按照人类的方式创造除了一门值得大家学
  • 分布式、微服务概念

    目录 1 目前软件架构大致分类 2 各种架构技术方法 3 什么是微服务 4 微服务架构特点 5 什么是SOA 6 SOA架构特点 7 SOA架构和微服务架构的区别 8 ESB和微服务API网关 9 什么是分布式 10 什么是集群 11 负载
  • R语言使用cumsum函数计算向量数据的累加和(cumulative sum )

    R语言使用cumsum函数计算向量数据的累加和 cumulative sum 目录 R语言使用cumsum函数计算向量数据的累加和 cumulative sum

随机推荐

  • glsl version 300es 关键字

    参考链接 GLSL ES Specification 3 00 变量名 不能要以gl 开头 注释 或 关键字 void float int uint bool void function name float var name 1 uint
  • JS混淆技术探究及解密方法分析

    随着Web技术的快速发展 JavaScript被广泛应用于网页开发 移动应用开发等领域 然而 JavaScript代码很容易被反编译 解密 这给保护网站和应用程序的安全性带来了严重的挑战 为了解决这个问题 JS混淆技术应运而生 JS混淆就是
  • Redux持久化插件-解决刷新页面数据丢失问题

    最近在使用react的时候有用到redux 对数据进行全局的状态管理 但是发现和vuex一样会出现刷新之后数据丢失的问题 于是在github上面查阅了 redux persist 插件 使用redux persist进行持久化数据存储 通常
  • Rstudio更换主题/样式

    github项目地址 https github com gadenbuie rsthemes 安装 在 rstudio 的控制台console中数据 install packages devtools devtools install gi
  • 为什么会说:程序员年龄越大,越容易失业?

    在程序员的世界里 一直有一个传言 互联网公司没有35岁以上的中年人 从华为辞退34岁以上的员工 到腾讯辞退35 高级员工 似乎老程序员面临着 年龄危机 要时刻警惕在职场被 踢出 的危险 而国内其他很多职业 比如教师 医生 公务员 都在稳步发
  • python TCP通信雷达实时解析数据

    雷达解析程序 coding cp936 import socket import re class jiema def yushe3 self receve r receve av receve v receve h while True
  • 05-分布式计算框架

    目录 一 MapReduce 1 简介 2 原理 2 1 基本概念 2 2 程序执行过程 2 3 作业运行模式 二 Spark 1 简介 1 1 背景 1 2 概念 1 3 特点 2 原理 2 1 编程模型 2 2 运行模式 2 3 运行过
  • Java往字符串数组中追加一个数据

    public class Test public static void main String args 原字符串数组 String arr 原字符串数据1 原字符串数据2 执行数据添加 arr insert arr 需要追加的字符串数据
  • 三菱无机房电梯故障代码查询_【学习】 三菱J、A、K型扶梯介绍(上)

    导 读 抱前段时间我们一直在学习三菱直梯的介绍 今天就为大家分享三菱三种扶梯的介绍 分别为J型 K型和A型 三种扶梯的扶手驱动系统分别是 A型J型为直线式扶手驱动 K型自动扶梯采用摩擦轮式扶手驱动系统 当扶梯的驱动方式不相同时 我们维修的方
  • 解决docker启动mysql无法输入中文以及中文不显示或乱码问题

    前言 我在使用MySQL时 遇到了两个问题 一是在插入中文数据时 无法输入中文 二是在select的时候 查出来的中文数据是空的 因为插入时为空 然后我就使用Navicat连接数据库添加了中文数据 再到docker中查询 就发现了乱码问题
  • LeetCode64. 最小路径和

    题目大意 求出从网络左上角到右下角的一条代价最小的路径和 题目分析 使用动态规划 求出左上角到网络中每个点的代价最小路径和 假设当前要求的是point i j 点 那么它的值就应该是从左上角到它上面那个点point i 1 j 的路径和 与
  • 【技术应用】Qt Creator使用体会与小技巧

    Qt Creator是Qt官方的IDE 这个IDE为Qt编程人员提供了一个完整的开发环境 当然了 这个IDE是用Qt写的 也是免费的 这个IDE真正的编译部分使用了MinGW gcc compiler 也就是说 这个IDE主要的作用是协助开
  • 教务管理系统(免费源码获取)

    项目介绍 本系统使用springboot mybatis plus shiro lombok等技术 使用json传递数据 使用加盐加密对数据进行保存 前端页面使用vue搭建并打包放在static文件夹中 使用token保存当前用户 当用户登
  • chrome浏览器network报错:ERR_CERT_AUTHORITY_INVALID

    转载请注明作者 独孤尚良dugushangliang 出处 https blog csdn net dugushangliang article details 85275319 在访问局域网的某网址时 提示 您的连接不是私密连接 错误代码
  • 算法在ros中应用_烟火检测算法——中伟视界人工智能算法AI在智慧工地、石油中的应用_腾讯新闻...

    烟火检测算法功能说明及实现原理等 一 软件概述 视频智能分析基于目前先进的深度学习算法 通过大量的项目现场素材训练模型 通过本站大量采集的工作服素材 高精度的识别人 安全帽 工作服等识别 本项目主要两方面的算法 一是识别类的 二是行为分析
  • WPF中Datagrid其中一列使用图片显示

    实现效果 实现遇到的问题 当时想要实现如图所示 合格率 所示的效果 我的第一个想法就是使用wpf的转换器 可是接下来问题来了 我这个是通过数值来判断是否合格 什么控件可以做到既可以绑定图片类型的 又可以绑定数值类型的 还有此时的当值绑定肯定
  • 段、页、页框、页表、页表项

    段 页 页框 页表 页表项 分页式虚拟内存 页 页框 页表 页表项 段页式虚拟内存 分段 分页 段 段表 段表项 页 页框 页表 页表项 分页式虚拟内存 页 页框 页表 页表项 页 进程中的块 进程被分成许多大小相同的块 页号 页框 内存中
  • TS2769: Property 'xxx' does not exist on type 'IntrinsicAttributes & IntrinsicClassAttribute...

    用TypeScript开发React项目 在父子组件间传值时发生错误提示 class Page extends React Component render return div div
  • vue组件利用css var(--变量)实现动态修改伪类属性(::before、::after)

    如图所示 1 我们可以利用此属性实现vue组件动态传值 修改例如 before after等 伪类的背景色 背景图等属性值 因为vue利用无法直接在css中使用data里的变量 利用var 变量名 以及style中定义变量 其实此步是模仿
  • Coordinate attention,SE,CBAM

    1 SE 因为普通卷积难以建模信道关系 SE考虑通道的相互依赖关系增强模型对信息通道的敏感性 同时全局平均池化可以帮助模型捕获全局信息 然而SE只考虑了内部通道信息而忽略了位置信息的重要性 输入X首先经过全局平均池化 然后经过全连接层来捕获