深度学习之注意力机制详解(Attention)

2023-11-05

深度学习之注意力机制详解

前言

深度学习attention机制是对人类视觉注意力机制的仿生,本质上是一种资源分配机制。生理原理就是人类视觉注意力能够以高分辨率接收于图片上的某个区域,并且以低分辨率感知其周边区域,并且视点能够随着时间而改变。换而言之,就是人眼通过快速扫描全局图像,找到需要关注的目标区域,然后对这个区域分配更多注意,目的在于获取更多细节信息和抑制其他无用信息。提高 representation 的高效性。
图1  注意力可视化
在神经网络中,attention机制可以它认为是一种资源分配的机制,可以理解为对于原本平均分配的资源根据attention对象的重要程度重新分配资源,重要的单位就多分一点,不重要或者不好的单位就少分一点,在深度神经网络的结构设计中,attention所要分配的资源基本上就是权重了。
在这里插入图片描述
图 2 红色代表高注意力,蓝色代表低注意力
视觉注意力分为几种,核心思想是基于原有的数据找到其之间的关联性,然后突出其某些重要特征,例如处理检测任务时,我们可以让注意力集中在建筑物物体上,从而提高识别效率,有通道注意力,像素注意力,多阶注意力等

一、自注意力机制(self-Attention)

**那么如何实现注意力机制呢?**方差描述的是单个随机变量与其均值之间的偏差,而协方差描述的是两个随机变量之间的相似性。如果两个随机变量的分布相似,它们的协方差很大。否则,它们的协方差很小。如果我们将feature map中的每个像素作为一个随机变量,计算所有像素之间的配对协方差,我们可以根据每个预测像素在图像中与其他像素之间的相似性来增强或减弱每个预测像素的值在训练和预测时使用相似的像素,忽略不相似的像素。这种机制叫做自注意力
具体实施细节
Self-Attention可以理解将队列和一组值与输入对应,即形成querry,key,value向output的映射,output可以看作是value的加权求和,加权值则是由Self-Attention来得出的。在self-attention中,有3个不同的向量,它们分别是Query向量,Key向量和Value向量,长度相同。它们是通过3个不同的权值矩阵由input X乘以三个不同的权值矩阵得到,其中三个矩阵的尺寸也是相同的,过程图如下:
在这里插入图片描述
首先输入高度为H、宽度为w的特征图x,然后将X reshape为三个一维向量A、B和c,将A和B相乘得到大小为HWxHW的协方差矩阵。最后,我们用协方差矩阵和C相乘,得到D并对它reshape,得到输出特征图Y,并从输入X进行残差连接。这里D中的每一项都是输入X的加权和,权重是像素和彼此之间的协方差。
上述为大致流程,具体实现过程如下:
在这里插入图片描述

  1. 将输入特征图经过三个不同的1x1卷积,得到q(query),k(key),v(value)三个向量;
  2. 将q,k,v三个向量reshape为二维矩阵;
  3. 将q向量与k的转置矩阵相乘,经过softmax得到权重系数A;
  4. 为了梯度的稳定,Transformer使用了score归一化,即除以根号c;
  5. 将矩阵A乘以v,得到加权的每个输入向量的评分y
  6. 将y向量reshape为三维特征向量;
  7. 相加之后得到最终的输出结果z。
    其中
    在这里插入图片描述
    Z输出特征图
    F(X)为残差映射
    A为权重矩阵

二、代码

class BAM(nn.Module):
    """ Basic self-attention module
    """

    def __init__(self, in_dim, ds=8, activation=nn.ReLU):
        super(BAM, self).__init__()
        self.chanel_in = in_dim
        self.key_channel = self.chanel_in //8
        self.activation = activation
        self.ds = ds  #
        self.pool = nn.AvgPool2d(self.ds)
        #print('ds: ',ds)
        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax = nn.Softmax(dim=-1)  #

    def forward(self, input):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature
                attention: B X N X N (N is Width*Height)
        """
        #1,256,8,8
        x = self.pool(input)
        #print("1",x.shape)

        #1,256,8,8
        m_batchsize, C, width, height = x.size()


        #将1,256,8,8 变为 1,64,32
        proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)  # B X C X (N)/(ds*ds)
        #print("2",proj_query.shape)

        # 将1,256,8,8 变为 1,64,64
        proj_key = self.key_conv(x).view(m_batchsize, -1, width * height)  # B X C x (*W*H)/(ds*ds)
        #print(proj_key.shape)

        #1,64,32x1,32,64----1,64,64
        energy = torch.bmm(proj_query, proj_key)  # transpose check
        #print(energy.shape)

        #计算A
        energy = (self.key_channel**-.5) * energy
        #print("3",energy.shape)

        #经过softmax得到相似矩阵A
        attention = self.softmax(energy)  # BX (N) X (N)/(ds*ds)/(ds*ds)
        #print(attention.shape)

        #reshape V  256x64
        proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)  # B X C X N
        #print(proj_value.shape)

        #256x64 x  64x64  得到 256 64
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))

        #reshape为256 x8x8
        out = out.view(m_batchsize, C, width, height)

        #经过双线性插值恢复为原图大小
        out = F.interpolate(out, [width*self.ds,height*self.ds])
        out = out + input
        return out

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

深度学习之注意力机制详解(Attention) 的相关文章

随机推荐

  • IDEA如何手动配置插件

    打开idea 点击File进入setting设置 点击plugins 点击设置按钮 这样你就会看到这样一个界面 别担心这是你的插件路径 这是idea插件下载官网https plugins jetbrains com 在这里可以随意下载插件
  • 关于Maven报错的一些解决办法(别处贴的)

    1 警告 The tag handler class for s form org apache struts2 views jsp ui FormTag was not found on the Java Build Path 这个问题终
  • 原生JavaScript实现ajax异步请求代码

    jQuery封装了JavaScript的一些常用方法 而jQuery中的 ajax get post 是比较常用的方法 也是大家最熟悉 最常用的 但是在面试时 通常面试官 会要求你手写原生ajax异步请求的代码 此时即便你的jquery学的
  • Solr删除文档数据

    使用控制台删除solr的无用数据 目前我使用了两种方式 001 登录你的solr地址 我的地址为 http localhost 8983 solr 如下图所示 上图箭头处选择你的my core 我的mycore为damsearch 002
  • [Python图像处理] 二.OpenCV+Numpy库读取与修改像素

    该系列文章是讲解Python OpenCV图像处理知识 前期主要讲解图像入门 OpenCV基础用法 中期讲解图像处理的各种算法 包括图像锐化算子 图像增强技术 图像分割等 后期结合深度学习研究图像识别 图像分类应用 希望文章对您有所帮助 如
  • 恐龙酷跑(python)

    恐龙酷跑小游戏 摘要 一 引言 二 系统结构 三 实现代码 四 运行结果 五 总结和展望 摘要 论述了Python语言中Pygame库的框架结构和一些常用的该库API 使用Python库进行2D游戏开发时需要注意的事项 以及进行2D游戏开发
  • 【Docker】Docker安装telnet

    文章目录 1 概述 1 概述 在使用docker容器时 有时候里边没有安装telnet 敲vim命令时提示说 telnet command not found 这个时候就需要安装vim 可是当你敲apt get install telnet
  • error LNK2019: 无法解析的外部符号 Netbios,该符号在函数 “unsigned char * __cdecl getMACAddress(unsigned char * cons

    我已经正确的加了库 头文件也能找到了 但是还是出现这个问题 说明还是库有问题 原因是我加入的是dcmtk库 是通信有关的 所以还需要在头文件位置加上如下的代码 pragma comment lib netapi32 lib
  • 元数据编辑器--(坑集锦)

    概述 Angular中的输入输出是通过注解 Input和 Output来标识 它位于组件控制器的属性上方 输入输出针对的对象是父子组件 我借鉴的博客地址 https segmentfault com a 1190000007890167 1
  • 人像抠图学习笔记

    目录 人脸分割BiseNetV2 u2net 人脸分割BiseNetV2 宣传的 BiSeNet V2出来了 72 6 的mIOU 156FPS的速度 让分割飞起来 模型30多m TensorFlow平台的 cpu版时间80ms 人脸抠图
  • 两个排序后数组中是否存在相同数字

    因为两个数组都是排好序的 所以只要一次遍历就行了 首先设两个下标 分别初始化为两个数组的起始地址 依次向前推进 推进的规则是比较两个数组中的数字 小的那个数组的下标向前推进一步 直到任何一个数组的下标到达数组末尾时 如果这时还没碰到相同的数
  • Linux下的find指令

    一 概述 因为Linux下面一切皆文件 经常需要搜索某些文件来编写 所以对于linux来说find是一条很重要的命令 linux下面的find指令用于在目录结构中搜索文件 并执行指定的操作 它提供了相当多的查找条件 功能很强大 在不指定查找
  • 第23章组织通用管理

    组织通用管理是项目管理的关键前提和基础 它为项目管理提供思想路线和基本原则与方法 项目管理则是通用管理方法在特定场景下的具体表现 在把项目管理方法运用于实际工作的时候总会表现其通用的方法 反过来说 通用的方法又必定会支配和制约着人们对项目管
  • 理解 Linux 网络栈:Linux 网络协议栈简单总结

    1 Linux 网络路径 1 1 发送端 1 1 1 应用层 1 Socket 应用层的各种网络应用程序基本上都是通过 Linux Socket 编程接口来和内核空间的网络协议栈通信的 Linux Socket 是从 BSD Socket
  • Window10文件在另一个程序中打开无法删除

    1 打开任务管理 点详细信息 2 打开性能 gt 3 打开下方的 资源监视器 4 句柄中输入文件名 5 鼠标右键结束进程 就可以删除文件啦
  • matlab判断cell为空,问与答1:在VBA代码中如何判断单元格是否为空?

    问 如下图所示的工作表 我希望使用VBA代码将空行的背景色设置为灰色 以便于查看 即将上半部分的工作表变为下半部分的样式 我需要判断某行的单元格为空 然后将该行相应的单元格背景色设置为灰色 如何判断单元格是否为空 答 先看看实现所需效果的代
  • 普通人可以做的七个小众副业,让你告别死工资

    现在有什么副业又简单又可以赚得一定的收入呢 当然是有的 下面分享七个适合普通人操作的七个小众副业 1 手工制品 现在手工制品越来越贵 可以做的种类也很多 比如粘土 针织 滴胶 奶油 手机壳 发夹之类的 又是兴趣 又能赚钱 一举两得 可以在一
  • OPF 难解case

    14bus case 当前线路功率极限为Slmax 调整为0 89 Slmax OPF收敛 调整为0 93 1 0001 Slmax OPF不收敛 调整为1 1 Slmax OPF 收敛 其实整个计算过程中 line flow 是not a
  • 数据结构中的顺序表和链表

    目录 1 顺序表 1 1 存储结构 1 2 顺序表特点 1 3 顺序表应用场景 2 链表 2 1 存储结构 最近在复习数据结构中的线性表 下面总结一下顺序表和链表的区别 1 顺序表 线性表的顺序存储称为顺序表 顺序表使得逻辑地址连续的元素在
  • 深度学习之注意力机制详解(Attention)

    深度学习之注意力机制详解 前言 一 自注意力机制 self Attention 二 代码 前言 深度学习attention机制是对人类视觉注意力机制的仿生 本质上是一种资源分配机制 生理原理就是人类视觉注意力能够以高分辨率接收于图片上的某个