【深度学习】详解 Swin Transformer (SwinT)

2023-11-18


目录

摘要

一、介绍

二、原理

2.1 整体架构

2.1.1 Architecture 

2.1.2 Swin Transformer Block

2.2 基于移位窗口的自注意力

2.2.1 非重叠局部窗口中的自注意力

2.2.2 在连续块中的移位窗口划分      

2.2.3 便于移位配置的高效批计算

2.2.4 相对位置偏置

2.3 架构变体 

三、源码

3.1 Swin Transformer

3.2 Patch Embedding

3.3 Patch Merging

3.4 Window Partition

3.5 Window Reverse

3.6 MLP

3.7 Window Attention (W-MSA Module) ☆

3.8 Swin Transformer Block ☆

3.8.1 Shift Window Attention

3.8.2 Attention Mask

3.9 Basic Layer



摘要

        本文介绍了一种称为 Swin Transformer 的新视觉 Transformer,它可以作为 CV 的通用主干。将 Transformer 从语言适应到视觉方面的挑战来自 两个域之间的差异,例如视觉实体的规模以及相比于文本单词的高分辨率图像像素的巨大差异。为解决这些差异,我们提出了一种 层次化 (hierarchical) Transformer,其表示是用 移位窗口 (Shifted Windows) 计算的。移位窗口方案通过 将自注意力计算限制在不重叠的局部窗口的同时,还允许跨窗口连接来提高效率。这种分层架构具有在各种尺度上建模的灵活性,并且 相对于图像大小具有线性计算复杂度。Swin Transformer 的这些特性使其与广泛的视觉任务兼容,包括图像分类(ImageNet-1K 的 87.3 top-1 Acc)和 密集预测任务,例如 目标检测(COCO test dev 的 58.7 box AP 和 51.1 mask AP)和语义分割(ADE20K val 的 53.5 mIoU)。它的性能在 COCO 上以 +2.7 box AP 和 +2.6 mask AP 以及在 ADE20K 上 +3.2 mIoU 的大幅度超越了 SOTA 技术,证明了基于 Transformer 的模型作为视觉主干的潜力。分层设计和移位窗口方法也证明了其对全 MLP 架构是有益的


一、介绍

        CV 建模一直由 CNN 主导。从 AlexNet 和它在图像分类挑战上的革命性性能开始,CNN 架构已通过更大规模的、更广泛的连接和更复杂的卷积形式变得越来越强大。随着 CNNs 作为各种视觉任务的主干网络,这些架构的进步促进了性能的提升,并广泛地带动了整个领域的发展。另一方面,在 NLP 中,网络架构的发展已采取了一条不同的道路,即时至今日流行的架构是 Transformer。为序列建模和转换任务而设计的 Transformer,因其注意力机制对数据中的长程依赖性进行建模而闻名。它在语言领域的巨大成功使研究人员研究了它对计算机视觉的适应性,最近它在某些任务上展示了良好的结果,特别是图像分类和联合视觉-语言建模。本文试图扩大 Transformer 的适用性,使它可以作为 CV 的通用主干,正如其在 NLP 和 CNNs 在 CV 中一样。我们注意到将其在语言领域的高性能迁移到视觉领域的显著挑战,而这可用 两种模态之间的差异 来解释。

        其中一种差异涉及尺度 (scale)。与在语言 Transformer 中作为处理的基本元素的 word token 不同,视觉元素在尺度 (scale) 上可以存在很大差异,这是一个在目标检测等任务中受到关注的问题。在现有的基于 Transformer 的模型中,token 的尺度 (scale) 都是固定的,这是一种不适合这些视觉应用的性质。另一个差异是,图像中的像素分辨率比文本段落中的文字要高得多。存在许多视觉任务 ,如语义分割,需在像素级别上进行密集预测,这对于高分辨率图像上的 Transformer 而言是难以处理的,因为其 自注意力的计算复杂度是关于图像大小的二次方

图 1

        为克服这些问题,Swin Transformer 构造了层次化特征图,且关于图像大小具有线性计算复杂度。如图 1 (a) 所示,Swin Transformer 通过 从小尺寸 patch (灰色轮廓) 开始,逐渐在更深的 Transformer 层中合并相邻 patch,从而构造出一个层次化表示 (hierarchical representation)。通过这些层次化特征图,Swin Transformer 模型可方便地利用先进技术进行密集预测,例如特征金字塔网络 (FPN) 或 U-Net。线性计算复杂度是通过在图像分区的非重叠窗口内,局部地计算自注意力来实现的 (红色轮廓) (而非在整张图像的所有 patch 上进行)。每个窗口中的 patch 数量是固定的,因此复杂度与图像大小成线性关系。这些优点使 Swin Transformer 适合作为各种视觉任务的通用主干,与之前基于 Transformer 的架构形成对比,后者产生单一分辨率的特征图并具有二次复杂度

图 2

        Swin Transformer 的一个关键设计元素是它 在连续自注意力层之间的窗口分区的移位 (shift),如图 2 所示。移位窗口桥接了前一层的窗口,提供二者之间的连接,显着增强建模能力 (见表 4)。这种策略对于现实世界的延迟也是有效的:一个局部窗口内的所有 query patch 共享相同的 key 集合,这有助于硬件中的内存访问。相比之下,早期的 基于滑动 (sliding) 窗口的自注意力方法 由于 不同 query 像素具有不同的 key 集合 而在通用硬件上受到低延迟的影响。我们的实验表明,所提出的移位窗口方法的延迟比滑动窗口方法低得多,而建模能力相似 (见表 5 / 6)。移位窗口方法也被证明对全 MLP 架构有益。 

        所提出的 Swin Transformer 在图像分类、目标检测和语义分割的识别任务上取得了强大的性能。它在三个任务上以相似的延迟显着优于 ViT / DeiT 和 ResNe(X)t 模型。 我们相信,跨 CV 和 NP 的统一架构可以使这两个领域受益,因为它将促进视觉和文本信号的联合建模,并且可以更深入地共享来自两个领域的建模知识。我们希望 Swin Transformer 在各种视觉问题上的强大表现能够在社区中更深入地推动这种信念,并鼓励视觉和语言信号的统一建模。


二、原理

2.1 整体架构

2.1.1 Architecture 

图 3

        图 3 展示了 Swin Transformer 架构概览 (tiny 版 SwinT)。它首先通过 Patch 拆分模块 (Patch Partition) (同 ViT) 将输入的 H \times W \times 3 的 RGB 图像拆分为非重叠等尺寸的 N \times (P^{2} \times 3) patch每个 P^2 \times 3 patch 都被视为一个 patch token,共拆分出 N 个 (即 Transformer 的有效输入序列长度)

        更具体地,用 P^2 = 4 \times 4 大小且通道数 C = 3 的 patch,故各 patch 展平后的特征维度为 P \times P \times C = 4 \times 4 \times 3 = 48,共有 N = \frac{H}{4} \times \frac{W}{4} = \frac{HW}{16} 个 patch tokens。换言之,每张 H \times W \times 3 的图片被处理为了 \frac{H}{4} \times \frac{W}{4} 个图片 patches,每个 patch 被展平为 48 维的 token 向量 (类似 ViT 的 Flattened Patches),整体上是一个展平 (flatten) 的 N \times (P^2 \times 3) = (\frac{H}{4} \times \frac{W}{4}) \times 48 维 2D patch 序列。

        线性嵌入层 (Linear Embedding) (即全连接层) 则将此时维度为 (\frac{H}{4} \times \frac{W}{4}) \times 48 的张量投影到任意维度 C,得到维度为 (\frac{H}{4} \times \frac{W}{4}) \times C 的 Linear Embedding

        随后,这些 patch tokens (此时已为 Linear Embedding) 被馈入若干具有改进自注意力的 Swin Transformer blocks。首个 Swin Transformer block 保持输入输出 tokens 数恒为 \frac{H}{4} \times \frac{W}{4} 不变,且与 线性嵌入层 共同被指定为 Stage 1 (如图 3 的第一个虚线框所示)。

        为产生一个 层次化表示 (Hierarchical Representation),随着网络的加深,tokens 数逐渐通过 Patch 合并层 (Patch Meraging) 被减少。首个 Patch 合并层拼接了每组 2 \times 2 相邻 patch,则 patch token 数变为原来的 \frac{1}{4},即 \frac{H}{8} \times \frac{W}{8},而 patch token 的维度则扩大 4 倍,即 4C。然后,对 4C 维的 patch 拼接特征使用了一个线性层,将输出维度降为 2C。然后使用 Swin Transformer blocks 进行特征转换,其分辨率保持 \frac{H}{8} \times \frac{W}{8}不变。首个 Patch 合并层 和 该特征转换 Swin Transformer block 被指定为 Stage 2 (如图 3 的第二个虚线框所示)。重复 2 次与 Stage 2 相同过程,则分别指定为 Stage 3 Stage 4 (如图 3 的第三、四个虚线框所示)。输出分辨率 /  patch token 数 则分别为 \frac{H}{16} \times \frac{W}{16} 和 \frac{H}{32} \times \frac{W}{32}每个 Stage 都会改变张量的维度,从而形成一种层次化的表征。由此,该架构可方便地替换现有的各种视觉任务的主干网络。

2.1.2 Swin Transformer Block

W-MSA: 规则窗口 MSA    -    SW-MSA: 移位窗口 MSA

        Swin Transformer 相比于 Transformer block (例如 ViT),将 标准多头自注意力模块 (MSA) 替换为 基于移位窗口的多头自注意力模块 (W-MSA / SW-MSA) 且保持其他部分不变 (描述于 3.2 节)。如图 3(b) 或上图所示,一个 Swin Transformer block 由一个 基于移位窗口的 MSA 模块 构成,且后接一个夹有 GeLU 非线性在中间的 2 层 MLPLayerNorm (LN) 层被应用于每个 MSA 模块和每个 MLP 前,且一个 残差连接 被应用于每个模块后。


2.2 基于移位窗口的自注意力

        标准的 Transformer 架构及其对图像分类的适应版本都执行 全局自注意力,其计算了每个 token 与其他所有 tokens 之间的关系 (Attention Map)。全局自注意力计算 具有 相对于 token 数的二次计算复杂度 O(N^2D) (N 为 token 数 / 序列长度D 为 token 向量长度 / 嵌入维度),使之不适用于许多需大量 tokens 的 密集预测 / 高分辨率图像表示 等 高计算量视觉问题。

         O(MSA) 或 O(MHA) 的计算:

         当 n >> d 时,O(MHA) = O(n^2d),或者说 O(MSA) = O(N^2D)

2.2.1 非重叠局部窗口中的自注意力

        为高效建模,我们提出 在非重叠的局部窗口中计算自注意力,取代全局自注意力。以不重叠的方式均匀地划分图像得到各个窗口。已知 D = 2C,则设 每个非重叠局部窗口都包含 N = M × M 个 patch tokens,则 基于具有 N = h × w 个 patch tokens 的图像窗口的 MSA 模块 基于非重叠局部窗口的 W-MSA 模块 的计算复杂度分别是:

        其中,MSA 关于 patch token 数 h \times w 具有 二次复杂度 (共 h w 个 patch tokens,每个 patch token 在全局计算 h w 次)。W-MSA 则当 M 固定时 (默认设为 7) 具有 线性复杂度  (共 h w 个 patch tokens,每个 patch token 在各自的局部窗口内计算 M^2 次)。巨大的 h \times w 对 全局自注意力 计算而言是难以承受的,而 基于窗口的自注意力 (W_MSA) 则具有良好的扩展性。

2.2.2 在连续块中的移位窗口划分      

        基于窗口的自注意力模块 (W-MSA) 虽将计算复杂度从二次降为线性,但跨窗口之间交流与联系的匮乏将限制其建模表征能力。为引入跨窗口的联系且同时保持非重叠窗口的计算效率,我们提出一个 移位窗口划分方法,该方法在连续 Swin Transformer blocks 中的两种划分/分区配置间交替。

图 2

        如图 2 所示,首个模块使用一个规则的窗口划分策略,从左上角像素开始,将 8 \times 8 特征图均匀划分为 2 \times 2 个大小为 4 \times 4 的窗口 (此时局部窗口尺寸为 M = 4,如红色框所示)。然后,下个模块采用自前一层移位的窗口配置,即令规则划分窗口向左上 循环移位 (\left \lfloor \frac{M}{2} \right \rfloor, \left \lfloor \frac{M}{2} \right \rfloor) 个像素,如上图的红色框位置变化所示。

        通过采用移位窗口划分方法,如上图的 两个连续 Swin Transformer Blocks 的计算可表示为:

        其中, \hat{\textbf{z}}^l 和 \textbf{z}^l 分别表示第 l 个 block 的 (S)W-MSA 模块输出特征 和 MLP 模块输出特征 (如图 3 (b) 所示)。

        移位窗口划分方法引入了先前层非重叠相邻窗口间的联系,且对图像分类、目标检测和语义分割很有效,如表 4 所示。

2.2.3 便于移位配置的高效批计算

         一个关于移位窗口划分的问题是,从 \left \lceil \frac{h}{M} \right \rceil \times \left \lceil \frac{w}{M} \right \rceil 到 (\left \lceil \frac{h}{M} \right \rceil + 1 ) \times ( \left \lceil \frac{w}{M} \right \rceil + 1 ) 不但会产生更多窗口,而且有些窗口尺寸将小于 M \times M。一个朴素的解决方法是,将更小的窗口填充至 M \times M,且在计算注意力时屏蔽掉填充值。当规则划分的窗口数很少时,即 2 \times 2,由该朴素方法所带来的计算量增长是相当可观的 (2 \times 2 \rightarrow 3 \times 3 大 2.25 倍)。

        此处,我们提出了一种更有效的批计算方法,其 循环向左上方移位,如图 4 所示。在这种移位后,批窗口可由特征图中不相邻的子窗口组成,因此 使用屏蔽机制将自注意计算限制在每个子窗口内。通过循环移位,批处理窗口的数仍与规则分区的窗口数相同 (如规则划分时是 4 个窗口,向左上角循环移位后仍是 4 个窗口,如上图的 A,B,C,D 所示)。因此,该方法是高效的,其低延迟率如表 5 所示。

        经过了循环移位的方法,一个窗口可包含来自不同窗口的内容。因此,要采用 masked MSA 机制将自注意力计算限制在各子窗口内。最后通过逆循环移位方法将每个窗口的自注意力结果返回。例如,一个 9 窗口的图解如下所示:

        按子窗口划分即可得到 5 号子窗口的自注意力的结果,但直接计算会使得 5 号 / 6 号 / 4 号子窗口的自注意力计算混在一起,类似的混算还包括 5 号 / 8 号 / 2 号子窗口 和  9 号 / 7 号 / 3 号 / 1 号子窗口的纵向或横向等。所以需采用 masked MSA 机制先正常计算自注意力,再进行 mask 操作将不需要的注意力图置 0,从而将自注意力计算限制在各子窗口内

         例如, 6 号 / 4 号子窗口共由 4 个 patch 构成一个正方形区域,如下所示,故应计算出 4×4 注意力图。

        为避免各不同的子窗口注意力计算发生混叠,合适的注意力图应如下所示:

         从而,合适的 mask 应如下所示:

         再例如,9 号 / 7 号 / 3 号 / 1 号子窗口共由 4 个 patch 构成一个正方形区域,如下所示:

         同理,合适的 mask 应如下所示:

2.2.4 相对位置偏置

        在计算自注意力时,我们在计算相似度的过程中对每个 head 加入 相对位置偏置 B \in \mathbb{R}^{M^2 \times M^2},如下所示:

        其中,Q, K, V \in \mathbb{R}^{M^2, d} 分别为 QueryKeyValue 矩阵,d 为 Query / Key 维度,M^2 为 (局部) 窗口内的 patches 数。因为沿各轴的相对位置均处于 [-M + 1, M - 1] 范围内,我们参数化一个更小尺寸的偏置矩阵 \hat{B} \in \mathbb{R}^{(2M-1) \times (2M-1)},且 B 中的值均取自 \hat{B}

        如表 4 的实验表明,使用该 相对位置偏置 的效果显著优于 不使用位置偏置 使用绝对位置嵌入。进一步向输入添加绝对位置嵌入会略微降低性能,因此在我们的实现中没有采用。 

        此外,预训练中学习到的相对位置偏置 也可用于 通过双三次插值 初始化具有不同窗口大小的微调模型。


2.3 架构变体 

        我们构造的基础模型 Swin-B 具有类似于 ViT-B/DeiT-B 的模型大小和计算复杂度。我们也引入了 Swin-T,Swin-S 和 Swin-L,其模型大小和计算复杂度分别是 Swin-B 的 0.25\times0.5\times 和 2 \times。注意到,Swin-T 和 Swin-S 的复杂度分别与 ResNet-50 (DeiT-S) 和 ResNet-101 相似。每种架构的窗口尺寸均默认设为 M = 7。对于所有实验,每个 headQuery 维度 d = 32,且每个 MLP 的扩展层为 \alpha = 4。每种架构的各 Stage 层数如下:

        其中,C 是 Stage 1 的隐藏层通道数。用于 ImageNet 图像分类的各模型变体的模型大小、理论计算复杂度 (FLOPs) 和 吞吐量 (throughput) 如表 1 所示。

表 1:ImageNet 图像分类性能对比
表 7:模型架构细节

三、源码

        Codehttps://github.com/microsoft/Swin-Transformer

3.1 Swin Transformer

        先入为主地展示 Swin Transformer 的整体架构。

class SwinTransformer(nn.Module):
    r""" Swin Transformer
        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
          https://arxiv.org/pdf/2103.14030
    Args:
        img_size (int | tuple(int)): Input image size. Default 224
        patch_size (int | tuple(int)): Patch size. Default: 4
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each Swin Transformer layer.
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
        drop_rate (float): Dropout rate. Default: 0
        attn_drop_rate (float): Attention dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
        patch_norm (bool): If True, add normalization after patch embedding. Default: True
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 use_checkpoint=False, **kwargs):
        super().__init__()

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution

        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               depth=depths[i_layer],
                               num_heads=num_heads[i_layer],
                               window_size=window_size,
                               mlp_ratio=self.mlp_ratio,
                               qkv_bias=qkv_bias, qk_scale=qk_scale,
                               drop=drop_rate, attn_drop=attn_drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer=norm_layer,
                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                               use_checkpoint=use_checkpoint)
            self.layers.append(layer)

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'absolute_pos_embed'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'relative_position_bias_table'}

    def forward_features(self, x):
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)  # B L C
        x = self.avgpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

    def flops(self):
        flops = 0
        flops += self.patch_embed.flops()
        for i, layer in enumerate(self.layers):
            flops += layer.flops()
        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
        flops += self.num_features * self.num_classes
        return flops

3.2 Patch Embedding

        将图片输入 Swin Transformer Block 前,需将图片划分成若干 patch tokens 并投影为嵌入向量。更具体地,将输入原始图片划分成一个个 patch_size * patch_size 大小的 patch token,然后投影嵌入。可通过将 2D 卷积层的 stride 和 kernel_size 的大小设为 patch_size,并将输出通道数设为 embed_dim 来实现投影嵌入。最后,展平并置换维度。

class PatchEmbed(nn.Module):
    r""" Image to Patch Embedding
    Args:
        img_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)  # 输入嵌入投影
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        '''         
            # 以默认参数为例                # 输入 (B, C, H, W) = (B, 3, 224, 224)
            x = self.proj(x)              # 输出 (B, 96, 224/4, 224/4) = (B, 96, 56, 56)
            x = torch.flatten(x, 2)       # H W 维展平, 输出 (B, 96, 56*56)
            x = torch.transpose(x, 1, 2)  # C 维放最后, 输出 (B, 56*56, 96)
        '''
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)  # shape = (B, P_h*P_w, C)
        if self.norm is not None:
            x = self.norm(x)
        return x

    def flops(self):
        Ho, Wo = self.patches_resolution
        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
        if self.norm is not None:
            flops += Ho * Wo * self.embed_dim
        return flops

3.3 Patch Merging

        在每个 Stage 前下采样缩小分辨率并减半通道数,从而形成层次化设计并降低运算量 (类似 Pixel Shuffle)。示意图及实现:

class PatchMerging(nn.Module):
    r""" Patch Merging Layer.
    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        # reshape
        x = x.view(B, H, W, C)

        # 在行、列方向以 stride = 2 等间隔抽样, 实现分辨率 1/2 下采样
        x0 = x[:, 0::2, 0::2, :]              # shape = (B, H/2, W/2, C)
        x1 = x[:, 1::2, 0::2, :]              # shape = (B, H/2, W/2, C)
        x2 = x[:, 0::2, 1::2, :]              # shape = (B, H/2, W/2, C)
        x3 = x[:, 1::2, 1::2, :]              # shape = (B, H/2, W/2, C)

        # 拼接 使通道数加倍
        x = torch.cat([x0, x1, x2, x3], -1)   # shape = (B, H/2, W/2, 4*C)
        x = x.view(B, -1, 4 * C)              # shape = (B, H*W/4, 4*C)

        # FC 使通道数减半
        x = self.norm(x)
        x = self.reduction(x)                 # shape = (B, H*W/4, 2*C)

        return x

    def extra_repr(self) -> str:
        return f"input_resolution={self.input_resolution}, dim={self.dim}"

    def flops(self):
        H, W = self.input_resolution
        flops = H * W * self.dim
        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
        return flops

3.4 Window Partition

        将 shape = (B, H, W, C) 的输入张量 reshape 为 shape = (B \times \frac{H}{M} \times \frac{W}{M}, M, M, C) 的窗口张量。其中 M 即为窗口大小。由此,得到 N = B \times \frac{H}{M} \times \frac{W}{M} 个 shape = (M, M, C) 的窗口。该函数将用于 Window Attention

def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """

    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

3.5 Window Reverse

        即窗口划分的逆过程,将 shape = (B \times \frac{H}{M} \times \frac{W}{M}, M, M, C) 的窗口张量 reshape 回 shape = (B, H, W, C) 的张量。该函数将用于 Window Attention

def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    """

    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

3.6 MLP

        使用 GELU 激活函数 + Dropout 的两层 FCs。

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

3.7 Window Attention (W-MSA Module) ☆

        一方面,在局部窗口 而非全局图像内 计算自注意力 可将计算复杂度由二次降为线性。

        另一方面,在计算原 Attention 的 Query 和 Key 时,加入 相对位置编码 B 可改善性能。

        更具体地,首先由 Query 和 Key 相乘得到 Attention Map,shape = (numWindows*B, num_heads, window_size*window_size, window_size*window_size)。对于 Attention Map,以不同像素点作为原点,则各像素点位置/坐标随之不同。

        由于 每个非重叠局部窗口都包含 N = M × M 个 patch tokens, window_size = M = 2 为例,分别以 左上角像素点右上角像素点 为原点的相对位置编码如下所示 (坐标系轴向同矩阵坐标):

         其次,使用 torch.arange 生成等距的行方向和列方向索引,再用 torch.meshgrid 生成网格坐标索引。

        仍以 window_size = M = 2 为例,生成网格 grid 坐标:

coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.meshgrid([coords_h, coords_w]) # -> 2*(wh, ww)
"""
  (tensor([[0, 0],
           [1, 1]]), 
   tensor([[0, 1],
           [0, 1]]))
"""

         堆叠并展开为 2D 向量:

coords = torch.stack(coords)  # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
"""
tensor([[0, 0, 1, 1],
        [0, 1, 0, 1]])
"""

         分别在第 1 和 2 维处插入新维度,并利用广播机制做减法,得到 shape = (2, wh*ww, wh*ww) 的张量:

relative_coords_first = coords_flatten[:, :, None]  # 2, wh*ww, 1
relative_coords_second = coords_flatten[:, None, :]  # 2, 1, wh*ww
relative_coords = relative_coords_first - relative_coords_second  # 2, wh*ww, wh*ww

        由于相减得到的索引是从负数开始的,故加上偏移量使之从 0 开始:

relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1

        接着,需要将其展开成 1D 偏移量。

        对于诸如第 0 行上 (1, 2) 和 (2, 1) 这两个不同的坐标 (x, y),通过将 (x, y) 坐标求和得到 1D 偏移量 x+y 时,二者所表示的 相对于原点的偏移量却是相等的 (1+2 = 2+1 = 3):

可见第 0 行的原始偏移量 x+y 分别为 2、3、3、4,不同的位置却具有相同的偏移量,降低了相对区分度/差异度

         为避免这种 偏移量相等 的错误对应情况,还需对坐标 (准确地说是 x 坐标) 进行 乘法变换 (offset multiply),以提高区分度:

relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1)  # 每个 x 坐标乘 (2 * 2 - 1) = 3
对 x 坐标实施乘法变换得到 (x', y),再重新计算得到具有差异度的各坐标位置的偏移量 x'+y

         接着在最后一维上求和 x+y,展开成一个 1D 坐标 (相对位置索引),并注册为一个不参与网络学习的变量 relative_position_index,其作用是 根据最终的相对位置索引 找到对应的可学习的相对位置编码

relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)

         完整代码如下所示:

class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int):                         Number of input channels.
        window_size (tuple[int]):          The height and width of the window.
        num_heads (int):                   Number of attention heads.
        qkv_bias (bool, optional):         If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional):       Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional):       Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()

        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])  # 局部窗口高度方向坐标
        coords_w = torch.arange(self.window_size[1])  # 局部窗口宽度方向坐标
        # 局部窗口坐标网格
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        # 相对位置
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """

        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # Query, Key, Value
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
        # Query 放缩
        q = q * self.scale  
        # Query * Key
        attn = (q @ k.transpose(-2, -1))  # @ 表示矩阵-向量乘法

        # 相对位置偏置 B
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww

        # Attention Map = Softmax(Q * K / √d + B) 
        attn = attn + relative_position_bias.unsqueeze(0)
        # 局部窗口 attention map mask + Softmax
        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)  # 最终的 Attention Map
        else:
            attn = self.softmax(attn)  # 最终的 Attention Map

        attn = self.attn_drop(attn)
        # Attention Map * V
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)  # @ 表示矩阵-向量乘法

        # 线性投影 FC
        x = self.proj(x)
        x = self.proj_drop(x)

        return x

    def extra_repr(self) -> str:
        ### 用于输出 print 结果
        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'

    def flops(self, N):
        ### calculate flops for 1 window with token length of N
        flops = 0
        # qkv = self.qkv(x)
        flops += N * self.dim * 3 * self.dim
        # attn = (q @ k.transpose(-2, -1)) 
        flops += self.num_heads * N * (self.dim // self.num_heads) * N
        #  x = (attn @ v)  
        flops += self.num_heads * N * N * (self.dim // self.num_heads)
        # x = self.proj(x)
        flops += N * self.dim * self.dim
        return flops

        一个不错的整体流程示意图:

相对位置编码矩阵:每一列 代表 每一个坐标在所有坐标 “眼中” 的相对位置

        主要 shape 变化注释:

class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # 通常默认 wh = ww = w = 4
        self.num_heads = num_heads  # MHA 的头数
        head_dim = dim // num_heads  # dim 平均分给每个 head
        self.scale = qk_scale or head_dim ** -0.5  # MHA 内的 scale 分母: 自定义的 qk_scale 或 根号 d

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # (2*wh-1 * 2*ww-1, num_heads)

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])  # wh
        coords_w = torch.arange(self.window_size[1])  # ww
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # (2, wh, ww)
        coords_flatten = torch.flatten(coords, 1)  # (2, wh*ww)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # (2, wh*ww, wh*ww)
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # (wh*ww, wh*ww, 2)
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # (wh*ww, wh*ww)
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        # 默认 N = wh*ww = w*w = 16
        # 默认 num_windows = (H*W)//(wh*ww) = (H*W)//16
        # 默认 C = 3

        # (num_windows*B, N, C) = (num_windows*B, wh*ww, C) 
        B_, N, C = x.shape  

        # (num_windows*B, N, C, num_heads, C//num_heads) -> (C, num_windows*B, num_heads, wh*ww, C//num_heads)
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 

        # (num_windows*B, num_heads, wh*ww, C//num_heads)  
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)   

        # (num_windows*B, num_heads, wh*ww, C//num_heads)
        q = q * self.scale  
        
        # (num_windows*B, num_heads, wh*ww, C//num_heads) * (num_windows*B, num_heads, C//num_heads, wh*ww) = (num_windows*B, num_heads, wh*ww, wh*ww)
        attn = (q @ k.transpose(-2, -1))  

        # (wh*ww, wh*ww, num_heads)
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  

        # (num_heads, wh*ww, wh*ww)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  

        # (num_heads, wh*ww, wh*ww) -> (1, num_heads, wh*ww, wh*ww) -> (num_windows*B, num_heads, wh*ww, wh*ww)
        attn = attn + relative_position_bias.unsqueeze(0)  #

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        # (num_windows*B, num_heads, wh*ww, wh*ww)
        attn = self.attn_drop(attn)  

        # (num_windows*B, num_heads, wh*ww, wh*ww) * (num_windows*B, num_heads, wh*ww, C//num_heads) = (num_windows*B, num_heads, wh*ww, C//num_heads)
        # (num_windows*B, num_heads, wh*ww, C//num_heads) -> (num_windows*B, wh*ww, num_heads, C//num_heads) -> (num_windows*B, wh*ww, C) = (N*B, wh*ww, C)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)

        x = self.proj(x)
        x = self.proj_drop(x)

        return x

    def extra_repr(self) -> str:
        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'

    def flops(self, N):
        # calculate flops for 1 window with token length of N
        flops = 0
        # qkv = self.qkv(x)
        flops += N * self.dim * 3 * self.dim
        # attn = (q @ k.transpose(-2, -1))
        flops += self.num_heads * N * (self.dim // self.num_heads) * N
        #  x = (attn @ v)
        flops += self.num_heads * N * N * (self.dim // self.num_heads)
        # x = self.proj(x)
        flops += N * self.dim * self.dim
        return flops

[点击并拖拽以移动]

       
​

        主要 shape 变化演示:

import torch
import torch.nn as nn


# 以 4×4 窗口大小为例
window_size = (4, 4)

coords_h = torch.arange(window_size[0])  # wh
coords_w = torch.arange(window_size[1])  # ww
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # (2, wh, ww)
coords, coords.shape
(tensor([[[0, 0, 0, 0],
          [1, 1, 1, 1],
          [2, 2, 2, 2],
          [3, 3, 3, 3]],
 
         [[0, 1, 2, 3],
          [0, 1, 2, 3],
          [0, 1, 2, 3],
          [0, 1, 2, 3]]]),
 torch.Size([2, 4, 4]))
coords_flatten = torch.flatten(coords, 1)  # (2, wh*ww)
coords_flatten, coords_flatten.shape
(tensor([[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],
         [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]]),
 torch.Size([2, 16]))
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # (2, wh*ww, wh*ww)
relative_coords, relative_coords.shape
(tensor([[[ 0,  0,  0,  0, -1, -1, -1, -1, -2, -2, -2, -2, -3, -3, -3, -3],
          [ 0,  0,  0,  0, -1, -1, -1, -1, -2, -2, -2, -2, -3, -3, -3, -3],
          [ 0,  0,  0,  0, -1, -1, -1, -1, -2, -2, -2, -2, -3, -3, -3, -3],
          [ 0,  0,  0,  0, -1, -1, -1, -1, -2, -2, -2, -2, -3, -3, -3, -3],
          [ 1,  1,  1,  1,  0,  0,  0,  0, -1, -1, -1, -1, -2, -2, -2, -2],
          [ 1,  1,  1,  1,  0,  0,  0,  0, -1, -1, -1, -1, -2, -2, -2, -2],
          [ 1,  1,  1,  1,  0,  0,  0,  0, -1, -1, -1, -1, -2, -2, -2, -2],
          [ 1,  1,  1,  1,  0,  0,  0,  0, -1, -1, -1, -1, -2, -2, -2, -2],
          [ 2,  2,  2,  2,  1,  1,  1,  1,  0,  0,  0,  0, -1, -1, -1, -1],
          [ 2,  2,  2,  2,  1,  1,  1,  1,  0,  0,  0,  0, -1, -1, -1, -1],
          [ 2,  2,  2,  2,  1,  1,  1,  1,  0,  0,  0,  0, -1, -1, -1, -1],
          [ 2,  2,  2,  2,  1,  1,  1,  1,  0,  0,  0,  0, -1, -1, -1, -1],
          [ 3,  3,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,  0,  0,  0,  0],
          [ 3,  3,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,  0,  0,  0,  0],
          [ 3,  3,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,  0,  0,  0,  0],
          [ 3,  3,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,  0,  0,  0,  0]],
 
         [[ 0, -1, -2, -3,  0, -1, -2, -3,  0, -1, -2, -3,  0, -1, -2, -3],
          [ 1,  0, -1, -2,  1,  0, -1, -2,  1,  0, -1, -2,  1,  0, -1, -2],
          [ 2,  1,  0, -1,  2,  1,  0, -1,  2,  1,  0, -1,  2,  1,  0, -1],
          [ 3,  2,  1,  0,  3,  2,  1,  0,  3,  2,  1,  0,  3,  2,  1,  0],
          [ 0, -1, -2, -3,  0, -1, -2, -3,  0, -1, -2, -3,  0, -1, -2, -3],
          [ 1,  0, -1, -2,  1,  0, -1, -2,  1,  0, -1, -2,  1,  0, -1, -2],
          [ 2,  1,  0, -1,  2,  1,  0, -1,  2,  1,  0, -1,  2,  1,  0, -1],
          [ 3,  2,  1,  0,  3,  2,  1,  0,  3,  2,  1,  0,  3,  2,  1,  0],
          [ 0, -1, -2, -3,  0, -1, -2, -3,  0, -1, -2, -3,  0, -1, -2, -3],
          [ 1,  0, -1, -2,  1,  0, -1, -2,  1,  0, -1, -2,  1,  0, -1, -2],
          [ 2,  1,  0, -1,  2,  1,  0, -1,  2,  1,  0, -1,  2,  1,  0, -1],
          [ 3,  2,  1,  0,  3,  2,  1,  0,  3,  2,  1,  0,  3,  2,  1,  0],
          [ 0, -1, -2, -3,  0, -1, -2, -3,  0, -1, -2, -3,  0, -1, -2, -3],
          [ 1,  0, -1, -2,  1,  0, -1, -2,  1,  0, -1, -2,  1,  0, -1, -2],
          [ 2,  1,  0, -1,  2,  1,  0, -1,  2,  1,  0, -1,  2,  1,  0, -1],
          [ 3,  2,  1,  0,  3,  2,  1,  0,  3,  2,  1,  0,  3,  2,  1,  0]]]),
 torch.Size([2, 16, 16]))
# (x, y) 格式显示 横、纵坐标
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # (wh*ww, wh*ww, 2)
relative_coords, relative_coords.shape
(tensor([[[ 0,  0],
          [ 0, -1],
          [ 0, -2],
          [ 0, -3],
          [-1,  0],
          [-1, -1],
          [-1, -2],
          [-1, -3],
          [-2,  0],
          [-2, -1],
          [-2, -2],
          [-2, -3],
          [-3,  0],
          [-3, -1],
          [-3, -2],
          [-3, -3]],
 
         [[ 0,  1],
          [ 0,  0],
          [ 0, -1],
          [ 0, -2],
          [-1,  1],
          [-1,  0],
          [-1, -1],
          [-1, -2],
          [-2,  1],
          [-2,  0],
          [-2, -1],
          [-2, -2],
          [-3,  1],
          [-3,  0],
          [-3, -1],
          [-3, -2]],
 
         [[ 0,  2],
          [ 0,  1],
          [ 0,  0],
          [ 0, -1],
          [-1,  2],
          [-1,  1],
          [-1,  0],
          [-1, -1],
          [-2,  2],
          [-2,  1],
          [-2,  0],
          [-2, -1],
          [-3,  2],
          [-3,  1],
          [-3,  0],
          [-3, -1]],
 
         [[ 0,  3],
          [ 0,  2],
          [ 0,  1],
          [ 0,  0],
          [-1,  3],
          [-1,  2],
          [-1,  1],
          [-1,  0],
          [-2,  3],
          [-2,  2],
          [-2,  1],
          [-2,  0],
          [-3,  3],
          [-3,  2],
          [-3,  1],
          [-3,  0]],
 
         [[ 1,  0],
          [ 1, -1],
          [ 1, -2],
          [ 1, -3],
          [ 0,  0],
          [ 0, -1],
          [ 0, -2],
          [ 0, -3],
          [-1,  0],
          [-1, -1],
          [-1, -2],
          [-1, -3],
          [-2,  0],
          [-2, -1],
          [-2, -2],
          [-2, -3]],
 
         [[ 1,  1],
          [ 1,  0],
          [ 1, -1],
          [ 1, -2],
          [ 0,  1],
          [ 0,  0],
          [ 0, -1],
          [ 0, -2],
          [-1,  1],
          [-1,  0],
          [-1, -1],
          [-1, -2],
          [-2,  1],
          [-2,  0],
          [-2, -1],
          [-2, -2]],
 
         [[ 1,  2],
          [ 1,  1],
          [ 1,  0],
          [ 1, -1],
          [ 0,  2],
          [ 0,  1],
          [ 0,  0],
          [ 0, -1],
          [-1,  2],
          [-1,  1],
          [-1,  0],
          [-1, -1],
          [-2,  2],
          [-2,  1],
          [-2,  0],
          [-2, -1]],
 
         [[ 1,  3],
          [ 1,  2],
          [ 1,  1],
          [ 1,  0],
          [ 0,  3],
          [ 0,  2],
          [ 0,  1],
          [ 0,  0],
          [-1,  3],
          [-1,  2],
          [-1,  1],
          [-1,  0],
          [-2,  3],
          [-2,  2],
          [-2,  1],
          [-2,  0]],
 
         [[ 2,  0],
          [ 2, -1],
          [ 2, -2],
          [ 2, -3],
          [ 1,  0],
          [ 1, -1],
          [ 1, -2],
          [ 1, -3],
          [ 0,  0],
          [ 0, -1],
          [ 0, -2],
          [ 0, -3],
          [-1,  0],
          [-1, -1],
          [-1, -2],
          [-1, -3]],
 
         [[ 2,  1],
          [ 2,  0],
          [ 2, -1],
          [ 2, -2],
          [ 1,  1],
          [ 1,  0],
          [ 1, -1],
          [ 1, -2],
          [ 0,  1],
          [ 0,  0],
          [ 0, -1],
          [ 0, -2],
          [-1,  1],
          [-1,  0],
          [-1, -1],
          [-1, -2]],
 
         [[ 2,  2],
          [ 2,  1],
          [ 2,  0],
          [ 2, -1],
          [ 1,  2],
          [ 1,  1],
          [ 1,  0],
          [ 1, -1],
          [ 0,  2],
          [ 0,  1],
          [ 0,  0],
          [ 0, -1],
          [-1,  2],
          [-1,  1],
          [-1,  0],
          [-1, -1]],
 
         [[ 2,  3],
          [ 2,  2],
          [ 2,  1],
          [ 2,  0],
          [ 1,  3],
          [ 1,  2],
          [ 1,  1],
          [ 1,  0],
          [ 0,  3],
          [ 0,  2],
          [ 0,  1],
          [ 0,  0],
          [-1,  3],
          [-1,  2],
          [-1,  1],
          [-1,  0]],
 
         [[ 3,  0],
          [ 3, -1],
          [ 3, -2],
          [ 3, -3],
          [ 2,  0],
          [ 2, -1],
          [ 2, -2],
          [ 2, -3],
          [ 1,  0],
          [ 1, -1],
          [ 1, -2],
          [ 1, -3],
          [ 0,  0],
          [ 0, -1],
          [ 0, -2],
          [ 0, -3]],
 
         [[ 3,  1],
          [ 3,  0],
          [ 3, -1],
          [ 3, -2],
          [ 2,  1],
          [ 2,  0],
          [ 2, -1],
          [ 2, -2],
          [ 1,  1],
          [ 1,  0],
          [ 1, -1],
          [ 1, -2],
          [ 0,  1],
          [ 0,  0],
          [ 0, -1],
          [ 0, -2]],
 
         [[ 3,  2],
          [ 3,  1],
          [ 3,  0],
          [ 3, -1],
          [ 2,  2],
          [ 2,  1],
          [ 2,  0],
          [ 2, -1],
          [ 1,  2],
          [ 1,  1],
          [ 1,  0],
          [ 1, -1],
          [ 0,  2],
          [ 0,  1],
          [ 0,  0],
          [ 0, -1]],
 
         [[ 3,  3],
          [ 3,  2],
          [ 3,  1],
          [ 3,  0],
          [ 2,  3],
          [ 2,  2],
          [ 2,  1],
          [ 2,  0],
          [ 1,  3],
          [ 1,  2],
          [ 1,  1],
          [ 1,  0],
          [ 0,  3],
          [ 0,  2],
          [ 0,  1],
          [ 0,  0]]]),
 torch.Size([16, 16, 2]))
# 横坐标加性偏移 (+= 3)
relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
relative_coords
tensor([[[ 3,  0],
         [ 3, -1],
         [ 3, -2],
         [ 3, -3],
         [ 2,  0],
         [ 2, -1],
         [ 2, -2],
         [ 2, -3],
         [ 1,  0],
         [ 1, -1],
         [ 1, -2],
         [ 1, -3],
         [ 0,  0],
         [ 0, -1],
         [ 0, -2],
         [ 0, -3]],

        [[ 3,  1],
         [ 3,  0],
         [ 3, -1],
         [ 3, -2],
         [ 2,  1],
         [ 2,  0],
         [ 2, -1],
         [ 2, -2],
         [ 1,  1],
         [ 1,  0],
         [ 1, -1],
         [ 1, -2],
         [ 0,  1],
         [ 0,  0],
         [ 0, -1],
         [ 0, -2]],

        [[ 3,  2],
         [ 3,  1],
         [ 3,  0],
         [ 3, -1],
         [ 2,  2],
         [ 2,  1],
         [ 2,  0],
         [ 2, -1],
         [ 1,  2],
         [ 1,  1],
         [ 1,  0],
         [ 1, -1],
         [ 0,  2],
         [ 0,  1],
         [ 0,  0],
         [ 0, -1]],

        [[ 3,  3],
         [ 3,  2],
         [ 3,  1],
         [ 3,  0],
         [ 2,  3],
         [ 2,  2],
         [ 2,  1],
         [ 2,  0],
         [ 1,  3],
         [ 1,  2],
         [ 1,  1],
         [ 1,  0],
         [ 0,  3],
         [ 0,  2],
         [ 0,  1],
         [ 0,  0]],

        [[ 4,  0],
         [ 4, -1],
         [ 4, -2],
         [ 4, -3],
         [ 3,  0],
         [ 3, -1],
         [ 3, -2],
         [ 3, -3],
         [ 2,  0],
         [ 2, -1],
         [ 2, -2],
         [ 2, -3],
         [ 1,  0],
         [ 1, -1],
         [ 1, -2],
         [ 1, -3]],

        [[ 4,  1],
         [ 4,  0],
         [ 4, -1],
         [ 4, -2],
         [ 3,  1],
         [ 3,  0],
         [ 3, -1],
         [ 3, -2],
         [ 2,  1],
         [ 2,  0],
         [ 2, -1],
         [ 2, -2],
         [ 1,  1],
         [ 1,  0],
         [ 1, -1],
         [ 1, -2]],

        [[ 4,  2],
         [ 4,  1],
         [ 4,  0],
         [ 4, -1],
         [ 3,  2],
         [ 3,  1],
         [ 3,  0],
         [ 3, -1],
         [ 2,  2],
         [ 2,  1],
         [ 2,  0],
         [ 2, -1],
         [ 1,  2],
         [ 1,  1],
         [ 1,  0],
         [ 1, -1]],

        [[ 4,  3],
         [ 4,  2],
         [ 4,  1],
         [ 4,  0],
         [ 3,  3],
         [ 3,  2],
         [ 3,  1],
         [ 3,  0],
         [ 2,  3],
         [ 2,  2],
         [ 2,  1],
         [ 2,  0],
         [ 1,  3],
         [ 1,  2],
         [ 1,  1],
         [ 1,  0]],

        [[ 5,  0],
         [ 5, -1],
         [ 5, -2],
         [ 5, -3],
         [ 4,  0],
         [ 4, -1],
         [ 4, -2],
         [ 4, -3],
         [ 3,  0],
         [ 3, -1],
         [ 3, -2],
         [ 3, -3],
         [ 2,  0],
         [ 2, -1],
         [ 2, -2],
         [ 2, -3]],

        [[ 5,  1],
         [ 5,  0],
         [ 5, -1],
         [ 5, -2],
         [ 4,  1],
         [ 4,  0],
         [ 4, -1],
         [ 4, -2],
         [ 3,  1],
         [ 3,  0],
         [ 3, -1],
         [ 3, -2],
         [ 2,  1],
         [ 2,  0],
         [ 2, -1],
         [ 2, -2]],

        [[ 5,  2],
         [ 5,  1],
         [ 5,  0],
         [ 5, -1],
         [ 4,  2],
         [ 4,  1],
         [ 4,  0],
         [ 4, -1],
         [ 3,  2],
         [ 3,  1],
         [ 3,  0],
         [ 3, -1],
         [ 2,  2],
         [ 2,  1],
         [ 2,  0],
         [ 2, -1]],

        [[ 5,  3],
         [ 5,  2],
         [ 5,  1],
         [ 5,  0],
         [ 4,  3],
         [ 4,  2],
         [ 4,  1],
         [ 4,  0],
         [ 3,  3],
         [ 3,  2],
         [ 3,  1],
         [ 3,  0],
         [ 2,  3],
         [ 2,  2],
         [ 2,  1],
         [ 2,  0]],

        [[ 6,  0],
         [ 6, -1],
         [ 6, -2],
         [ 6, -3],
         [ 5,  0],
         [ 5, -1],
         [ 5, -2],
         [ 5, -3],
         [ 4,  0],
         [ 4, -1],
         [ 4, -2],
         [ 4, -3],
         [ 3,  0],
         [ 3, -1],
         [ 3, -2],
         [ 3, -3]],

        [[ 6,  1],
         [ 6,  0],
         [ 6, -1],
         [ 6, -2],
         [ 5,  1],
         [ 5,  0],
         [ 5, -1],
         [ 5, -2],
         [ 4,  1],
         [ 4,  0],
         [ 4, -1],
         [ 4, -2],
         [ 3,  1],
         [ 3,  0],
         [ 3, -1],
         [ 3, -2]],

        [[ 6,  2],
         [ 6,  1],
         [ 6,  0],
         [ 6, -1],
         [ 5,  2],
         [ 5,  1],
         [ 5,  0],
         [ 5, -1],
         [ 4,  2],
         [ 4,  1],
         [ 4,  0],
         [ 4, -1],
         [ 3,  2],
         [ 3,  1],
         [ 3,  0],
         [ 3, -1]],

        [[ 6,  3],
         [ 6,  2],
         [ 6,  1],
         [ 6,  0],
         [ 5,  3],
         [ 5,  2],
         [ 5,  1],
         [ 5,  0],
         [ 4,  3],
         [ 4,  2],
         [ 4,  1],
         [ 4,  0],
         [ 3,  3],
         [ 3,  2],
         [ 3,  1],
         [ 3,  0]]])
# 纵坐标加性偏移 (+= 3)
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords
tensor([[[3, 3],
         [3, 2],
         [3, 1],
         [3, 0],
         [2, 3],
         [2, 2],
         [2, 1],
         [2, 0],
         [1, 3],
         [1, 2],
         [1, 1],
         [1, 0],
         [0, 3],
         [0, 2],
         [0, 1],
         [0, 0]],

        [[3, 4],
         [3, 3],
         [3, 2],
         [3, 1],
         [2, 4],
         [2, 3],
         [2, 2],
         [2, 1],
         [1, 4],
         [1, 3],
         [1, 2],
         [1, 1],
         [0, 4],
         [0, 3],
         [0, 2],
         [0, 1]],

        [[3, 5],
         [3, 4],
         [3, 3],
         [3, 2],
         [2, 5],
         [2, 4],
         [2, 3],
         [2, 2],
         [1, 5],
         [1, 4],
         [1, 3],
         [1, 2],
         [0, 5],
         [0, 4],
         [0, 3],
         [0, 2]],

        [[3, 6],
         [3, 5],
         [3, 4],
         [3, 3],
         [2, 6],
         [2, 5],
         [2, 4],
         [2, 3],
         [1, 6],
         [1, 5],
         [1, 4],
         [1, 3],
         [0, 6],
         [0, 5],
         [0, 4],
         [0, 3]],

        [[4, 3],
         [4, 2],
         [4, 1],
         [4, 0],
         [3, 3],
         [3, 2],
         [3, 1],
         [3, 0],
         [2, 3],
         [2, 2],
         [2, 1],
         [2, 0],
         [1, 3],
         [1, 2],
         [1, 1],
         [1, 0]],

        [[4, 4],
         [4, 3],
         [4, 2],
         [4, 1],
         [3, 4],
         [3, 3],
         [3, 2],
         [3, 1],
         [2, 4],
         [2, 3],
         [2, 2],
         [2, 1],
         [1, 4],
         [1, 3],
         [1, 2],
         [1, 1]],

        [[4, 5],
         [4, 4],
         [4, 3],
         [4, 2],
         [3, 5],
         [3, 4],
         [3, 3],
         [3, 2],
         [2, 5],
         [2, 4],
         [2, 3],
         [2, 2],
         [1, 5],
         [1, 4],
         [1, 3],
         [1, 2]],

        [[4, 6],
         [4, 5],
         [4, 4],
         [4, 3],
         [3, 6],
         [3, 5],
         [3, 4],
         [3, 3],
         [2, 6],
         [2, 5],
         [2, 4],
         [2, 3],
         [1, 6],
         [1, 5],
         [1, 4],
         [1, 3]],

        [[5, 3],
         [5, 2],
         [5, 1],
         [5, 0],
         [4, 3],
         [4, 2],
         [4, 1],
         [4, 0],
         [3, 3],
         [3, 2],
         [3, 1],
         [3, 0],
         [2, 3],
         [2, 2],
         [2, 1],
         [2, 0]],

        [[5, 4],
         [5, 3],
         [5, 2],
         [5, 1],
         [4, 4],
         [4, 3],
         [4, 2],
         [4, 1],
         [3, 4],
         [3, 3],
         [3, 2],
         [3, 1],
         [2, 4],
         [2, 3],
         [2, 2],
         [2, 1]],

        [[5, 5],
         [5, 4],
         [5, 3],
         [5, 2],
         [4, 5],
         [4, 4],
         [4, 3],
         [4, 2],
         [3, 5],
         [3, 4],
         [3, 3],
         [3, 2],
         [2, 5],
         [2, 4],
         [2, 3],
         [2, 2]],

        [[5, 6],
         [5, 5],
         [5, 4],
         [5, 3],
         [4, 6],
         [4, 5],
         [4, 4],
         [4, 3],
         [3, 6],
         [3, 5],
         [3, 4],
         [3, 3],
         [2, 6],
         [2, 5],
         [2, 4],
         [2, 3]],

        [[6, 3],
         [6, 2],
         [6, 1],
         [6, 0],
         [5, 3],
         [5, 2],
         [5, 1],
         [5, 0],
         [4, 3],
         [4, 2],
         [4, 1],
         [4, 0],
         [3, 3],
         [3, 2],
         [3, 1],
         [3, 0]],

        [[6, 4],
         [6, 3],
         [6, 2],
         [6, 1],
         [5, 4],
         [5, 3],
         [5, 2],
         [5, 1],
         [4, 4],
         [4, 3],
         [4, 2],
         [4, 1],
         [3, 4],
         [3, 3],
         [3, 2],
         [3, 1]],

        [[6, 5],
         [6, 4],
         [6, 3],
         [6, 2],
         [5, 5],
         [5, 4],
         [5, 3],
         [5, 2],
         [4, 5],
         [4, 4],
         [4, 3],
         [4, 2],
         [3, 5],
         [3, 4],
         [3, 3],
         [3, 2]],

        [[6, 6],
         [6, 5],
         [6, 4],
         [6, 3],
         [5, 6],
         [5, 5],
         [5, 4],
         [5, 3],
         [4, 6],
         [4, 5],
         [4, 4],
         [4, 3],
         [3, 6],
         [3, 5],
         [3, 4],
         [3, 3]]])
# 横坐标乘性变换 (*= 7)
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_coords
tensor([[[21,  3],
         [21,  2],
         [21,  1],
         [21,  0],
         [14,  3],
         [14,  2],
         [14,  1],
         [14,  0],
         [ 7,  3],
         [ 7,  2],
         [ 7,  1],
         [ 7,  0],
         [ 0,  3],
         [ 0,  2],
         [ 0,  1],
         [ 0,  0]],

        [[21,  4],
         [21,  3],
         [21,  2],
         [21,  1],
         [14,  4],
         [14,  3],
         [14,  2],
         [14,  1],
         [ 7,  4],
         [ 7,  3],
         [ 7,  2],
         [ 7,  1],
         [ 0,  4],
         [ 0,  3],
         [ 0,  2],
         [ 0,  1]],

        [[21,  5],
         [21,  4],
         [21,  3],
         [21,  2],
         [14,  5],
         [14,  4],
         [14,  3],
         [14,  2],
         [ 7,  5],
         [ 7,  4],
         [ 7,  3],
         [ 7,  2],
         [ 0,  5],
         [ 0,  4],
         [ 0,  3],
         [ 0,  2]],

        [[21,  6],
         [21,  5],
         [21,  4],
         [21,  3],
         [14,  6],
         [14,  5],
         [14,  4],
         [14,  3],
         [ 7,  6],
         [ 7,  5],
         [ 7,  4],
         [ 7,  3],
         [ 0,  6],
         [ 0,  5],
         [ 0,  4],
         [ 0,  3]],

        [[28,  3],
         [28,  2],
         [28,  1],
         [28,  0],
         [21,  3],
         [21,  2],
         [21,  1],
         [21,  0],
         [14,  3],
         [14,  2],
         [14,  1],
         [14,  0],
         [ 7,  3],
         [ 7,  2],
         [ 7,  1],
         [ 7,  0]],

        [[28,  4],
         [28,  3],
         [28,  2],
         [28,  1],
         [21,  4],
         [21,  3],
         [21,  2],
         [21,  1],
         [14,  4],
         [14,  3],
         [14,  2],
         [14,  1],
         [ 7,  4],
         [ 7,  3],
         [ 7,  2],
         [ 7,  1]],

        [[28,  5],
         [28,  4],
         [28,  3],
         [28,  2],
         [21,  5],
         [21,  4],
         [21,  3],
         [21,  2],
         [14,  5],
         [14,  4],
         [14,  3],
         [14,  2],
         [ 7,  5],
         [ 7,  4],
         [ 7,  3],
         [ 7,  2]],

        [[28,  6],
         [28,  5],
         [28,  4],
         [28,  3],
         [21,  6],
         [21,  5],
         [21,  4],
         [21,  3],
         [14,  6],
         [14,  5],
         [14,  4],
         [14,  3],
         [ 7,  6],
         [ 7,  5],
         [ 7,  4],
         [ 7,  3]],

        [[35,  3],
         [35,  2],
         [35,  1],
         [35,  0],
         [28,  3],
         [28,  2],
         [28,  1],
         [28,  0],
         [21,  3],
         [21,  2],
         [21,  1],
         [21,  0],
         [14,  3],
         [14,  2],
         [14,  1],
         [14,  0]],

        [[35,  4],
         [35,  3],
         [35,  2],
         [35,  1],
         [28,  4],
         [28,  3],
         [28,  2],
         [28,  1],
         [21,  4],
         [21,  3],
         [21,  2],
         [21,  1],
         [14,  4],
         [14,  3],
         [14,  2],
         [14,  1]],

        [[35,  5],
         [35,  4],
         [35,  3],
         [35,  2],
         [28,  5],
         [28,  4],
         [28,  3],
         [28,  2],
         [21,  5],
         [21,  4],
         [21,  3],
         [21,  2],
         [14,  5],
         [14,  4],
         [14,  3],
         [14,  2]],

        [[35,  6],
         [35,  5],
         [35,  4],
         [35,  3],
         [28,  6],
         [28,  5],
         [28,  4],
         [28,  3],
         [21,  6],
         [21,  5],
         [21,  4],
         [21,  3],
         [14,  6],
         [14,  5],
         [14,  4],
         [14,  3]],

        [[42,  3],
         [42,  2],
         [42,  1],
         [42,  0],
         [35,  3],
         [35,  2],
         [35,  1],
         [35,  0],
         [28,  3],
         [28,  2],
         [28,  1],
         [28,  0],
         [21,  3],
         [21,  2],
         [21,  1],
         [21,  0]],

        [[42,  4],
         [42,  3],
         [42,  2],
         [42,  1],
         [35,  4],
         [35,  3],
         [35,  2],
         [35,  1],
         [28,  4],
         [28,  3],
         [28,  2],
         [28,  1],
         [21,  4],
         [21,  3],
         [21,  2],
         [21,  1]],

        [[42,  5],
         [42,  4],
         [42,  3],
         [42,  2],
         [35,  5],
         [35,  4],
         [35,  3],
         [35,  2],
         [28,  5],
         [28,  4],
         [28,  3],
         [28,  2],
         [21,  5],
         [21,  4],
         [21,  3],
         [21,  2]],

        [[42,  6],
         [42,  5],
         [42,  4],
         [42,  3],
         [35,  6],
         [35,  5],
         [35,  4],
         [35,  3],
         [28,  6],
         [28,  5],
         [28,  4],
         [28,  3],
         [21,  6],
         [21,  5],
         [21,  4],
         [21,  3]]])
# 计算 1D 偏移量 (x+y)
relative_position_index = relative_coords.sum(-1)  # (wh*ww, wh*ww)
relative_position_index, relative_position_index.shape

# 可见偏移量大小沿主对角线垂直方向扩散、分布
# 16 列与 4×4 个坐标位置一一对应
(tensor([[24, 23, 22, 21, 17, 16, 15, 14, 10,  9,  8,  7,  3,  2,  1,  0],
         [25, 24, 23, 22, 18, 17, 16, 15, 11, 10,  9,  8,  4,  3,  2,  1],
         [26, 25, 24, 23, 19, 18, 17, 16, 12, 11, 10,  9,  5,  4,  3,  2],
         [27, 26, 25, 24, 20, 19, 18, 17, 13, 12, 11, 10,  6,  5,  4,  3],
         [31, 30, 29, 28, 24, 23, 22, 21, 17, 16, 15, 14, 10,  9,  8,  7],
         [32, 31, 30, 29, 25, 24, 23, 22, 18, 17, 16, 15, 11, 10,  9,  8],
         [33, 32, 31, 30, 26, 25, 24, 23, 19, 18, 17, 16, 12, 11, 10,  9],
         [34, 33, 32, 31, 27, 26, 25, 24, 20, 19, 18, 17, 13, 12, 11, 10],
         [38, 37, 36, 35, 31, 30, 29, 28, 24, 23, 22, 21, 17, 16, 15, 14],
         [39, 38, 37, 36, 32, 31, 30, 29, 25, 24, 23, 22, 18, 17, 16, 15],
         [40, 39, 38, 37, 33, 32, 31, 30, 26, 25, 24, 23, 19, 18, 17, 16],
         [41, 40, 39, 38, 34, 33, 32, 31, 27, 26, 25, 24, 20, 19, 18, 17],
         [45, 44, 43, 42, 38, 37, 36, 35, 31, 30, 29, 28, 24, 23, 22, 21],
         [46, 45, 44, 43, 39, 38, 37, 36, 32, 31, 30, 29, 25, 24, 23, 22],
         [47, 46, 45, 44, 40, 39, 38, 37, 33, 32, 31, 30, 26, 25, 24, 23],
         [48, 47, 46, 45, 41, 40, 39, 38, 34, 33, 32, 31, 27, 26, 25, 24]]),
 torch.Size([16, 16]))
# 设 MHA 的 heads 数为 3
num_heads = 3

relative_position_bias_table = nn.Parameter(
    torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
relative_position_bias_table, relative_position_bias_table.shape
(Parameter containing:
 tensor([[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]], requires_grad=True),
 torch.Size([49, 3]))
relative_position_bias = relative_position_bias_table[relative_position_index.view(-1)].view(
            window_size[0] * window_size[1], window_size[0] * window_size[1], -1)

relative_position_bias, relative_position_bias.shape
(tensor([[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]], grad_fn=<ViewBackward>),
 torch.Size([16, 16, 3]))
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
relative_position_bias, relative_position_bias.shape
(tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
 
         [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
 
         [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
        grad_fn=<CopyBackwards>),
 torch.Size([3, 16, 16]))

        以上基本展示了相对位置编码和偏置的生成过程。


3.8 Swin Transformer Block 

3.8.1 Shift Window Attention

        基本的 Attention (W-MSA) 是在每个窗口下计算的,为更好地 和其他窗口交互信息,Swin Transformer 还引入了 移位窗口 (Shifted Window) 操作。

        左边是无重叠的 Basic Window Attention,而右边则是将窗口进行次对角线方向移位的 Shift Window Attention。可见移位后的窗口包含了原相邻窗口的元素,但也随之引入了 Window 个数增多的问题 —— 窗口由 4 个变成了 9 个。

        实现时,通过对特征图移位,并给 Attention 设置 Mask 来间接实现 Shift Window Attention (SW-MSA)。从而,在保持原 Window 数不变的情况下,使最后的计算结果等价。

         在代码中,通过 torch.roll 对特征图移位,如下所示:

         若要执行 reverse cyclic shift 只需将参数 shifts 设为对应的正数。

3.8.2 Attention Mask

        通过合理设置 Mask,可使 Shifted Window Attention (SW-MSA) 在与 Basic Window Attention (W-MSA) 窗口个数相同的情况下,达到等价的计算结果。

        首先,对 Shift Window 后的每个窗口都赋予 index,并执行 roll 操作 (window_size=2, shift_size=1),如下所示:

        在计算 Attention Map 时,希望 仅留下具有相同 index 的 Query 和 Key 的计算结果,而 忽略不同 index 的 Query 和 Key 的计算结果,如下所示:

        注意,Value 和 Query 的 shape 一致 (4×1),以上方法计算的 shape = (4×4) 的 QK 乃至 Attention Map 与 Value 相乘时,依然能够得到正确位置的运算结果,即 (4×4) · (4×1) = (4×1)。

        而若要在原始的四个窗口下得到正确计算结果,则必须给 Attention Map 加入一个 Mask (如上图灰色 patch),相关代码如下:

        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

        以上图设置,使用上述代码将得到如下的 mask:

tensor([[[[[   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.]]],


         [[[   0., -100.,    0., -100.],
           [-100.,    0., -100.,    0.],
           [   0., -100.,    0., -100.],
           [-100.,    0., -100.,    0.]]],


         [[[   0.,    0., -100., -100.],
           [   0.,    0., -100., -100.],
           [-100., -100.,    0.,    0.],
           [-100., -100.,    0.,    0.]]],


         [[[   0., -100., -100., -100.],
           [-100.,    0., -100., -100.],
           [-100., -100.,    0., -100.],
           [-100., -100., -100.,    0.]]]]])

        在 Window Attention 模块的前向过程代码中,包含一段:

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)

        其将值设为 -100 的 mask 直接加到 Attention Map 上,并在 reshape 后通过 Softmax 近似忽略之。

        最后是 Swin Transformer Block 示意图及其代码:

class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.
    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        ##################### 循环移位局部窗口自注意力 #####################
        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"

    def flops(self):
        flops = 0
        H, W = self.input_resolution
        # norm1
        flops += self.dim * H * W
        # W-MSA/SW-MSA
        nW = H * W / self.window_size / self.window_size
        flops += nW * self.attn.flops(self.window_size * self.window_size)
        # mlp
        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
        # norm2
        flops += self.dim * H * W
        return flops

         举例模拟说明:

def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size - 通常默认 wh = ww = w = 4

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    # (B, H, W, C) -> (B, H//wh, wh, W//ww, ww, C) -> (B, H//wh, W//ww, wh, ww, C) -> ((B*H*W)//(wh*ww), wh, ww, C)
    B, H, W, C = x.shape
    #print(x.shape)
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    #print(x.shape)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows  # ((B*H*W)//(wh*ww), wh, ww, C) = (N, wh, ww, C)


def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size - 通常默认 wh = ww = w = 4
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    """
    # ((B*H*W)//(wh*ww), w, w, C) -> (B, H//wh, W//ww, wh, ww, C) -> (B, H//wh, wh, W//ww, ww, C) -> (B, H, W, C)
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x
batch_size = 1
num_channel = 3
input_resolution = (8, 8)
image = torch.rand(batch_size, num_channel, input_resolution[0], input_resolution[1])
image.shape
torch.Size([1, 3, 8, 8])
# local window 的 size 完全由 window_size 和 shift_size 两种长度的组合构成
window_size = 4
shift_size = window_size // 2  # 2

# calculate attention mask for SW-MSA
H, W = input_resolution
img_mask = torch.zeros((1, H, W, 1))  # (1, H, W, 1)

# local window index range 
h_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))
w_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))

# 按 local window 标记 idx (即 cnt)
idx = 0
for h in h_slices:
    for w in w_slices:
        print(h, w, idx)  
        img_mask[:, h, w, :] = idx
        idx += 1
        
print(f"num local windows: {idx}")
img_mask.shape, img_mask[0, ..., 0]
slice(0, -4, None) slice(0, -4, None) 0
slice(0, -4, None) slice(-4, -2, None) 1
slice(0, -4, None) slice(-2, None, None) 2
slice(-4, -2, None) slice(0, -4, None) 3
slice(-4, -2, None) slice(-4, -2, None) 4
slice(-4, -2, None) slice(-2, None, None) 5
slice(-2, None, None) slice(0, -4, None) 6
slice(-2, None, None) slice(-4, -2, None) 7
slice(-2, None, None) slice(-2, None, None) 8
num local windows: 9

(torch.Size([1, 8, 8, 1]),
 tensor([[0., 0., 0., 0., 1., 1., 2., 2.],
         [0., 0., 0., 0., 1., 1., 2., 2.],
         [0., 0., 0., 0., 1., 1., 2., 2.],
         [0., 0., 0., 0., 1., 1., 2., 2.],
         [3., 3., 3., 3., 4., 4., 5., 5.],
         [3., 3., 3., 3., 4., 4., 5., 5.],
         [6., 6., 6., 6., 7., 7., 8., 8.],
         [6., 6., 6., 6., 7., 7., 8., 8.]]))
# (1, 8, 8, 1) = (1, H, W, 1) = (B, H, W, C)  ->  ((B*H*W)//(wh*ww), wh, ww, C) = (N, wh, ww, C) = (4, 4, 4, 1)
mask_windows = window_partition(img_mask, window_size) 
mask_windows.shape, mask_windows[..., 0]
(torch.Size([4, 4, 4, 1]),
 tensor([[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],
 
         [[1., 1., 2., 2.],
          [1., 1., 2., 2.],
          [1., 1., 2., 2.],
          [1., 1., 2., 2.]],
 
         [[3., 3., 3., 3.],
          [3., 3., 3., 3.],
          [6., 6., 6., 6.],
          [6., 6., 6., 6.]],
 
         [[4., 4., 5., 5.],
          [4., 4., 5., 5.],
          [7., 7., 8., 8.],
          [7., 7., 8., 8.]]]))
mask_windows = mask_windows.view(-1, window_size * window_size)
mask_windows.shape, mask_windows
(torch.Size([4, 16]),
 tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 2., 2., 1., 1., 2., 2., 1., 1., 2., 2., 1., 1., 2., 2.],
         [3., 3., 3., 3., 3., 3., 3., 3., 6., 6., 6., 6., 6., 6., 6., 6.],
         [4., 4., 5., 5., 4., 4., 5., 5., 7., 7., 8., 8., 7., 7., 8., 8.]]))
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask.shape, attn_mask[..., 0]
(torch.Size([4, 4, 4, 4, 1]),
 tensor([[[[ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.]],
 
          [[ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.]],
 
          [[ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.]],
 
          [[ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.]]],
 
 
         [[[ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.]],
 
          [[ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.]],
 
          [[ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.]],
 
          [[ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.]]],
 
 
         [[[ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.],
           [ 3.,  3.,  3.,  3.],
           [ 3.,  3.,  3.,  3.]],
 
          [[ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.],
           [ 3.,  3.,  3.,  3.],
           [ 3.,  3.,  3.,  3.]],
 
          [[-3., -3., -3., -3.],
           [-3., -3., -3., -3.],
           [ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.]],
 
          [[-3., -3., -3., -3.],
           [-3., -3., -3., -3.],
           [ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.]]],
 
 
         [[[ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.],
           [ 3.,  3.,  3.,  3.],
           [ 3.,  3.,  3.,  3.]],
 
          [[ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.],
           [ 3.,  3.,  3.,  3.],
           [ 3.,  3.,  3.,  3.]],
 
          [[-3., -3., -3., -3.],
           [-3., -3., -3., -3.],
           [ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.]],
 
          [[-3., -3., -3., -3.],
           [-3., -3., -3., -3.],
           [ 0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.]]]]))
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
attn_mask.shape, attn_mask[..., 0]
(torch.Size([4, 4, 4, 4, 1]),
 tensor([[[[   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.]],
 
          [[   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.]],
 
          [[   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.]],
 
          [[   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.]]],
 
 
         [[[   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.]],
 
          [[   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.]],
 
          [[   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.]],
 
          [[   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.]]],
 
 
         [[[   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [-100., -100., -100., -100.],
           [-100., -100., -100., -100.]],
 
          [[   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [-100., -100., -100., -100.],
           [-100., -100., -100., -100.]],
 
          [[-100., -100., -100., -100.],
           [-100., -100., -100., -100.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.]],
 
          [[-100., -100., -100., -100.],
           [-100., -100., -100., -100.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.]]],
 
 
         [[[   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [-100., -100., -100., -100.],
           [-100., -100., -100., -100.]],
 
          [[   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.],
           [-100., -100., -100., -100.],
           [-100., -100., -100., -100.]],
 
          [[-100., -100., -100., -100.],
           [-100., -100., -100., -100.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.]],
 
          [[-100., -100., -100., -100.],
           [-100., -100., -100., -100.],
           [   0.,    0.,    0.,    0.],
           [   0.,    0.,    0.,    0.]]]]))
x = image.clone().permute(0, 2, 3, 1)  # (B, H, W, C)
B, H, W, C = x.shape
x[0, 0, 0, 0] = 0.  # 用于标记
x[0, -1, -1, -1] = 1.  # 用于标记
x.shape, x[..., 0], x[..., 1], x[..., 2]
(torch.Size([1, 8, 8, 3]),
 tensor([[[0.0000, 0.1060, 0.1015, 0.2196, 0.3544, 0.8485, 0.3509, 0.7353],
          [0.3751, 0.5690, 0.9630, 0.1945, 0.8999, 0.4977, 0.8007, 0.1598],
          [0.1991, 0.0081, 0.2861, 0.4539, 0.1620, 0.8776, 0.5298, 0.3748],
          [0.8627, 0.0345, 0.2435, 0.7224, 0.9310, 0.8621, 0.4113, 0.8057],
          [0.4686, 0.3372, 0.3429, 0.6660, 0.7115, 0.8560, 0.7055, 0.4709],
          [0.5104, 0.2789, 0.8015, 0.8737, 0.6784, 0.6677, 0.8233, 0.6589],
          [0.9013, 0.6980, 0.1548, 0.9066, 0.0334, 0.1617, 0.2747, 0.6150],
          [0.6927, 0.0394, 0.4180, 0.0387, 0.4488, 0.1339, 0.3340, 0.6178]]]),
 tensor([[[0.6409, 0.0384, 0.7775, 0.3550, 0.6754, 0.9210, 0.0923, 0.0691],
          [0.7684, 0.0422, 0.1605, 0.9409, 0.0397, 0.7786, 0.8475, 0.8495],
          [0.3684, 0.9283, 0.0569, 0.7790, 0.6074, 0.1229, 0.3138, 0.5926],
          [0.1865, 0.5766, 0.2497, 0.2391, 0.4254, 0.7249, 0.3116, 0.1666],
          [0.0800, 0.3956, 0.7351, 0.8919, 0.1177, 0.7949, 0.0028, 0.2635],
          [0.9967, 0.1205, 0.1785, 0.0886, 0.8664, 0.8412, 0.1258, 0.4302],
          [0.5620, 0.9326, 0.6767, 0.2432, 0.5963, 0.7276, 0.2273, 0.4879],
          [0.9367, 0.9096, 0.8327, 0.1795, 0.0361, 0.7189, 0.9292, 0.3822]]]),
 tensor([[[0.8273, 0.7008, 0.2891, 0.1136, 0.4981, 0.2119, 0.8096, 0.6342],
          [0.0062, 0.8495, 0.1382, 0.8667, 0.2436, 0.6408, 0.0238, 0.4742],
          [0.4363, 0.1852, 0.6110, 0.2923, 0.6231, 0.0668, 0.9430, 0.2830],
          [0.5139, 0.5424, 0.3008, 0.5251, 0.3518, 0.0882, 0.3335, 0.7853],
          [0.4835, 0.3571, 0.2530, 0.8452, 0.2010, 0.3866, 0.5673, 0.1172],
          [0.4750, 0.7820, 0.4262, 0.4426, 0.4161, 0.1199, 0.0477, 0.4646],
          [0.7525, 0.2686, 0.6829, 0.6753, 0.4345, 0.6609, 0.8414, 0.9140],
          [0.4540, 0.7455, 0.4464, 0.6749, 0.9152, 0.6936, 0.7035, 1.0000]]]))
# 可见标记 0.0000 和 1.0000 都在自身所在的 channel 里向左上角循环位移了 shift_size = window_size // 2 = 2 个 pixels
shifted_x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2))  # (B, H, W, C)
shifted_x.shape, shifted_x[..., 0], shifted_x[..., 1], shifted_x[..., 2]
(torch.Size([1, 8, 8, 3]),
 tensor([[[0.2861, 0.4539, 0.1620, 0.8776, 0.5298, 0.3748, 0.1991, 0.0081],
          [0.2435, 0.7224, 0.9310, 0.8621, 0.4113, 0.8057, 0.8627, 0.0345],
          [0.3429, 0.6660, 0.7115, 0.8560, 0.7055, 0.4709, 0.4686, 0.3372],
          [0.8015, 0.8737, 0.6784, 0.6677, 0.8233, 0.6589, 0.5104, 0.2789],
          [0.1548, 0.9066, 0.0334, 0.1617, 0.2747, 0.6150, 0.9013, 0.6980],
          [0.4180, 0.0387, 0.4488, 0.1339, 0.3340, 0.6178, 0.6927, 0.0394],
          [0.1015, 0.2196, 0.3544, 0.8485, 0.3509, 0.7353, 0.0000, 0.1060],
          [0.9630, 0.1945, 0.8999, 0.4977, 0.8007, 0.1598, 0.3751, 0.5690]]]),
 tensor([[[0.0569, 0.7790, 0.6074, 0.1229, 0.3138, 0.5926, 0.3684, 0.9283],
          [0.2497, 0.2391, 0.4254, 0.7249, 0.3116, 0.1666, 0.1865, 0.5766],
          [0.7351, 0.8919, 0.1177, 0.7949, 0.0028, 0.2635, 0.0800, 0.3956],
          [0.1785, 0.0886, 0.8664, 0.8412, 0.1258, 0.4302, 0.9967, 0.1205],
          [0.6767, 0.2432, 0.5963, 0.7276, 0.2273, 0.4879, 0.5620, 0.9326],
          [0.8327, 0.1795, 0.0361, 0.7189, 0.9292, 0.3822, 0.9367, 0.9096],
          [0.7775, 0.3550, 0.6754, 0.9210, 0.0923, 0.0691, 0.6409, 0.0384],
          [0.1605, 0.9409, 0.0397, 0.7786, 0.8475, 0.8495, 0.7684, 0.0422]]]),
 tensor([[[0.6110, 0.2923, 0.6231, 0.0668, 0.9430, 0.2830, 0.4363, 0.1852],
          [0.3008, 0.5251, 0.3518, 0.0882, 0.3335, 0.7853, 0.5139, 0.5424],
          [0.2530, 0.8452, 0.2010, 0.3866, 0.5673, 0.1172, 0.4835, 0.3571],
          [0.4262, 0.4426, 0.4161, 0.1199, 0.0477, 0.4646, 0.4750, 0.7820],
          [0.6829, 0.6753, 0.4345, 0.6609, 0.8414, 0.9140, 0.7525, 0.2686],
          [0.4464, 0.6749, 0.9152, 0.6936, 0.7035, 1.0000, 0.4540, 0.7455],
          [0.2891, 0.1136, 0.4981, 0.2119, 0.8096, 0.6342, 0.8273, 0.7008],
          [0.1382, 0.8667, 0.2436, 0.6408, 0.0238, 0.4742, 0.0062, 0.8495]]]))
x_windows = window_partition(shifted_x, window_size)  # (B*N, wh, ww, C) = (B*(H*W)//(wh*ww), wh, ww, C)
x_windows.shape, x_windows[..., 0], x_windows[..., 1], x_windows[..., 2]
(torch.Size([4, 4, 4, 3]),
 tensor([[[0.2861, 0.4539, 0.1620, 0.8776],
          [0.2435, 0.7224, 0.9310, 0.8621],
          [0.3429, 0.6660, 0.7115, 0.8560],
          [0.8015, 0.8737, 0.6784, 0.6677]],
 
         [[0.5298, 0.3748, 0.1991, 0.0081],
          [0.4113, 0.8057, 0.8627, 0.0345],
          [0.7055, 0.4709, 0.4686, 0.3372],
          [0.8233, 0.6589, 0.5104, 0.2789]],
 
         [[0.1548, 0.9066, 0.0334, 0.1617],
          [0.4180, 0.0387, 0.4488, 0.1339],
          [0.1015, 0.2196, 0.3544, 0.8485],
          [0.9630, 0.1945, 0.8999, 0.4977]],
 
         [[0.2747, 0.6150, 0.9013, 0.6980],
          [0.3340, 0.6178, 0.6927, 0.0394],
          [0.3509, 0.7353, 0.0000, 0.1060],
          [0.8007, 0.1598, 0.3751, 0.5690]]]),
 tensor([[[0.0569, 0.7790, 0.6074, 0.1229],
          [0.2497, 0.2391, 0.4254, 0.7249],
          [0.7351, 0.8919, 0.1177, 0.7949],
          [0.1785, 0.0886, 0.8664, 0.8412]],
 
         [[0.3138, 0.5926, 0.3684, 0.9283],
          [0.3116, 0.1666, 0.1865, 0.5766],
          [0.0028, 0.2635, 0.0800, 0.3956],
          [0.1258, 0.4302, 0.9967, 0.1205]],
 
         [[0.6767, 0.2432, 0.5963, 0.7276],
          [0.8327, 0.1795, 0.0361, 0.7189],
          [0.7775, 0.3550, 0.6754, 0.9210],
          [0.1605, 0.9409, 0.0397, 0.7786]],
 
         [[0.2273, 0.4879, 0.5620, 0.9326],
          [0.9292, 0.3822, 0.9367, 0.9096],
          [0.0923, 0.0691, 0.6409, 0.0384],
          [0.8475, 0.8495, 0.7684, 0.0422]]]),
 tensor([[[0.6110, 0.2923, 0.6231, 0.0668],
          [0.3008, 0.5251, 0.3518, 0.0882],
          [0.2530, 0.8452, 0.2010, 0.3866],
          [0.4262, 0.4426, 0.4161, 0.1199]],
 
         [[0.9430, 0.2830, 0.4363, 0.1852],
          [0.3335, 0.7853, 0.5139, 0.5424],
          [0.5673, 0.1172, 0.4835, 0.3571],
          [0.0477, 0.4646, 0.4750, 0.7820]],
 
         [[0.6829, 0.6753, 0.4345, 0.6609],
          [0.4464, 0.6749, 0.9152, 0.6936],
          [0.2891, 0.1136, 0.4981, 0.2119],
          [0.1382, 0.8667, 0.2436, 0.6408]],
 
         [[0.8414, 0.9140, 0.7525, 0.2686],
          [0.7035, 1.0000, 0.4540, 0.7455],
          [0.8096, 0.6342, 0.8273, 0.7008],
          [0.0238, 0.4742, 0.0062, 0.8495]]]))
x_windows = x_windows.view(-1, window_size * window_size, C)  # (B*N, wh*ww, C) = (B*(H*W)//(wh*ww), wh*ww, C)
x_windows.shape, x_windows[..., 0], x_windows[..., 1], x_windows[..., 2]
(torch.Size([4, 16, 3]),
 tensor([[0.2861, 0.4539, 0.1620, 0.8776, 0.2435, 0.7224, 0.9310, 0.8621, 0.3429,
          0.6660, 0.7115, 0.8560, 0.8015, 0.8737, 0.6784, 0.6677],
         [0.5298, 0.3748, 0.1991, 0.0081, 0.4113, 0.8057, 0.8627, 0.0345, 0.7055,
          0.4709, 0.4686, 0.3372, 0.8233, 0.6589, 0.5104, 0.2789],
         [0.1548, 0.9066, 0.0334, 0.1617, 0.4180, 0.0387, 0.4488, 0.1339, 0.1015,
          0.2196, 0.3544, 0.8485, 0.9630, 0.1945, 0.8999, 0.4977],
         [0.2747, 0.6150, 0.9013, 0.6980, 0.3340, 0.6178, 0.6927, 0.0394, 0.3509,
          0.7353, 0.0000, 0.1060, 0.8007, 0.1598, 0.3751, 0.5690]]),
 tensor([[0.0569, 0.7790, 0.6074, 0.1229, 0.2497, 0.2391, 0.4254, 0.7249, 0.7351,
          0.8919, 0.1177, 0.7949, 0.1785, 0.0886, 0.8664, 0.8412],
         [0.3138, 0.5926, 0.3684, 0.9283, 0.3116, 0.1666, 0.1865, 0.5766, 0.0028,
          0.2635, 0.0800, 0.3956, 0.1258, 0.4302, 0.9967, 0.1205],
         [0.6767, 0.2432, 0.5963, 0.7276, 0.8327, 0.1795, 0.0361, 0.7189, 0.7775,
          0.3550, 0.6754, 0.9210, 0.1605, 0.9409, 0.0397, 0.7786],
         [0.2273, 0.4879, 0.5620, 0.9326, 0.9292, 0.3822, 0.9367, 0.9096, 0.0923,
          0.0691, 0.6409, 0.0384, 0.8475, 0.8495, 0.7684, 0.0422]]),
 tensor([[0.6110, 0.2923, 0.6231, 0.0668, 0.3008, 0.5251, 0.3518, 0.0882, 0.2530,
          0.8452, 0.2010, 0.3866, 0.4262, 0.4426, 0.4161, 0.1199],
         [0.9430, 0.2830, 0.4363, 0.1852, 0.3335, 0.7853, 0.5139, 0.5424, 0.5673,
          0.1172, 0.4835, 0.3571, 0.0477, 0.4646, 0.4750, 0.7820],
         [0.6829, 0.6753, 0.4345, 0.6609, 0.4464, 0.6749, 0.9152, 0.6936, 0.2891,
          0.1136, 0.4981, 0.2119, 0.1382, 0.8667, 0.2436, 0.6408],
         [0.8414, 0.9140, 0.7525, 0.2686, 0.7035, 1.0000, 0.4540, 0.7455, 0.8096,
          0.6342, 0.8273, 0.7008, 0.0238, 0.4742, 0.0062, 0.8495]]))
# W-MSA/SW-MSA
#attn_windows = self.attn(x_windows, mask=self.attn_mask)  # 原操作 (B*N, wh*ww, C) = (B*(H*W)//(wh*ww), wh*ww, C)
attn_windows = x_windows.clone()  # 仅用于示范

# merge windows
attn_windows = attn_windows.view(-1, window_size, window_size, C)
attn_windows.shape
torch.Size([4, 4, 4, 3])
# reverse cyclic shift
shifted_x = window_reverse(attn_windows, window_size, H, W)
shifted_x.shape, shifted_x[..., 0], shifted_x[..., 1], shifted_x[..., 2]
(torch.Size([1, 8, 8, 3]),
 tensor([[[0.2861, 0.4539, 0.1620, 0.8776, 0.5298, 0.3748, 0.1991, 0.0081],
          [0.2435, 0.7224, 0.9310, 0.8621, 0.4113, 0.8057, 0.8627, 0.0345],
          [0.3429, 0.6660, 0.7115, 0.8560, 0.7055, 0.4709, 0.4686, 0.3372],
          [0.8015, 0.8737, 0.6784, 0.6677, 0.8233, 0.6589, 0.5104, 0.2789],
          [0.1548, 0.9066, 0.0334, 0.1617, 0.2747, 0.6150, 0.9013, 0.6980],
          [0.4180, 0.0387, 0.4488, 0.1339, 0.3340, 0.6178, 0.6927, 0.0394],
          [0.1015, 0.2196, 0.3544, 0.8485, 0.3509, 0.7353, 0.0000, 0.1060],
          [0.9630, 0.1945, 0.8999, 0.4977, 0.8007, 0.1598, 0.3751, 0.5690]]]),
 tensor([[[0.0569, 0.7790, 0.6074, 0.1229, 0.3138, 0.5926, 0.3684, 0.9283],
          [0.2497, 0.2391, 0.4254, 0.7249, 0.3116, 0.1666, 0.1865, 0.5766],
          [0.7351, 0.8919, 0.1177, 0.7949, 0.0028, 0.2635, 0.0800, 0.3956],
          [0.1785, 0.0886, 0.8664, 0.8412, 0.1258, 0.4302, 0.9967, 0.1205],
          [0.6767, 0.2432, 0.5963, 0.7276, 0.2273, 0.4879, 0.5620, 0.9326],
          [0.8327, 0.1795, 0.0361, 0.7189, 0.9292, 0.3822, 0.9367, 0.9096],
          [0.7775, 0.3550, 0.6754, 0.9210, 0.0923, 0.0691, 0.6409, 0.0384],
          [0.1605, 0.9409, 0.0397, 0.7786, 0.8475, 0.8495, 0.7684, 0.0422]]]),
 tensor([[[0.6110, 0.2923, 0.6231, 0.0668, 0.9430, 0.2830, 0.4363, 0.1852],
          [0.3008, 0.5251, 0.3518, 0.0882, 0.3335, 0.7853, 0.5139, 0.5424],
          [0.2530, 0.8452, 0.2010, 0.3866, 0.5673, 0.1172, 0.4835, 0.3571],
          [0.4262, 0.4426, 0.4161, 0.1199, 0.0477, 0.4646, 0.4750, 0.7820],
          [0.6829, 0.6753, 0.4345, 0.6609, 0.8414, 0.9140, 0.7525, 0.2686],
          [0.4464, 0.6749, 0.9152, 0.6936, 0.7035, 1.0000, 0.4540, 0.7455],
          [0.2891, 0.1136, 0.4981, 0.2119, 0.8096, 0.6342, 0.8273, 0.7008],
          [0.1382, 0.8667, 0.2436, 0.6408, 0.0238, 0.4742, 0.0062, 0.8495]]]))
# 可见标记 0.0000 和 1.0000 都在自身所在的 channel 里向右下角循环位移了 shift_size = window_size // 2 = 2 个 pixels 回去了
x = torch.roll(shifted_x, shifts=(shift_size, shift_size), dims=(1, 2))
x.shape, x[..., 0], x[..., 1], x[..., 2]
(torch.Size([1, 8, 8, 3]),
 tensor([[[0.0000, 0.1060, 0.1015, 0.2196, 0.3544, 0.8485, 0.3509, 0.7353],
          [0.3751, 0.5690, 0.9630, 0.1945, 0.8999, 0.4977, 0.8007, 0.1598],
          [0.1991, 0.0081, 0.2861, 0.4539, 0.1620, 0.8776, 0.5298, 0.3748],
          [0.8627, 0.0345, 0.2435, 0.7224, 0.9310, 0.8621, 0.4113, 0.8057],
          [0.4686, 0.3372, 0.3429, 0.6660, 0.7115, 0.8560, 0.7055, 0.4709],
          [0.5104, 0.2789, 0.8015, 0.8737, 0.6784, 0.6677, 0.8233, 0.6589],
          [0.9013, 0.6980, 0.1548, 0.9066, 0.0334, 0.1617, 0.2747, 0.6150],
          [0.6927, 0.0394, 0.4180, 0.0387, 0.4488, 0.1339, 0.3340, 0.6178]]]),
 tensor([[[0.6409, 0.0384, 0.7775, 0.3550, 0.6754, 0.9210, 0.0923, 0.0691],
          [0.7684, 0.0422, 0.1605, 0.9409, 0.0397, 0.7786, 0.8475, 0.8495],
          [0.3684, 0.9283, 0.0569, 0.7790, 0.6074, 0.1229, 0.3138, 0.5926],
          [0.1865, 0.5766, 0.2497, 0.2391, 0.4254, 0.7249, 0.3116, 0.1666],
          [0.0800, 0.3956, 0.7351, 0.8919, 0.1177, 0.7949, 0.0028, 0.2635],
          [0.9967, 0.1205, 0.1785, 0.0886, 0.8664, 0.8412, 0.1258, 0.4302],
          [0.5620, 0.9326, 0.6767, 0.2432, 0.5963, 0.7276, 0.2273, 0.4879],
          [0.9367, 0.9096, 0.8327, 0.1795, 0.0361, 0.7189, 0.9292, 0.3822]]]),
 tensor([[[0.8273, 0.7008, 0.2891, 0.1136, 0.4981, 0.2119, 0.8096, 0.6342],
          [0.0062, 0.8495, 0.1382, 0.8667, 0.2436, 0.6408, 0.0238, 0.4742],
          [0.4363, 0.1852, 0.6110, 0.2923, 0.6231, 0.0668, 0.9430, 0.2830],
          [0.5139, 0.5424, 0.3008, 0.5251, 0.3518, 0.0882, 0.3335, 0.7853],
          [0.4835, 0.3571, 0.2530, 0.8452, 0.2010, 0.3866, 0.5673, 0.1172],
          [0.4750, 0.7820, 0.4262, 0.4426, 0.4161, 0.1199, 0.0477, 0.4646],
          [0.7525, 0.2686, 0.6829, 0.6753, 0.4345, 0.6609, 0.8414, 0.9140],
          [0.4540, 0.7455, 0.4464, 0.6749, 0.9152, 0.6936, 0.7035, 1.0000]]]))
x = x.view(B, H * W, C)
x.shape
torch.Size([1, 64, 3])

3.9 Basic Layer

        Basic Layer 即 Swin Transformer 的各 Stage,包含了若干 Swin Transformer Blocks 及 其他层

        注意,一个 Stage 包含的 Swin Transformer Blocks 的个数必须是 偶数,因为需交替包含一个含有 Window Attention (W-MSA) 的 Layer 和含有 Shifted Window Attention (SW-MSA) 的 Layer。 

class BasicLayer(nn.Module):
    """ A basic Swin Transformer layer for one stage.
    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """
    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                                 num_heads=num_heads, window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer)
            for i in range(depth)])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"

    def flops(self):
        flops = 0
        for blk in self.blocks:
            flops += blk.flops()
        if self.downsample is not None:
            flops += self.downsample.flops()
        return flops

参考资料:

搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了(十六)

图解swin transformer - 云+社区 - 腾讯云

Swin Transformer 论文详解及程序解读 - 知乎

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【深度学习】详解 Swin Transformer (SwinT) 的相关文章

随机推荐

  • 基于matlab轴的优化设计,基于MATLAB的转轴可靠性优化设计

    科 技 天 地 56 INTELLIGENCE 基于 MATLAB 的转轴可靠性优化设计 上海理工大学管理学院 郑 红 摘 要 基于 MATLAB 的可靠性优化设计是应用 MATLAB 软件 在优化设计中将设计参数作为随机变量 以产品的可靠
  • 整理了60个 Python 实战例子,拿来即用!

    人生苦短 我用 Python 大家好 最近有一些朋友问我有没有一些 Python 实战小案例 今天我整理排版了一遍 给大家分享一下 喜欢记得点赞 收藏 关注 整理了60个Python小例子 拿来即用 整理了60个 Python 实战例子 拿
  • Ubuntu系统下安装微信

    安装微信实现截图发送图片功能 1 打开终端输入命令 更新软件源 sudo apt get update 2 输入以下命令 一定注意 O 中间是大写字母O 不是0也不是小写o wget O https deepin wine i m dev
  • verilog之状态机详细解释(二)

    一 有限状态机设计的一般步骤 1 逻辑抽象 得出状态转换图 就是把给出的一个实际逻辑关系表示为时序逻辑函数 可以用状态转换表来描述 也可以用状态转换图来描述 这就需要 分析给定的逻辑问题 确定输入变量 输出变量以及电路的状态数 通常是取原因
  • Effective C++ - Implementations

    前言 实现中需要注意的一些问题 尽可能延后变量的定义 尽量少做转型动作 转型语法 尽量避免使用dynamic cast 避免返回handles指向对象内部成分 为异常安全而努力是值得的 透彻了解inlining的里里外外 将文件间的编译依存
  • LaTex将表格居于文档中间

    导入booktabs包 usepackage booktabs 插入表格处 begin table htbp table环境 中换成h 效果相同 centering 表示居中 begin tabular cc 表示两列 toprule 添加
  • 憨批的语义分割重制版2——语义分割评价指标mIOU的计算

    憨批的语义分割重制版2 语义分割评价指标mIOU的计算 注意事项 学习前言 什么是mIOU mIOU的计算 1 计算混淆矩阵 2 计算IOU 3 计算mIOU 计算miou 注意事项 这是针对重构了的语义分割网络 而不是之前的那个 所以不要
  • Ubuntu 开机时间

    1 查看开机启动时间 systemd analyze blame 2 关闭服务 关闭NetworkManager sudo systemctl disable NetworkManager wait online service 网络管理器
  • VM虚拟机 此主机支持Intel VT-x,但Intel VT-x处于禁用状态”

    其实遇到这个问题 我们只需要进主板BIOS中 开启Intel Virtualization Technology选项即可解决 但是由于主板品牌众多 当然设置大同小异 其实花点时间就可以在BIOS中找到Intel Virtualization
  • Python进阶-----面对对象4.0(面对对象三大特征之--继承)

    目录 前言 Python的继承简介 1 什么是继承 2 继承的好处 3 object类 继承的相关用法 1 继承的定义与法则 2 对继承的重写 3 单继承 多层继承 4 多继承 5 多继承重写时调用父类方法 前言 在讲之前 我想说说中国古代
  • 快乐数

    快乐数 happy number 有以下的特性 在给定的进位制下 该数字所有数位 digits 的平方和 得到的新数再次求所有数位的平方和 如此重复进行 最终结果必为1 中文名 快乐数 外文名 happy number 类型 计算方法 属于
  • 列出某个目录下面所有的文件与目录

    import java io File public class wenjian 遍历该对应对应的数组 public static void main String args File file new File d 某银行新规面资料 if
  • Spring-AOP实践 - 统计访问时间

    公司的项目有的页面超级慢 20s以上 不知道用户会不会疯掉 于是老大说这个页面要性能优化 于是 首先就要搞清楚究竟是哪一步耗时太多 我采用spring aop来统计各个阶段的用时 其中计时器工具为StopWatch 文章结构 遇到的问题 创
  • centos7安装nginx 报./configure: error: C compiler cc is not found

    centos7安装nginx 报 configure error C compiler cc is not found CentOS 7 下 安装 nginx 执行配置命令 configure 时提示以下错误 解决 执行以下命令 yum y
  • php伪随机数

    目录 函数介绍 代码测试 考点 1 根据种子预测随机数 2 根据随机数预测种子 函数介绍 mt srand 播种 Mersenne Twister 随机数生成器 mt rand 生成随机数 简单来说mt srand 通过分发seed种子 然
  • 如何学会像优秀程序员一样思考

    如何学会像优秀程序员一样思考 程序员的思考方式比较有意思 并且这些思考方式有时候表现得很好 这些思考方式其实可以概述下 通常包含如下几个点 一切都只是数据 数据本身没有任何意义 如果有意义那么它必须被解释 编程是关于创建和组合抽象 模型是给
  • 内容管理软件——Obsidian、Zettlr学习笔记(附Typora)

    一 Obsidian 1 官网 Obsidian 2 学习教程 Obsidian 中文论坛 3 使用经验 3 1关于markdown常用格式 标题的格式 标题级数 空格 文本内容 这是一段普通的文本 这是一级标题 这是二级标题 这是三级标题
  • ChatGPT在生态保护和可持续发展中的潜在作用如何?

    ChatGPT在生态保护和可持续发展领域具有潜在的重要作用 生态保护和可持续发展是全球性的挑战 涉及到环境保护 资源管理 气候变化应对 生物多样性保护等多个方面 ChatGPT作为一种人工智能技术 可以在以下几个方面发挥积极作用 1 数据分
  • ELK(六)ElasticSearch快速入门_中文分词

    分词 分词就是指将一个文本转化成一系列单词的过程 也叫文本分析 在ElasticSearch中称之为Analysis 举例 我是中国人 gt 我 是 中国人 分词API 指定分词器进行分词 POST analyze analyzer sta
  • 【深度学习】详解 Swin Transformer (SwinT)

    目录 摘要 一 介绍 二 原理 2 1 整体架构 2 1 1 Architecture 2 1 2 Swin Transformer Block 2 2 基于移位窗口的自注意力 2 2 1 非重叠局部窗口中的自注意力 2 2 2 在连续块中