注意力模型CBAM

2023-11-17

论文:CBAM: Convolutional Block Attention Module 

 

Convolutional Block Attention Module (CBAM) 表示卷积模块的注意力机制模块。是一种结合了空间(spatial)和通道(channel)的注意力机制模块。相比于senet只关注通道(channel)的注意力机制可以取得更好的效果。

 

基于传统VGG结构的CBAM模块。需要在每个卷积层后面加该模块。

基于shortcut结构的CBAM模块。例如resnet50,该模块在每个resnet的block后面加该模块。

 

Channel attention module:

 

将输入的featuremap,分别经过基于width和height的global max pooling 和global average pooling,然后分别经过MLP。将MLP输出的特征进行基于elementwise的加和操作,再经过sigmoid激活操作,生成最终的channel attention featuremap。将该channel attention featuremap和input featuremap做elementwise乘法操作,生成Spatial attention模块需要的输入特征。

其中,seigema为sigmoid操作,r表示减少率,其中W0后面需要接RELU激活。

 

Spatial attention module:

 

将Channel attention模块输出的特征图作为本模块的输入特征图。首先做一个基于channel的global max pooling 和global average pooling,然后将这2个结果基于channel 做concat操作。然后经过一个卷积操作,降维为1个channel。再经过sigmoid生成spatial attention feature。最后将该feature和该模块的输入feature做乘法,得到最终生成的特征。

其中,seigema为sigmoid操作,7*7表示卷积核的大小,7*7的卷积核比3*3的卷积核效果更好。

 

The code:

def cbam_module(inputs,reduction_ratio=0.5,name=""):
    with tf.variable_scope("cbam_"+name, reuse=tf.AUTO_REUSE):
        batch_size,hidden_num=inputs.get_shape().as_list()[0],inputs.get_shape().as_list()[3]

        maxpool_channel=tf.reduce_max(tf.reduce_max(inputs,axis=1,keepdims=True),axis=2,keepdims=True)
        avgpool_channel=tf.reduce_mean(tf.reduce_mean(inputs,axis=1,keepdims=True),axis=2,keepdims=True)
        
        maxpool_channel = tf.layers.Flatten()(maxpool_channel)
        avgpool_channel = tf.layers.Flatten()(avgpool_channel)
        
        mlp_1_max=tf.layers.dense(inputs=maxpool_channel,units=int(hidden_num*reduction_ratio),name="mlp_1",reuse=None,activation=tf.nn.relu)
        mlp_2_max=tf.layers.dense(inputs=mlp_1_max,units=hidden_num,name="mlp_2",reuse=None)
        mlp_2_max=tf.reshape(mlp_2_max,[batch_size,1,1,hidden_num])

        mlp_1_avg=tf.layers.dense(inputs=avgpool_channel,units=int(hidden_num*reduction_ratio),name="mlp_1",reuse=True,activation=tf.nn.relu)
        mlp_2_avg=tf.layers.dense(inputs=mlp_1_avg,units=hidden_num,name="mlp_2",reuse=True)
        mlp_2_avg=tf.reshape(mlp_2_avg,[batch_size,1,1,hidden_num])

        channel_attention=tf.nn.sigmoid(mlp_2_max+mlp_2_avg)
        channel_refined_feature=inputs*channel_attention

        maxpool_spatial=tf.reduce_max(inputs,axis=3,keepdims=True)
        avgpool_spatial=tf.reduce_mean(inputs,axis=3,keepdims=True)
        max_avg_pool_spatial=tf.concat([maxpool_spatial,avgpool_spatial],axis=3)
        conv_layer=tf.layers.conv2d(inputs=max_avg_pool_spatial, filters=1, kernel_size=(7, 7), padding="same", activation=None)
        spatial_attention=tf.nn.sigmoid(conv_layer)

        refined_feature=channel_refined_feature*spatial_attention

    return refined_feature

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

注意力模型CBAM 的相关文章

随机推荐

  • 使用tp5内cache缓存,存储手机短信验证码

    设置手机短信验证码缓存方法 设置手机短信验证码缓存 User JW Email jw 333 163 com Date param data cache public function setRegSmsCache data cache C
  • Gitee API的使用|如何批量删除Gitee下的所有仓库

    前言 那么这里博主先安利一些干货满满的专栏了 首先是博主的高质量博客的汇总 这个专栏里面的博客 都是博主最最用心写的一部分 干货满满 希望对大家有帮助 高质量博客汇总https blog csdn net yu cblog category
  • python语音播报

    python3 pip install pyttsx3 python2 pip install pyttsx 文本转语音 import pyttsx3 import time str Come on Catherine engine pyt
  • java 强密码验证策略工具类

    java 强密码验证策略工具类 package com neusoft caeid common utils import java util regex Matcher import java util regex Pattern aut
  • ChatGPT论文考试满绩,高等教育该如何应对人工智能挑战?

    近日 ChatGPT引发热议 一方面 ChatGPT表现亮眼 有大学生利用ChatGPT在开卷课堂上取得满绩的优异成绩 另一方面 部分院校 学术期刊却对ChatGPT在高等教育领域的推进保持谨慎态度 甚至有高校明确禁止这项工具技术的使用 那
  • 算法题:Rod Cutting

    算法题 Rod Cutting 一 题目 二 代码 三 结果 一 题目 二 代码 lengths 1 1 3 4 lengths 5 4 4 2 2 8 def rodOffcut lengths resut resut append le
  • Android自定义控件--如何在XML文件中使用自定义属性

    前言 你好 我是Cici 这几天在做一个小项目的时候 用到了自定义控件 为了方便在XML中进行配置 于是需要用到自定义属性 特此记下用法 方便复习的同时也希望对大家有所帮助 一 为什么需要自定义控件 Android本身提供了很多控件 比如T
  • 1024 视频拼接

    题目描述 你将会获得一系列视频片段 这些片段来自于一项持续时长为 T 秒的体育赛事 这些片段可能有所重叠 也可能长度不一 视频片段 clips i 都用区间进行表示 开始于 clips i 0 并于 clips i 1 结束 我们甚至可以对
  • pyinstaller打包Transformers 报错No such file or directory

    问题描述 Traceback most recent call last File transformers utils import utils py line 1086 in get module File importlib init
  • Go开发者路线图2019,请收下这份指南

    Go是Google开发的一种静态 强类型 编译型 并发型 并具有垃圾回收功能的类C编程语言 2009以开源项目的形式发布 2012年发布1 0稳定版本 距今已经十年了 其性能类似于Java和C 但速度极快 适合搭载于web服务器 用于高性能
  • LeetCode1652. 拆炸弹

    题目描述 1652 拆炸弹 力扣 LeetCode 题目描述看的不是很清楚 直接看用例 这道题是简单题 取模 防止数组访问越界 C语言代码如下 int decrypt int code int codeSize int k int retu
  • 数据分桶

    数据分桶是一种数据预处理技术 用于减少次要观察误差的影响 是一种将多个连续值分组为较少数量的 桶 的方法 例如 例如我们有一组关于人年龄的数据 如下图所示 现在我们希望将他们的年龄分组到更少的间隔中 可以通过设置一些条件来实现 分桶的数据不
  • (Java)leetcode-945 Minimum Increment to Make Array Unique(使数组唯一的最小增量)

    题目描述 给定整数数组 A 每次 move 操作将会选择任意 A i 并将其递增 1 返回使 A 中的每个值都是唯一的最少操作次数 示例 1 输入 1 2 2 输出 1 解释 经过一次 move 操作 数组将变为 1 2 3 示例 2 输入
  • 在ubuntu上搭建文件服务器

    首先需要在ubuntu上下载好文件资源 一共是三个资源 在下载资源之前建议将git和nginx安装好 在本教程中将会用到 ngnix http nginx org download nginx 1 12 2 tar gz 利用winscp上
  • osg学习(五十一)Warning: detected OpenGL error ‘invalid operation‘ at after RenderBin::draw(..)

    原因是什么 这个错误只出现一次 并且是在第一帧时出现 Warning detected OpenGL error invalid operation after applying attribute Viewport 04292398 应该
  • 华为OD笔试题:工作安排 --- 100分 (思路+python代码)

    题目 小明每周上班都会拿到自己的工作清单 工作清单内包含n项工作 每项工作都有对应的耗时时长 单位h 和报酬 工作的总报酬为所有已完成工作的报酬之和 那么请你帮小明安排一下工作 保证小明在指定的工作时间内工作收入最大化 输入描述 输入的第一
  • 每天进步一点点——Linux中的线程局部存储(二)

    转载 http blog csdn net cywosp article details 26876231 在Linux中还有一种更为高效的线程局部存储方法 就是使用关键字 thread来定义变量 thread是GCC内置的线程局部存储设施
  • TreeMap 的特点

    TreeMap基于红黑树实现 增删改查的平均和最差时间复杂度均为O 最大特点时Key有序 key必须实现Comparable接口或者提供Comparator比较器 所以key不允许为null HashMap 依靠hashCode和equal
  • web移动端适配方案以及不同单位之间的区别

    web移动端适配方案 第一种 rem实现原理 rem是一个倍数单位 它是基于html的font size的倍数 只要我们在不同的设备上设置一个合适的初始值 当设备发生变化font size就会自动等比适配大小 从而在不同的设备上表现统一 如
  • 注意力模型CBAM

    论文 CBAM Convolutional Block Attention Module Convolutional Block Attention Module CBAM 表示卷积模块的注意力机制模块 是一种结合了空间 spatial 和