Swin-Unet:Swin Transformer在医学分割上的首次尝试

2023-10-31

Swin-Unet:Swin Transformer在医学分割上的首次尝试

前言

最近小编主要在搞一些医学图像分割的工作,也跑了一下Swin-Unet,之前看到也看到过这篇Swin-Unet(其实五月份就看到了hhhh),决定搬运过来。实际上从这篇论文可以看到目前医学分割或者检测引入transformer,更常见的做法还是直接嵌入到医学图像常用的网络结构中,比如Unet系列等,没有对transformer block做更多的创新,这主要是由于医学图像数据集太小导致对于transformer本身进行创新难以通过医学图像数据集进行实验验证。后续小编将持续更新医学图像分割相关的论文解读系列~~本篇文章应该是今年5月份左右挂到Arxiv上的,工作的创新主要基于今年3月微软的Swin Transformer工作(ps:Swin Transformer刚刚获得ICCV2021最佳论文奖,所以Swin-Unet真的是站在了巨人的肩膀上,hhhhhhh)

一、Related Works

作者首先对最近医学分割领域的相关工作进行了总结,主要有以下三种:

  1. CNN-based methods
    以Unet为backbone的一系列变种,比如Unet++,Unet3+,Att-Unet

  2. Transformer-based methods
    ViT/DeiT,DerT

  3. CNN+Transformer
    两者结合的理由也非常好理解,CNN注重local dependency提取,transformer注重global和long-range dependency提取。举个例子:TransUnet中证明了hybrid encoder优于CNN-based and transformer-based

二、Architecture Overview

整个网络框架如下图所示,可以发现整体就是一个Unet结构,只不过encoder和decoder部分换成了Swin Transformer block,细节部分我们将分为以下几点具体讲解:

  • Swin Transformer
  • Patch Merging && Patch Expanding
  • Comparison with Unet
    在这里插入图片描述

Swin Transformer

第一部分先介绍一下Swin Transformer,先上Swin transformer block的结构图看一波~
在这里插入图片描述
典型的transformer encoder的结构,主要关注点应该是W-MSA(window based MSA)和SW-MSA(shifted window based MSA),这两个组件也是Swin Transformer论文的创新点,下面通过以下计算复杂度的介绍来回顾下作者提出用这个东西来代替单纯的MSA的初衷:

一个MSA的计算复杂度为:
在这里插入图片描述

我们首先看一下如何得到当前的公式,对于一张图像,我们将其分为 h × w h\times w h×w个patch,同时设每个patch经过embedding之后的feature dimension为 C C C,其中 C C C一般为 d m o d e l d_{model} dmodel, d d d为heads的数量。

一起来回顾下MSA的计算:

第一步:计算Q,K,V矩阵,计算量为 3 h w C 2 / d 3hwC^2/d 3hwC2/d
Q = X W Q Q=XW^Q Q=XWQ K = X W K K=XW^K K=XWK V = X W V V=XW^V V=XWV
(以Q的计算为例,其中X维度为 h w × C hw\times C hw×C, W Q W^Q WQ维度为 C × C / d C\times C/d C×C/d , 以此类推)

第二步:计算 Q K T V QK^TV QKTV,计算量为 2 ( h w ) 2 C / d 2(hw)^2C/d 2(hw)2C/d
(其中 Q K T QK^T QKT的计算量为 ( h w ) 2 C / d (hw)^2C/d (hw)2C/d

第三步:一个head的计算量为 3 h w C 2 / d + 2 ( h w ) 2 C / d 3hwC^2/d+2(hw)^2C/d 3hwC2/d+2(hw)2C/d,那么d个head的计算量为 3 h w C 2 + 2 ( h w ) 2 C 3hwC^2+2(hw)^2C 3hwC2+2(hw)2C

第四步:最后将d个head进行融合,和矩阵 W o W^o Wo相乘,计算量为 h w C 2 hwC^2 hwC2,因此总计算量为 4 h w C 2 + 2 ( h w ) 2 C 4hwC^2+2(hw)^2C 4hwC2+2(hw)2C

可以看到MSA的计算复杂度是 O ( n p 2 ) O({n_p}^2) O(np2),其中 n p n_p np为patch的数量,很显然计算量太大,对于大尺度的图片很不友好,显存占用会比较夸张。那么我们来看下一个W-MSA的计算复杂度(顾名思义就是把图像分为几个window,一个window中假设有 M × M M\times M M×M个patch,那么我们只对这个window中所有patch计算attention,提取局部的依赖关系),W-MSA的计算复杂度为:
在这里插入图片描述
所以比较好理解,对于一个window来说,计算量只需要把 4 h w C 2 + 2 ( h w ) 2 C 4hwC^2+2(hw)^2C 4hwC2+2(hw)2C中的 h h h, w w w分别替换为 M M M,那么一个window的计算量为 4 M 2 C 2 + 2 M 4 C 4M^2C^2+2M^4C 4M2C2+2M4C,共有 h w M 2 \frac{hw}{M^2} M2hw个window,就可以得到上述式子。

如果仔细阅读的话可以发现,这样做实际上是有问题的,没错,各个window和window内的patch之间就没有interaction了,为了解决这个问题,我们非常容易想到shifted window,即让window移动一下不就可以了吗,如下图:
在这里插入图片描述
如上图中,可以发现在经过shift= 1 2 \frac{1}{2} 21window_size之后出现了window_size大小不一致的问题,如果只简单的添加padding,计算量还是增加了,(因为窗口数量由 2 × 2 2\times 2 2×2变成 3 × 3 3\times 3 3×3)。因此作者又进行了一个cycle shift操作,这样操作完之后继续按照之前的窗口大小进行划分,并对每一个窗口的patch进行self-attention进行计算。以现在的第一个window(index=5)为例,包含的这四个patch分别来自于shift之前的四个窗口,而每个patch又和之前窗口中的patch进行过交互,所以再次计算第一个window中的patch之间的attention就相当于完成了window间的交互。然而,这样计算还是会存在一些问题,(比如左上角窗口中所有的patch的index都是5,但是右下角的窗口中包含了来自index分别为1,3,7,9的四个patch, 而这些位置并不相邻,计算意义不大)因此作者引入了attention mask(以最后一个window为例):
在这里插入图片描述
添加类似于上图的mask之后就可以避免不同的index的patch之间进行计算。
代码如下(示例):

H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

简单理解来说,就是生成一张attention mask,每个位置的值为0或者-100,取决于index是否相同,再与我们计算出来的 Q K T QK^T QKT矩阵相加,这样index不同位置的attention值就降到了很低,那么softmax值就会比较低。
事实上,在window partition过程中,当整个feature map不能被完全整分为windows时,此时我们除了增加padding之外,也可以尝试这种方法生成attention mask,这样做更加准确而且不改变feature map的大小。

以上就是swin transformer具体的一个block的讲解。

Patch Merging

在Swin Transformer的论文中作者已经提出了Patch Merging的概念,我们先看下Swin-Transformer的整体结构:
在这里插入图片描述
从stage中出现Patch Merging,Patch Merging中进行的操作是:

首先进行 ,具体操作通过将每个patch周围的patch的feature进行concat,此时dimension为输入的4倍,通过一个linear projection将其转为输入的2倍。
代码如下(示例):

class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

上述的Patch Merging和Patch Expanding做法类似pixel unshuffle/shuffle操作。

Relative Position Embedding

在论文中作者提到了在原本计算attention的基础上添加相对位置的编码,如下公式所示:
在这里插入图片描述
至于为什么引入相对位置编码代替绝对位置编码,可以分别从理论+实验上进行证明,简单来说,在计算过程中,相对位置信息会“消失”,这是ViT在提出时没有注意到的问题,作者也做了实验,证明本文提出的relative position 优于ViT中absolute position以及rel+abs组合的方法:
在这里插入图片描述
相对位置编码索引计算过程如下:
在这里插入图片描述
代码如下(示例):

self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)

这里链接中的博主讲的比较清楚,每一步都有具体的讲解。(大概过程是生成了一张可以索引的relative position table)
https://zhuanlan.zhihu.com/p/384514268

Patch Expanding

与merging操作相反,首先经过linear projection将维度拓展2倍,接着进行rearrange operation(操作即为merging的逆过程),为了证明此种方法有效,作者和传统的上采样方法进行了比较:
在这里插入图片描述
代码如下(示例):

class PatchExpand(nn.Module):
    def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity()
        self.norm = norm_layer(dim // dim_scale)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        x = self.expand(x)
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)
        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4)
        x = x.view(B,-1,C//4)
        x= self.norm(x)

        return x

Comparison with Unet

很显然,整篇论文的思路就是把UNet中CNN换成swin-transformer结构,左边为下采样+通道扩张,右边为上采样+通道压缩。

三、Experiments

在这里插入图片描述
在这里插入图片描述

四、Views on related works about transformer(七月份写的)

最近大概读了些最新的transformer的论文,主要是关于医学分割方向的,简单地在此谈谈下一步可创新的点:

1.引入MLP

目前三篇论文《Do you even need attention》,《external attention》,《remlp》都在暗示self-attention是否可以直接被MLP取代,《Do you even need attention》中发现将ViT中的self-attention替换为patch dimension的MLP效果也非常好,同时也做了将feature dimension的FFN替换为self-attention但是效果就很差,作者认为viT之所以表现不错的原因可能取决于patch embedding以及训练过程;《external attention》中认为计算self-attention时 的计算没有必要,因为一个位置的特征只与周围近距离的几个点的值有关,同时为了挖掘样本之间的关系,将self-attention长距离建模拓展到样本的层面上,提出了使用两个外部的记忆单元(代码中用了两个全连接层),事实上,对这两个全连接层能否很好地像self-attention一样对单个样本中patch之间进行建模不能确定,或者说是否和self-attention相结合效果会更好?

2.结合医学图像特点,做的更细致

虽然引入了transformer,但可以看出,也仅仅是使用了,还没有很好地结合医学图像做更多地细化工作,后续双transformer,多尺度的transformer融合分割等可以搞起来了…最近也在看些半监督的论文,加入半监督的transformer说不定也可以发起来。

plus:前些天也读了一篇《CAT:Cross attention in vision transformer》,和swin-transformer有点异曲同工的感觉,分成了IPSA和CPSA,主要也是考虑到了transformer忽略了单个patch中的结构相关性以及最重要的可能是模仿mobile net(作者也提到了)减少计算量,实验结果和swin-T相比不相上下。

3.CNN的那一套改进的可以在transformer上再过一遍了

刚看到可变形的transformer出来了…

后续小编将继续进行更新~~~~~~~~~~~~~~~~~~~~~~~~~

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Swin-Unet:Swin Transformer在医学分割上的首次尝试 的相关文章

随机推荐

  • 服务器管理系统是什么

    服务器管理系统是什么 服务器管理系统 是在操作系统下对操作系统的服务器软件及其相关软件进行二次设置的管理软件 是运营商管理域名 服务器 企业邮局 数据库等服务器主机类产品的一个网站平台 以达到快捷实现域名 服务器主机 企业邮局 数据库等产品
  • C#连接数据库SQlServer+Form窗格实现简单动态增删查改操作

    C 连接数据库SQlServer Form窗格实现简单动态增删查改操作 一 数据库连接 数据库连接部分学习自该博主原创博客 点击即可跳转 再附上该博主的博客链接 https blog csdn net kiss soul article d
  • 深度学习去运动模糊----《DeblurGAN》

    前言 现实生活中 大多数图片是模糊不清的 试想一下 追剧时视频不清晰 看着都很捉急 何况现实中好端端的一幅美景 美女也可以 被抓拍得不忍直视 瞬间暴躁 拍照时手抖 或者画面中的物体运动都会让画面模糊 女友辛辛苦苦摆好的各种Pose也将淹没在
  • 海湾主机汉字注释表打字出_海湾报警主机(JB-QG-GST5000)操作手册

    报警主机正面示意图 报警主机内部结构图 控制器 模块总线 通讯总线 联动电源输出端子示意图 A1 B1 An Bn RS 485有极性通讯线端子 接火灾报警显示盘 GND 24V LD D02电源盘对外输出端子 保护地 此端子与机壳相连 接
  • NLP技术中的Tokenization

    作者 Gam Waiciu 单位 QTrade AI研发中心 研究方向 自然语言处理 前言 今天我们来聊一聊 NLP 技术中的 Tokenization 之所以想要聊这个话题 是因为 一方面在 NLP 技术中 Tokenization 是非
  • 网络知识:光纤收发器TX、RX介绍以及两者的区别

    当我们远距离传输时 通常会使用光纤来传输 因为光纤的传输距离很远 一般来说单模光纤的传输距离在10千米以上 而多模光纤的传输距离最高也能达到2千米 而在光纤网络中 我们常常会使用到光纤收发器 那么光纤收发器怎么连 我们一起来了解下 一 光纤
  • 自媒体月入过万的运营攻略,轻松上手

    很多自媒体新手羡慕大V月入过万 同是做自媒体运营 为什么自己不能实现营收过万呢 给大家分享一套月入过万的运营攻略 适合新手们去操作 收藏起来直接套用到运营哦 1 账号定位 清晰的定位是影响后期变现的关键因素 选一个后期容易变现的领域能帮自己
  • ajax net::err_connection_refused,javascript - How to handle net::ERR_CONNECTION_REFUSED in jquery aj...

    It appears that when jqXHR readyState i e the readyState field of the first parameter to the ajax fail method is 0 that
  • 调用织梦搜索功能

    织梦默认的搜索框
  • 使用C对TOML文件的解析

    使用C对TOML文件的解析 toml书写语法 解析toml文件 测试输出内容如下 TOML是前GitHub CEO Tom Preston Werner 于2013年创建的语言 其目标是成为一个小规模的易于使用的语义化配置文件格式 TOML
  • HJT212协议

    HJ T212是由国家环保行业制定的数据传输标准协议 目前广泛使用的是HJ T212 2005通信协议 该协议在2005年制定 并于2006年2月1日正式实施 HJ T212标准不规定数据采集传输仪与监控仪器仪表的通讯方式 可以采用RS23
  • Mali GPU OpenGL ES 应用性能优化--基本概念

    1 基本概念 1 1 Mali GPU家族 Mali GPU家族都包含以下通用的硬件 基于分块的延迟渲染 Mali GPU把framebuffer分成许多块 16 x 16像素 然后一块一块地进行渲染 基于分块的渲染是有效的 因为像素值使用
  • matlab中if elseif语句,Matlab if…elseif…elseif…else…end语句

    if语句后面可以有一个 或多个 可选elseif 和一个else语句 这对于测试各种条件非常有用 当使用if elseif else语句时 请记住几点 if可以有零个或一个else 它必须在elseif之后 if可以有零到多个elseif
  • (python编程)k-shell的实现

    一 k shell 算法 改错 他发的代码报错 def kshell graph importance dict ks 1 while graph nodes temp node degrees dict for i in graph de
  • python之标准库使用

    目录 一 标准库 二 字符串操作 三 字符串类型 四 时间操作 五 文件基本方法及操作 文件基本方法 文件操作 一 标准库 Python 标准库非常庞大 所提供的组件涉及范围十分广泛 正如以下内容目录所显示的 这个库包含了 Python中的
  • Activiti7 监听器【十四】

    Activiti 7系列文章目录 文章代码下载 Activiti7 工作流设计器 一 Activiti7 创建表 二 Activiti7 表结构介绍 三 Activiti7 设计器创建流程 四 Activiti7 部署流程 五 Activi
  • maven打包出错:Failed to execute goal org.sp

    Failed to execute goal org springframework boot spring boot maven plugin 2 2 13 RELEASE repackage default on project bla
  • 数学建模——数据分析方法

    一 常见数据分析软件 Excel office三件套之一 R语言 Eviews origin 图形分析工具 SPSS 统计分析与数据挖掘 MATLAB 墙裂推荐 python 墙裂推荐 SAS 二 统计性描述 均值 mean x
  • 第一次参加蓝桥杯的心得

    随着我的4道题的答案提交后 蓝桥杯第十届比赛落下帷幕 这其中我也是参赛者 对于这次比赛 虽然我是一位小白 但是我也有不少的感悟 因为这一次也是我从小到大参加的第一次大型竞赛 所以我做了以下的总结 这次的比赛是在长沙理工大学比赛 所以对于我来
  • Swin-Unet:Swin Transformer在医学分割上的首次尝试

    Swin Unet Swin Transformer在医学分割上的首次尝试 前言 最近小编主要在搞一些医学图像分割的工作 也跑了一下Swin Unet 之前看到也看到过这篇Swin Unet 其实五月份就看到了hhhh 决定搬运过来 实际上