神经网络学习小记录64——Pytorch 图像处理中注意力机制的解析与代码详解

2023-11-08

神经网络学习小记录64——Pytorch 图像处理中注意力机制的解析与代码详解

学习前言

注意力机制是一个非常有效的trick,注意力机制的实现方式有许多,我们一起来学习一下。
在这里插入图片描述

什么是注意力机制

注意力机制是深度学习常用的一个小技巧,它有多种多样的实现形式,尽管实现方式多样,但是每一种注意力机制的实现的核心都是类似的,就是注意力

注意力机制的核心重点就是让网络关注到它更需要关注的地方

当我们使用卷积神经网络去处理图片的时候,我们会更希望卷积神经网络去注意应该注意的地方,而不是什么都关注,我们不可能手动去调节需要注意的地方,这个时候,如何让卷积神经网络去自适应的注意重要的物体变得极为重要。

注意力机制就是实现网络自适应注意的一个方式。

一般而言,注意力机制可以分为通道注意力机制,空间注意力机制,以及二者的结合。
在这里插入图片描述

代码下载

Github源码下载地址为:
https://github.com/bubbliiiing/yolov4-tiny-pytorch

复制该路径到地址栏跳转。

注意力机制的实现方式

在深度学习中,常见的注意力机制的实现方式有SENet,CBAM,ECA等等。

1、SENet的实现

SENet是通道注意力机制的典型实现。
2017年提出的SENet是最后一届ImageNet竞赛的冠军,其实现示意图如下所示,对于输入进来的特征层,我们关注其每一个通道的权重,对于SENet而言,其重点是获得输入进来的特征层,每一个通道的权值。利用SENet,我们可以让网络关注它最需要关注的通道。

其具体实现方式就是:
1、对输入进来的特征层进行全局平均池化
2、然后进行两次全连接,第一次全连接神经元个数较少,第二次全连接神经元个数和输入特征层相同
3、在完成两次全连接后,我们再取一次Sigmoid将值固定到0-1之间,此时我们获得了输入特征层每一个通道的权值(0-1之间)。
4、在获得这个权值后,我们将这个权值乘上原输入特征层即可。
在这里插入图片描述
实现代码如下:

import torch
import torch.nn as nn
import math

class se_block(nn.Module):
    def __init__(self, channel, ratio=16):
        super(se_block, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
                nn.Linear(channel, channel // ratio, bias=False),
                nn.ReLU(inplace=True),
                nn.Linear(channel // ratio, channel, bias=False),
                nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

2、CBAM的实现

CBAM将通道注意力机制和空间注意力机制进行一个结合,相比于SENet只关注通道的注意力机制可以取得更好的效果。其实现示意图如下所示,CBAM会对输入进来的特征层,分别进行通道注意力机制的处理和空间注意力机制的处理
在这里插入图片描述
下图是通道注意力机制和空间注意力机制的具体实现方式:
图像的上半部分为通道注意力机制通道注意力机制的实现可以分为两个部分,我们会对输入进来的单个特征层,分别进行全局平均池化全局最大池化。之后对平均池化最大池化的结果,利用共享的全连接层进行处理,我们会对处理后的两个结果进行相加,然后取一个sigmoid,此时我们获得了输入特征层每一个通道的权值(0-1之间)。在获得这个权值后,我们将这个权值乘上原输入特征层即可。

图像的下半部分为空间注意力机制,我们会对输入进来的特征层,在每一个特征点的通道上取最大值和平均值。之后将这两个结果进行一个堆叠,利用一次通道数为1的卷积调整通道数,然后取一个sigmoid,此时我们获得了输入特征层每一个特征点的权值(0-1之间)。在获得这个权值后,我们将这个权值乘上原输入特征层即可。
在这里插入图片描述
实现代码如下:

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=8):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        # 利用1x1卷积代替全连接
        self.fc1   = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2   = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

class cbam_block(nn.Module):
    def __init__(self, channel, ratio=8, kernel_size=7):
        super(cbam_block, self).__init__()
        self.channelattention = ChannelAttention(channel, ratio=ratio)
        self.spatialattention = SpatialAttention(kernel_size=kernel_size)

    def forward(self, x):
        x = x * self.channelattention(x)
        x = x * self.spatialattention(x)
        return x

3、ECA的实现

ECANet是也是通道注意力机制的一种实现形式。ECANet可以看作是SENet的改进版。
ECANet的作者认为SENet对通道注意力机制的预测带来了副作用捕获所有通道的依赖关系是低效并且是不必要的
在ECANet的论文中,作者认为卷积具有良好的跨通道信息获取能力

ECA模块的思想是非常简单的,它去除了原来SE模块中的全连接层,直接在全局平均池化之后的特征上通过一个1D卷积进行学习。

既然使用到了1D卷积,那么1D卷积的卷积核大小的选择就变得非常重要了,了解过卷积原理的同学很快就可以明白,1D卷积的卷积核大小会影响注意力机制每个权重的计算要考虑的通道数量。用更专业的名词就是跨通道交互的覆盖率

如下图所示,左图是常规的SE模块,右图是ECA模块。ECA模块用1D卷积替换两次全连接。
在这里插入图片描述
实现代码如下:

class eca_block(nn.Module):
    def __init__(self, channel, b=1, gamma=2):
        super(eca_block, self).__init__()
        kernel_size = int(abs((math.log(channel, 2) + b) / gamma))
        kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
        
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False) 
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
        y = self.sigmoid(y)
        return x * y.expand_as(x)

注意力机制的应用

注意力机制是一个即插即用的模块,理论上可以放在任何一个特征层后面,可以放在主干网络,也可以放在加强特征提取网络。

由于放置在主干会导致网络的预训练权重无法使用,本文以YoloV4-tiny为例,将注意力机制应用加强特征提取网络上。

如下图所示,我们在主干网络提取出来的两个有效特征层上增加了注意力机制,同时对上采样后的结果增加了注意力机制
在这里插入图片描述
实现代码如下:

attention_block = [se_block, cbam_block, eca_block]

#---------------------------------------------------#
#   特征层->最后的输出
#---------------------------------------------------#
class YoloBody(nn.Module):
    def __init__(self, anchors_mask, num_classes, phi=0):
        super(YoloBody, self).__init__()
        self.phi            = phi
        self.backbone       = darknet53_tiny(None)

        self.conv_for_P5    = BasicConv(512,256,1)
        self.yolo_headP5    = yolo_head([512, len(anchors_mask[0]) * (5 + num_classes)],256)

        self.upsample       = Upsample(256,128)
        self.yolo_headP4    = yolo_head([256, len(anchors_mask[1]) * (5 + num_classes)],384)

        if 1 <= self.phi and self.phi <= 3:
            self.feat1_att      = attention_block[self.phi - 1](256)
            self.feat2_att      = attention_block[self.phi - 1](512)
            self.upsample_att   = attention_block[self.phi - 1](128)

    def forward(self, x):
        #---------------------------------------------------#
        #   生成CSPdarknet53_tiny的主干模型
        #   feat1的shape为26,26,256
        #   feat2的shape为13,13,512
        #---------------------------------------------------#
        feat1, feat2 = self.backbone(x)
        if 1 <= self.phi and self.phi <= 3:
            feat1 = self.feat1_att(feat1)
            feat2 = self.feat2_att(feat2)

        # 13,13,512 -> 13,13,256
        P5 = self.conv_for_P5(feat2)
        # 13,13,256 -> 13,13,512 -> 13,13,255
        out0 = self.yolo_headP5(P5) 

        # 13,13,256 -> 13,13,128 -> 26,26,128
        P5_Upsample = self.upsample(P5)
        # 26,26,256 + 26,26,128 -> 26,26,384
        if 1 <= self.phi and self.phi <= 3:
            P5_Upsample = self.upsample_att(P5_Upsample)
        P4 = torch.cat([P5_Upsample,feat1],axis=1)

        # 26,26,384 -> 26,26,256 -> 26,26,255
        out1 = self.yolo_headP4(P4)
        
        return out0, out1
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

神经网络学习小记录64——Pytorch 图像处理中注意力机制的解析与代码详解 的相关文章

随机推荐

  • java的类学习

    先看下面的代码 span style font size 18px public class Static public int a public String SS public Static a 35 public void Test
  • 大专毕业,从6个月开发转入测试岗位的一些感悟——写在测试岗位3年之际

    时光飞逝 我从前端开发岗位转入测试岗位已经三年了 这期间从迷茫到熟悉 到强化 到熟练 到总结 感受还是很深的 三年前的某一个晚上 我正准备下班回家 我们的项目经理把我叫到办公司和我谈话 谈了很多 具体说什么不记得了 大体意思就是说测试组缺人
  • Linux基础知识点总结

    作者 小刘在C站 个人主页 小刘主页 每天分享云计算网络运维课堂笔记 努力不一定有回报 但一定会有收获加油 一起努力 共赴美好人生 夕阳下 是最美的绽放 树高千尺 落叶归根人生不易 人间真情 目录 前言 Linux 安装系统 服务管理
  • 关于uniapp将H5网页编译为微信小程序样式错乱

    在控制台看了下出现警告Some selectors are not allowed in component wxss including tag name selectors ID selectors and attribute sele
  • jsonpath - 使用 JSONPath 解析 JSON 完整内容详解

    目录 1 操作符 2 函数 3 过滤器运算符 4 Java操作示例 5 阅读文档 何时返回 谓词 6 调整配置 7 Java操作示例源码 json Java 输出 示例2 Java 输出 过滤器示例 Java 输出 JsonPath是一种简
  • python 实现信息熵、条件熵、信息增益、基尼系数

    在这里插入代码片注 该代码为慕课网课程中老师讲解 python import pandas as pd import numpy as np import math 计算信息熵 def getEntropy s 找到各个不同取值出现的次数
  • 相机系统综述 —— ISP

    转 http kernel meizu com camera isp intro html ISP Image Signal Processor 即图像信号处理器 用于处理图像信号传感器输出的图像信号 它在相机系统中占有核心主导的地位 是构
  • 网络 — MB/s、Mb/s、Mbps、Mbit/s、Kbps

    MB s 兆字节每秒 Mb s 兆比特每秒 Mbps 兆比特每秒 Mbit s 兆比特每秒 Kbps 千比特每秒 1Byte 字节 8 bit 比特 1B 8b 1MB 百万字节也称兆字节 8 Mb 1Mb 0 125MB 1Kb 1024
  • java yyyy-mm-dd 日期格式_Java中的日期时间格式化

    原标题 Java中的日期时间格式化 1 Java日期时间格式化的概念 我们在日常的开发过程中常常会碰到关于日期时间的计算与存储问题 比如我们要把一个当前时间类型转换成字符串类型 我们会直接使用Util包下的Date数据类型 java uti
  • unity Screen.width, Screen.height

    如果事从编译器调用这个函数 获取的值不正确 获取的是editorwindow的大小
  • 国产自主研发,完全可控 IDE!

    最近 互联网上逐渐有些热闹 日本福岛核废水排海计划 中国自主研发 IDE 作为一名开发者 自然好奇国产自主研发的 IDE 不禁夸赞吾国威武 某方面领域越来越强 该产品名为 CEC IDE 是由数字广东公司联合麒麟软件打造国内首款适配国产操作
  • 乐高叉车wedo教案_24乐高教育wedo编程摩天轮教案

    1 人小组 时长 1 5 活动目标 巩固对三角形结构的稳定性的认识 认识重力的方向是垂直向下 活动准备 9886 套装 摩天轮图片 活动过程 备注 联 系 20 一 互动问大家去游乐园座过摩天轮没有 二 看视频了解摩天轮能座在上面旋转 很高
  • dotnet java_我所理解的JAVA和 DotNet

    Java 从实用性来讲 Java 可以说是第一种 网页 语言 尽管像 Perl 等语言会突然发现它们处理字符串的能力在恢复价值和发送 HTML 到网页浏览器上是天生的 但是 Java 是最早发现自己是根植于浏览器中 最初是在一个有趣但却非常
  • Python基础知识点总结

    https www cnblogs com wu chao p 8421708 html Python中pass语句的作用是什么 pass语句不会执行任何操作 一般作为占位符或者创建占位程序 Python是如何进行类型转换的 Python提
  • python下的pyecharts应用4----绘制cpu折线图

    要求 1 截止到运行一刻 2 每秒钟监测 3 绘制折线图 设计 获取cpu的代码如下 1 获取系统cpu占有率的信息 import psutil import time 隔1s绘制cpu的占有率 gt 持久化的保存 如何将时间和对应的cpu
  • html5 页面可以上下滚动条,h5页面上下左右滑动

    var startX 0 startY 0 operate 0 backDom addEventListener touchstart function evt evt preventDefault var touch evt touche
  • js身份证号校验

    if card console log 请输入身份证号 身份证号不能为空 return false if isCardNo card false console log 您输入的身份证号码不正确 return false 检查省份 if c
  • 西门子PLC的常见的通讯方式

    1 PPI通信 T PPI协议是S7 200cpu最基本的通信方式 S7 200cpu的默认通信方式可通过原端口通信 西门子PLC是一种专为工业环境应用而设计的数字操作电子系统 可编程存储器 存储逻辑操作 顺序控制 定时 计数 算术操作等指
  • Ubuntu常用命令汇集

    ubuntu常用命令汇集 文章目录 一 文件组织结构 二 常用命令 三 权限 一 文件组织结构 为根目录 为系统最基本的目录 home下有用户名的文件夹 该文件夹就是 为主目录 为日常使用的目录 命令在终端中输入 需要注意当前所在的文件夹
  • 神经网络学习小记录64——Pytorch 图像处理中注意力机制的解析与代码详解

    神经网络学习小记录64 Pytorch 图像处理中注意力机制的解析与代码详解 学习前言 什么是注意力机制 代码下载 注意力机制的实现方式 1 SENet的实现 2 CBAM的实现 3 ECA的实现 注意力机制的应用 学习前言 注意力机制是一