当可变形注意力机制引入Vision Transformer

2023-10-27

【GiantPandaCV导语】通过在Transformer基础上引入Deformable CNN中的可变性能力,在降低模型参数量的同时提升获取大感受野的能力,文内附代码解读。

引言

Transformer由于其更大的感受野能够让其拥有更强的模型表征能力,性能上超越了很多CNN的模型。

然而单纯增大感受野也会带来其他问题,比如说ViT中大量使用密集的注意力,会导致需要额外的内存和计算代价,特征很容易被无关的部分所影响。

而PVT或者Swin Transformer中使用的sparse attention是数据不可知的,会影响模型对长距离依赖的建模能力。

由此引入主角:Deformabel Attention Transformer的两个特点:

  • data-dependent: key和value对的位置上是依赖于数据的。
  • 结合Deformable 方式能够有效降低计算代价,提升计算效率。

下图展示了motivation:

图中比较了几种方法的感受野,其中红色星星和蓝色星星表示的是不同的query。而实线包裹起来的目标则是对应的query参与处理的区域。

(a) ViT对所有的query都一样,由于使用的是全局的注意力,所以感受野覆盖全图。

(b) Swin Transformer中则使用了基于window划分的注意力。不同query处理的位置是在一个window内部完成的。

© DCN使用的是3x3卷积核基础上增加一个偏移量,9个位置都学习到偏差。

(d) DAT是本文提出的方法,由于结合ViT和DCN,所有query的响应区域是相同的,但同时这些区域也学习了偏移量。

方法

先回忆一下Deformable Convolution:

简单来讲是使用了额外的一个分支回归offset,然后将其加载到坐标之上得到合适的目标。

在回忆一下ViT中的Multi-head Self-attention:

q = x W q , k = x W k , v = x W v , z ( m ) = σ ( q ( m ) k ( m ) ⊤ / d ) v ( m ) , m = 1 , … , M , z =  Concat  ( z ( 1 ) , … , z ( M ) ) W o , z l ′ = MHSA ⁡ ( LN ⁡ ( z l − 1 ) ) + z l − 1 , z l = MLP ⁡ ( LN ⁡ ( z l ′ ) ) + z l ′ , \begin{aligned} q&=x W_{q}, k=x W_{k}, v=x W_{v}, \\ z^{(m)}&=\sigma\left(q^{(m)} k^{(m) \top} / \sqrt{d}\right) v^{(m)}, m=1, \ldots, M, \\ z&=\text { Concat }\left(z^{(1)}, \ldots, z^{(M)}\right) W_{o}, \\ z_{l}^{\prime} &=\operatorname{MHSA}\left(\operatorname{LN}\left(z_{l-1}\right)\right)+z_{l-1}, \\ z_{l} &=\operatorname{MLP}\left(\operatorname{LN}\left(z_{l}^{\prime}\right)\right)+z_{l}^{\prime}, \end{aligned} qz(m)zzlzl=xWq,k=xWk,v=xWv,=σ(q(m)k(m)/d )v(m),m=1,,M,= Concat (z(1),,z(M))Wo,=MHSA(LN(zl1))+zl1,=MLP(LN(zl))+zl,

有了以上铺垫,下图就是本文最核心的模块Deformable Attention。

  • 左边这部分使用一组均匀分布在feature map上的参照点
  • 然后通过offset network学习偏置的值,将offset施加于参照点中。
  • 在得到参照点以后使用bilinear pooling操作将很小一部分特征图抠出来,作为k和v的输入
x_sampled = F.grid_sample(
input=x.reshape(B * self.n_groups, self.n_group_channels, H, W), 
grid=pos[..., (1, 0)], # y, x -> x, y
mode='bilinear', align_corners=True) # B * g, Cg, Hg, Wg
  • 之后将得到的Q,K,V执行普通的self-attention, 并在其基础上增加relative position bias offsets。

其中offset network构建很简单, 代码和图示如下:

  self.conv_offset = nn.Sequential(
      nn.Conv2d(self.n_group_channels, self.n_group_channels, kk, stride, kk//2, groups=self.n_group_channels),
      LayerNormProxy(self.n_group_channels),
      nn.GELU(),
      nn.Conv2d(self.n_group_channels, 2, 1, 1, 0, bias=False)
  )

最终网络结构为:

具体参数如下:

实验

实验配置:300epoch,batch size 1024, lr=1e-3,数据增强大部分follow DEIT

  • 分类结果:

目标检测数据集结果:

语义分割:

  • 消融实验:

  • 可视化结果:COCO

这个可视化结果有点意思,如果是分布在背景上的点大部分变动不是很大,即offset不是很明显,但是目标附近的点会存在一定的集中趋势(ps:这种趋势没有Deformable Conv中的可视化结果明显)

代码

  • 生成Q
  B, C, H, W = x.size()
  dtype, device = x.dtype, x.device
  
  q = self.proj_q(x)
  • offset network前向传播得到offset
  q_off = einops.rearrange(q, 'b (g c) h w -> (b g) c h w', g=self.n_groups, c=self.n_group_channels)
  offset = self.conv_offset(q_off) # B * g 2 Hg Wg
  Hk, Wk = offset.size(2), offset.size(3)
  n_sample = Hk * Wk
  • 在参照点基础上使用offset
offset = einops.rearrange(offset, 'b p h w -> b h w p')
reference = self._get_ref_points(Hk, Wk, B, dtype, device)
    
if self.no_off:
    offset = offset.fill(0.0)
    
if self.offset_range_factor >= 0:
    pos = offset + reference
else:
    pos = (offset + reference).tanh()
  • 使用bilinear pooling的方式将对应feature map抠出来,等待作为k,v的输入。
x_sampled = F.grid_sample(
    input=x.reshape(B * self.n_groups, self.n_group_channels, H, W), 
    grid=pos[..., (1, 0)], # y, x -> x, y
    mode='bilinear', align_corners=True) # B * g, Cg, Hg, Wg
    
x_sampled = x_sampled.reshape(B, C, 1, n_sample)

q = q.reshape(B * self.n_heads, self.n_head_channels, H * W)
k = self.proj_k(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)
v = self.proj_v(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)
  • 在positional encodding部分引入相对位置的偏置:
  rpe_table = self.rpe_table
  rpe_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
  
  q_grid = self._get_ref_points(H, W, B, dtype, device)
  
  displacement = (q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups, n_sample, 2).unsqueeze(1)).mul(0.5)
  
  attn_bias = F.grid_sample(
      input=rpe_bias.reshape(B * self.n_groups, self.n_group_heads, 2 * H - 1, 2 * W - 1),
      grid=displacement[..., (1, 0)],
      mode='bilinear', align_corners=True
  ) # B * g, h_g, HW, Ns
  
  attn_bias = attn_bias.reshape(B * self.n_heads, H * W, n_sample)
  
  attn = attn + attn_bias

参考

https://github.com/LeapLabTHU/DAT

https://arxiv.org/pdf/2201.00520.pdf

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

当可变形注意力机制引入Vision Transformer 的相关文章

随机推荐

  • 2023华为OD机试真题-机房布局(JAVA、Python、C++)

    题目描述 小明正在规划一个大型数据中心机房 为了使得机柜上的机器都能正常满负荷工作 需要确保在每个机柜边上至少要有一个电箱 为了简化题目 假设这个机房是一整排 M表示机柜 I表示间隔 请你返回这整排机柜 至少需要多少个电箱 如果无解请返回
  • 你真的搞懂Class,class了么?

    博客主页 傻根根呀 欢迎点赞 收藏 留言 欢迎讨论 本文由 傻根根呀 原创 首发于 CSDN 由于博主是在学小白一枚 难免会有错误 有任何问题欢迎评论区留言指出 感激不尽 个人主页 精品专栏 不定时更新 JavaSE MySQL LeetC
  • 有时间学习下ue4开源项目shootergame和虚幻竞技场

    无意中 看到有人推荐这个 特此留个坑 以后学习下 另外 官方文档上有大例子推荐 可以学下 https docs unrealengine com 4 27 zh CN Basics InstallingUnrealEngine Custom
  • 算术基本定理及其应用

    算术基本定理 又称为正整数的唯一分解定理 即 每个大于1的自然数均可写为质数的积 而且这些素因子按大小排列之后 写法仅有一种方式 例如 6936 23 3 172 1200 24 3 52 6936 2 3
  • Reliable Cloud Infrastructure: Design and Process学习笔记

    最后更新2022 03 16 忘记更新对应的学习笔记 补上 这一科有9节 加上0章简介 简介 google cloud的好多功能有点相似 这科内容是介绍应该选什么产品 怎么选择 怎么规划 怎么设计等等 首先 你要有个软件产品的设计思想 包括
  • 西米支付:微信服务商支付的介绍

    服务商申请条件 1 微信支付服务商面向企业 政府机关 事业单位 社会组织类型主体开放申请 2 申请资料准备 1 业务联系人信息 包含联系人姓名 联系手机 联系邮箱 若联系人非法定代表人 还需提交有效证件照片 2 主体身份信息 营业执照 登记
  • (纯c)数据结构之------>链表(详解)

    目录 一 链表的定义 1 链表的结构 2 为啥要存在链表及链表的优势 二 无头单向链表的常用接口 1 头插 尾插 2 头删 尾删 3 销毁链表 打印链表 4 在pos位置后插入一个值 5 消除pos位置后的值 6 查找链表中的值并且返回它的
  • 【Nginx】解决在Nginx+Vue部署多个前端项目,二级目录不能访问、访问空白的问题

    一 前言 需求 设置访问 www ai com 访问时打开前端代码 tmp zhsf 设置访问 www ai com case search 时 访问时打开另一个前端代码 tmp template 二 实现过程 1 根目录访问 部署使用ng
  • uni-cloud云函数管理公共模块依赖

    1 右键函数文件夹 选中依赖模块 更新依赖 2 完成后
  • Keras中的fit函数训练集,验证集和测试集

    Keras中的fit函数训练集 验证集和测试集 1 Keras fit函数history对象包含两个重要属性 epoch 训练的轮数 history 它是一个字典 包含val loss val acc loss acc四个key 2 关于训
  • 第十三届蓝桥杯大赛软件赛省赛 Python 大学 C 组

    试题 A 排列字母 本题总分 5 分 问题描述 小蓝要把一个字符串中的字母按其在字母表中的顺序排列 例如 LANQIAO 排列后为 AAILNOQ 又如 GOODGOODSTUDYDAYDAYUP 排列后为 AADDDDDGGOOOOPST
  • 拥抱ChatGPT,开启结对咨询模式!

    ChatGPT刮起了一阵旋风 ChatGPT到底能做什么 做到什么程度 真的会让咨询顾问失业吗 带着这样的疑问 我费尽周折 注册了ChatGPT账号 我先从一个大众化的话题开启了与ChatGPT的对话 如何提高软件开发的质量 如果是我回答这
  • 网页文字复制的几种方法

    1 开启网页阅读模式 这种方法适用于Microsoft Edge浏览器中 它有网页阅读功能可以使用 在网址的最前面加上 read 就会进入网页阅读界面 然后选中文字就可以直接进行复制了 2 直接拖拽 一种简单直接的方法 不用进行任何其他操作
  • 如何在服务器上跑python程序

    购买服务器 首先你需要一个服务器 阿里云云翼计划有一个9 9云服务器ECS服务 你怎么买我不管 反正你最后给我搞到一个云服务器 购买的配置界面 由于阿里云现在限量购买 所以这里只是截个图说明而已 主要说明一点公共镜像选择ubuntu14 0
  • 【软件测试】理论知识基础第一章

    前言 骗取自己的救赎 直到和染尘斑驳的玫瑰一起坠入深渊 软件测试 理论知识基础第一章 一 认识软件测试 1 什么是软件测试 二 常见的测试分类 1 阶段划分 2 代码可见度划分 3 扩展 总结 三 模型 1 质量模型 2 W模型 四 软件测
  • Webservice接口的生成及调用

    最近项目上要对接一个Webservice形式的接口 因为以前一直没有对接过这种类型的 所以这次专门查了一些资料学习下 一 Webservice的简单介绍 WebService是一种跨编程语言和跨操作系统平台的远程调用技术 它通过标准通信协议
  • AAA协议tacacs认证简单实验

    实验名称 AAA的tacacs验证 实验目的 在AAA认证服务器上认证客户端telnet登陆路由器 实验拓扑图 主要实验步骤 Router上的配置 Router gt en Router conf t Router config inter
  • 内存超频时序怎么调_超频技术之内存“时序”重要参数设置解说

    超频技术之内存 时序 重要参数设置解说 来源 华强电子网 作者 华仔 浏览 432 时间 2017 05 10 21 48 标签 摘要 相信大多数超频帖子里都会提到内存时序调整 也就是我们经常看到的5 5 5 15 1T 4 5 4 12
  • python爬虫requests源码链家_Python 爬虫 链家二手房(自行输入城市爬取)

    因同事想在沈阳买房 对比分析沈阳各区的房价 让我帮忙爬取一下链家网相关数据 然后打 算记下笔记 用于总结学到的东西 用到的东西 一 爬虫需要会什么 学习东西 首先你要知道它是干嘛的 爬虫 顾名思义就是爬取你所看到的网页内容 小说 新闻 信息
  • 当可变形注意力机制引入Vision Transformer

    GiantPandaCV导语 通过在Transformer基础上引入Deformable CNN中的可变性能力 在降低模型参数量的同时提升获取大感受野的能力 文内附代码解读 引言 Transformer由于其更大的感受野能够让其拥有更强的模