CeiT:训练更快的多层特征抽取ViT

2023-11-02

【GiantPandaCV导语】来自商汤和南洋理工的工作,也是使用卷积来增强模型提出low-level特征的能力,增强模型获取局部性的能力,核心贡献是LCA模块,可以用于捕获多层特征表示。

引言

针对先前Transformer架构需要大量额外数据或者额外的监督(Deit),才能获得与卷积神经网络结构相当的性能,为了克服这种缺陷,提出结合CNN来弥补Transformer的缺陷,提出了CeiT:

(1)设计Image-to-Tokens模块来从low-level特征中得到embedding。

(2)将Transformer中的Feed Forward模块替换为Locally-enhanced Feed-Forward(LeFF)模块,增加了相邻token之间的相关性。

(3)使用Layer-wise Class Token Attention(LCA)捕获多层的特征表示。

经过以上修改,可以发现模型效率方面以及泛化能力得到了提升,收敛性也有所改善,如下图所示:

方法

1. Image-to-Tokens

使用卷积+池化来取代原先ViT中7x7的大型patch。

x ′ = I 2   T ( x ) = MaxPool ⁡ ( BN ⁡ ( Conv ⁡ ( x ) ) ) \mathbf{x}^{\prime}=\mathrm{I} 2 \mathrm{~T}(\mathbf{x})=\operatorname{MaxPool}(\operatorname{BN}(\operatorname{Conv}(\mathbf{x}))) x=I2 T(x)=MaxPool(BN(Conv(x)))

2. LeFF

将tokens重新拼成feature map,然后使用深度可分离卷积添加局部性的处理,然后再使用一个Linear层映射至tokens。

x c h , x p h = Split ⁡ ( x t h ) x p l 1 = GELU ⁡ ( BN ⁡ ( Linear ⁡ ( ( x p h ) ) ) x p s = SpatialRestore ⁡ ( x p l 1 ) x p d = GELU ⁡ ( BN ⁡ ( DWConv ⁡ ( x p s ) ) ) x p f = Flatten ⁡ ( x p d ) x p l 2 = GELU ⁡ ( BN ⁡ ( Linear ⁡ 2 ( x p f ) ) ) x t h + 1 = Concat ⁡ ( x c h , x p l 2 ) \begin{aligned} \mathbf{x}_{c}^{h}, \mathbf{x}_{p}^{h} &=\operatorname{Split}\left(\mathbf{x}_{t}^{h}\right) \\ \mathbf{x}_{p}^{l_{1}} &=\operatorname{GELU}\left(\operatorname{BN}\left(\operatorname{Linear}\left(\left(\mathbf{x}_{p}^{h}\right)\right)\right)\right.\\ \mathbf{x}_{p}^{s} &=\operatorname{SpatialRestore}\left(\mathbf{x}_{p}^{l_{1}}\right) \\ \mathbf{x}_{p}^{d} &=\operatorname{GELU}\left(\operatorname{BN}\left(\operatorname{DWConv}\left(\mathbf{x}_{p}^{s}\right)\right)\right) \\ \mathbf{x}_{p}^{f} &=\operatorname{Flatten}\left(\mathbf{x}_{p}^{d}\right) \\ \mathbf{x}_{p}^{l_{2}} &=\operatorname{GELU}\left(\operatorname{BN}\left(\operatorname{Linear} 2\left(\mathbf{x}_{p}^{f}\right)\right)\right) \\ \mathbf{x}_{t}^{h+1} &=\operatorname{Concat}\left(\mathbf{x}_{c}^{h}, \mathbf{x}_{p}^{l_{2}}\right) \end{aligned} xch,xphxpl1xpsxpdxpfxpl2xth+1=Split(xth)=GELU(BN(Linear((xph)))=SpatialRestore(xpl1)=GELU(BN(DWConv(xps)))=Flatten(xpd)=GELU(BN(Linear2(xpf)))=Concat(xch,xpl2)

3. LCA

前两个都比较常规,最后一个比较有特色,经过所有Transformer层以后使用的Layer-wise Class-token Attention,如下图所示:

LCA模块会将所有Transformer Block中得到的class token作为输入,然后再在其基础上使用一个MSA+FFN得到最终的logits输出。作者认为这样可以获取多尺度的表征。

实验

SOTA比较:

I2T消融实验:

LeFF消融实验:

LCA有效性比较:

收敛速度比较:

代码

模块1:I2T Image-to-Token

  # IoT
  self.conv = nn.Sequential(
      nn.Conv2d(in_channels, out_channels, conv_kernel, stride, 4),
      nn.BatchNorm2d(out_channels),
      nn.MaxPool2d(pool_kernel, stride)    
  )
  
  feature_size = image_size // 4

  assert feature_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
  num_patches = (feature_size // patch_size) ** 2
  patch_dim = out_channels * patch_size ** 2
  self.to_patch_embedding = nn.Sequential(
      Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
      nn.Linear(patch_dim, dim),
  )

模块2:LeFF

class LeFF(nn.Module):
    
    def __init__(self, dim = 192, scale = 4, depth_kernel = 3):
        super().__init__()
        
        scale_dim = dim*scale
        self.up_proj = nn.Sequential(nn.Linear(dim, scale_dim),
                                    Rearrange('b n c -> b c n'),
                                    nn.BatchNorm1d(scale_dim),
                                    nn.GELU(),
                                    Rearrange('b c (h w) -> b c h w', h=14, w=14)
                                    )
        
        self.depth_conv =  nn.Sequential(nn.Conv2d(scale_dim, scale_dim, kernel_size=depth_kernel, padding=1, groups=scale_dim, bias=False),
                          nn.BatchNorm2d(scale_dim),
                          nn.GELU(),
                          Rearrange('b c h w -> b (h w) c', h=14, w=14)
                          )
        
        self.down_proj = nn.Sequential(nn.Linear(scale_dim, dim),
                                    Rearrange('b n c -> b c n'),
                                    nn.BatchNorm1d(dim),
                                    nn.GELU(),
                                    Rearrange('b c n -> b n c')
                                    )
        
    def forward(self, x):
        x = self.up_proj(x)
        x = self.depth_conv(x)
        x = self.down_proj(x)
        return x
        
class TransformerLeFF(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, scale = 4, depth_kernel = 3, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
                Residual(PreNorm(dim, LeFF(dim, scale, depth_kernel)))
            ]))
    def forward(self, x):
        c = list()
        for attn, leff in self.layers:
            x = attn(x)
            cls_tokens = x[:, 0]
            c.append(cls_tokens)
            x = leff(x[:, 1:])
            x = torch.cat((cls_tokens.unsqueeze(1), x), dim=1) 
        return x, torch.stack(c).transpose(0, 1)

模块3:LCA

class LCAttention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
        q = q[:, :, -1, :].unsqueeze(2) # Only Lth element use as query

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = dots.softmax(dim=-1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        return out

class LCA(nn.Module):
    # I remove Residual connection from here, in paper author didn't explicitly mentioned to use Residual connection, 
    # so I removed it, althougth with Residual connection also this code will work.
    def __init__(self, dim, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.layers.append(nn.ModuleList([
                PreNorm(dim, LCAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x[:, -1].unsqueeze(1)
            x = x[:, -1].unsqueeze(1) + ff(x)
        return x

参考

https://arxiv.org/abs/2103.11816

https://github.com/rishikksh20/CeiT-pytorch/blob/master/ceit.py

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

CeiT:训练更快的多层特征抽取ViT 的相关文章

随机推荐

  • 2019大厂Android高级工程师面试题整理+进阶资料

    金三银四 很多同学心里大概都准备着年后找工作或者跳槽 最近有很多同学都在交流群里求大厂面试题 正好我电脑里面有这方面的整理 于是就发上来分享给大家 这些题目是网友去百度 小米 乐视 美团 58 猎豹 360 新浪 搜狐等一线互联网公司面试被
  • 计算机系统 实验四(课程实验LAB四)

    实验中需要的几个控制语句 u userid 使用这个语句是要确保不同的人使用不同的 ID 做题 并攻击不同的地址 h 用于打印这几个操作的内容 n 用于 Level4 关卡 s 用于提交你的解决方案到服务器中 1 根据makecookie生
  • mysql ERROR 1045 (28000): Access denied for user ‘ODBC‘@‘localhost‘ (using password: YES)

    遇到这个问题搞了很久 自己记下来 方法是百度的 亲测有效 ERROR 1045 28000 Access denied for user ODBC localhost using password NO ERROR 1045 28000 A
  • 按照lockattribute来划分MESH

    网格模型基础 网格模型 一 定义 二 子集和属性缓存 2 0 子集 2 1 属性 2 2 操作 三 邻接信息 四 属性表 五 优化 六 网格的创建与绘制 6 1 创建 6 2 绘制 一 定义 网格模型是一种将物体的顶点数据 纹理 材质等信息
  • Linux内核scripts/Makefile.build文件结构

    1 默认目标 build 2 初始化obj y obj m等变量 3 include include config auto conf 内含CONFIG RING BUFFER y等变量列表 4 include scripts Kbuild
  • vue+element动态设置el-menu导航,刷新页面保持当前菜单选中项及路由

    今天闲来无事整理了一套后台管理系统的侧边栏菜单 实现了页面刷新路由保持不变和菜单也是当前点击的高亮状态 来一起看看吧 首先 菜单数据是动态的 注意的是 id 和 路由的 name保持一致 页面刷新要用到 一级菜单不用name 因为没用到路由
  • Android开机自启动添加

    1 添加需要自启动的可以执行文件 1 可执行C文件 system core init start needInitStartService c 例如 include
  • 基于大数据的python爬虫的菜谱美食食物推荐系统

    众所周知 现阶段我们正处于一个 大数据 时代 从互联网上大量的数据中找到自己想要的信息变得越来困难 搜索引擎的商业化给市场带来了百度和谷歌这样的商业公司 网络爬虫便是搜索引擎的重要组成部分 本课题是基于Python设计的面向下厨房网站的网络
  • edge浏览器打开多个网页卡顿解决办法

    edge有时候打开了十几个页面就大量占据内存了 卡的不行 上网汇总了解决方法 具体参考以下两篇文章 一个是通过edge浏览器自身的设置修改 一个是关闭gpu相关的图形加速插件 按照以下两篇文章的方法基本就不会卡了 1 解决win10系统ed
  • Redis 与 Lua 脚本

    这篇文章 主要是讲 Redis 和 Lua 是如何协同工作的以及 Redis 如何管理 Lua 脚本 Lua 简介 Lua 以可嵌入 轻量 高效 提升静态语言的灵活性 有了 Lua 方便对程序进行改动或拓展 减少编译的次数 在游戏开发中特别
  • 16行 python代码获取音效素材

    人生苦短 我用python 声音素材资源 源码资料电子书 点击此处跳转文末名片获取 所需环境 开发环境 Python 环境 Pycharm 编辑器 模块 requests re 流程讲解 首先我们打开网址后右键选择检查 选择network
  • Visual Studio 自动补全代码

    自动补全两种方式 1 写完下面代码 双击Tab 自动补全 2 写完下面代码 回车 单击Tab 自动补全 可以在vs中自行查看 ctor 自动补全构造函数 prop 自动实现属性 cw Console WriteLine switch 自动补
  • 【边喝caffee边Caffe 】(三) Check failed: registry.count(t ype) == 1 (0 vs. 1) Unknown layer type

    自己建立一个工程 希望调用libcaffe lib 各种配置好 也能成功编译 但是运行就会遇到报错 F0519 14 54 12 494139 14504 layer factory hpp 77 Check failed registry
  • OCSVM 学习笔记

    OCSVM 学习笔记 前言 OCSVM OneClass SVM 算法是一种经典的异常检测算法 基本原理与 SVM 类似 与 SVM 关注的二分类问题不同的是 就像它的名字 OneClass SVM 那样 OCSVM 只有一个分类 这也正是
  • Excel每页都打印表头

    前言 有时候表格打印时 需要每页都打印表头 但是表格默认是只打印第一页的表头 那该如何设置呢 步骤 切换到 页面布局 打印标题 在 顶端标题行 中右侧可以选择你要打印的标题行 点击确定就欧克了 多行标题就选中多行就行 如 第1行到第3行 所
  • 上岸了,不写代码了

    上岸了 目前不搞这些东西了 不出意外的话应该不会再回来更新和回复了 各位 江湖再见
  • BugKu-Web-矛盾

    BugKu Web 矛盾 题目链接 https ctf bugku com challenges detail id 72 html 考点 PHP弱类型比较漏洞 题目源码分析 num GET num 定义一个num变量用get方法接收 if
  • VUE-鼠标移入到目标区域变成小手模样

    这是测试提的一个需求 当鼠标移入点击更多时 鼠标指针变成小手模样 其实这个东西特别简单 只是用的不多平常 我们只需要给目标区域的style样式中加入 cursor pointer 这个鼠标就好了
  • vue文件无法正常build

    如图所示 run serve后控制台没有报错 但是运行到此处直接结束 解决办法 暴力解决 直接删除node module 再输入cnpm install重新安装依赖 重新安装完成后成功运行
  • CeiT:训练更快的多层特征抽取ViT

    GiantPandaCV导语 来自商汤和南洋理工的工作 也是使用卷积来增强模型提出low level特征的能力 增强模型获取局部性的能力 核心贡献是LCA模块 可以用于捕获多层特征表示 引言 针对先前Transformer架构需要大量额外数