打败VIT?Swin Transformer是怎么做到的

发布时间 2023-11-23 14:13:39作者: 水木清扬

https://mp.weixin.qq.com/s/C5ZDYKPdHazR2bR9I9KFjQ

在之前的文章中,我们介绍过VIT(Vision Transformer) ,它将NLP中常用的Transformer架构用于图像分类预训练任务上,取得了比肩ResNet的效果,成功证明了Transformer在NLP和CV上的大一统能力,进而成为后续许多工作的骨架模型。

今天我们要介绍的Swin Transformer(Shifted Window Transformer) 和VIT一样,也是一个采用Transformer架构做图像预训练的模型,但得益于它引入了CNN的归纳偏置,使得其在各类数据集上(尤其是分割、检测这样的预测密集型任务上)的效果都要优于VIT。但效果上有所提升的背后代价是:它对NLP和CV任务的统一能力下降了。这一点我们会在正文中细说。

本文在写作时,假设大家已经具备VIT相关知识。如果对VIT原理有不了解的朋友,可以参考这篇文章

【全文目录如下】

一、Swin Transformer的诞生背景
1.1 VIT的缺陷
1.2 Swin Transformer的改进

二、Swin Transformer的整体架构

三、Patch Merging

四、W-MSA与SW-MSA
4.1 W-MSA
4.2 朴素SW-MSA
4.3 环状SW-MSA
4.4 Masked Attention
4.5 复习Swin Transformer Block

五、窗口attention计算量分析
5.1 矩阵相乘的计算量
5.2 全局attention计算量
5.3 窗口attention计算量
5.4 全局MSA与窗口MSA

六、参考

【绘图与码字不易,点赞和在看,也是持续更新的动力❤️~】

一、Swin Transformer的诞生背景

1.1 VIT的缺陷

之前在介绍VIT原理时,我们提过VIT的一个重要意义是:证明Transformer对CV和NLP的大一统性。因此VIT几乎是将Transformer encoder部分完全搬运过来(也可理解为和Bert几乎一致),然后将图片分割成pacth的形式,每个patch即等同于NLP中的一个token向量,如此一来完全以训练语料的方式做图片分类的预训练,整个过程如下:

图片

而这样的极致统一,对于CV任务来说,有两个显著的缺陷:

  • 同一实体的尺寸变化问题。 例如一张街景图,其中有很多辆车,每辆车的大小各不相同,而我们的目标是要把这些大小不同的车都检测出来。在CNN架构中,我们会对输入数据做尺寸变化,在模型的每一层输出不同大小的特征图(例如常见的UNet架构),以此教会模型探查出不同大小的实体。但是VIT由于追求和NLP任务保持一致性,它每一层的patch数量和patch大小都保持不变。这就使得VIT在分割、检测这样预测密集型的任务上注定存在弱势。

  • 高分辨率图像引起的计算复杂度上升问题。在VIT中,尽管我们通过做patch的方式,减少了输入序列的长度。但VIT的attention计算仍是全局的,它的计算复杂度和图片尺寸大小呈平方关系。这也意味着VIT对高分辨率图像的处理是昂贵的。在后文中,我们会对这个复杂度的计算做更详细的说明。

1.2 Swin Transformer的改进

针对上述两点关于VIT的缺陷,Swin Transformer提出了如下改进:

图片

图中演示的是不同layer的情况。其中,红色的方框表示窗口(local window),灰色的方框表示patch。我们规定只在窗口的范围内计算attention。

我们先来看图(b),它表示VIT的架构,具有以下特点:

  • 在不同层中,保持patch数量和patch大小(16*16)不变。
  • 一整张图就是一个窗口,即VIT是在整张图的维度上计算attention的

我们再来看图(a),它表示Swin Transformer的架构,具有以下特点:

  • 在窗口内计算attention。注意到在每一层中,它将图片切分成若干个窗口(红框),各个窗口内含有固定数量的patch(灰框,实操中,patch_size一般设为4*4,每个window内的patch数量一般设为49,演示图中patch数量为16),各个窗口内独立计算attention(而不是基于全图)。这项优化对应解决上文说的“高分辨率图像引起计算复杂度上升”的问题。
  • 在层间做patch合并以获得多尺度特征。这个操作在论文中被称为patch merging。在这一步中,我们会通过某种方法,改变原始图片大小,然后继续按相同的方式划分窗口(此时每个窗口内依然有49个patch,每个patch对应的原始图片的尺寸还是4*4,只是窗口的数量变少了)。通过这种方式,我们模拟了CNN架构中产生不同尺寸特征图的过程,以此更好解决CV中的预测密集型任务(检测、分割等)。这项优化对应解决上文说的“同一实体尺寸变化”的问题

看到这里如果觉得迷惑,也没关系,我们会在后文对这几块做详细阐述。不难发现,Swin Transformer做的这两项优化,是有针对性的:都是针对CV任务做的专门优化。如果你再仔细感受一下,可能会发现它甚至借鉴了CNN架构对图像任务的归纳偏置(一种先验假设)的契合方法(参见这篇文章6.2.1) 。而在NLP任务中,其实并不关心多尺度特征这类问题。由于Swin Transformer这种对图像任务的针对性,导致它无法像VIT那样用统一的方式表示不同模态的特征,这也是文章作者留给后人填的一个坑

好,讲完了Swin Transformer在图像任务上做的针对性改进,接下来我们整体看一下Swin Transformer的架构和运作流程。在掌握整体流程之后,我们再分块对细节进行详细阐述。

二、Swin Transformer的整体架构和运作流程

图片

我们先来看(a),它描绘了Swin Transformer的整体架构和前向过程。

(1)首先,假设我们有一张224*224*3的输入图片,即这里H=224, W=224

(2)然后,划分patch,即图中所示的patch partition。实操中我们设置patch的大小为4*4,如果把一张图想象成一个长方体。那么划分后,将得到56*56个大小为4*4*3的小方块,如果我们把这些小方块拉平,就得到56*56*48的立方体,也就是图中在patch partition后所示的

但是,如果我们想让通道数变大点,比如从48变到96,这时我们就可以再增加一个linear embedding层来做线性转换。也就是图中stage1的linear embedding过后,最终的图片大小为56*56*96

在实际代码操作中,以上两个步骤可以用一个nn.Conv2d完成,我们在VIT原理篇中画过示例,这里我们再展示一次,将图中小patch的尺寸换成4*4*3,将卷积核变成96个4*4*3的卷积核,stride=patch_size=4,padding=0,就能得到一个56*56*96大小的方块了。那么拉平后,类比于NLP任务,我们就可得到56*56=3136个token,每个token的维度为96。

图片

(3)和VIT一样,秉持着一个patch就是一个token的原则,我们可以把数据输入到Swin Transformer中了。Swin Transformer具体的样子在(b)中已绘制出来,我们放在后文细谈。经过Swin Transformer block的处理,我们得到输出结果,其尺寸为56*56*96,对应着演示图中stage1的输出结果

(4)好,现在我们进入Stage2了,这时我们要做一个重要的操作:Patch Merging。我们在第一部分中已经说过它的大致作用。经过Patch Merging后,我们的输入数据尺寸变成28*28*192,再经过Swin Transformer block后的输出结果尺寸也一样,对应着图中。关于patch merging的细节,我们也在下文介绍。

(5)以此类推,stage3的输入输出尺寸为14*14*384,stage4的输入输出尺寸为7*7*768。拿着stage4的输出特征,我们就可以做后续的处理,用到下游的分割、检测、分类任务上了。

现在,我们回头再来看(b)

图(b)描绘了Swin Transformer block做的事情:输入数据先正常在窗口范围内做attention(W-MSA, W表示Window) ,然后经过LN和MLP一系列模块后,又做了一次移动窗口的attention(SW-MSA, SW表示Shifted Window) ,然后才是最终的输出结果。因此你可以发现演示图中的block都是偶数的。因为两个相连的block要依次完成W-MSA和SW-MSA的步骤。这个SW-MSA,就是Swin Transformer的灵魂所在了(连名字都是这么来的),大白话就是移动窗口下的attention计算。我们在后文来详细看它的实现细节。

好,目前为止,我们把Swin Transformer的整体流程过完了,我们回顾一下其中涉及到的几个细节问题:

  • patch分割(patch partition)怎么做?
  • patch融合(patch merging)怎么做?
  • Swin Transformer Block中的W-MSA和SW-MSA怎么做?

其中,第一个问题我们在本部分中已给出了详细的解答。所以我们集中在后面两个问题上。

三、Patch Merging

3.1 Patch Merging操作方法

图片

配合图例,我们来看patch merging的详细过程。我们以第二部分流程图中stage1的输出结果为例。

上图中HWC的图片对应到stage1的输出结果,尺寸就是56*56*96,为了绘图方便,C维度没有画出。

不难发现此时每一个小格子即表示一个patch,我们对patch进行编号,接下来我们就可以按照设定好的编号,把固定位置的patch取出来,拼成四份小图。

接着,我们在C维度concat这四份小图,生成一份尺寸为H * W * 4C的数据。以上两步骤的代码操作如下:

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)

最后,我们引入线性层 nn.linear(4C, 2C) ,减少通道数,使得最终输出的数据维度为H/2 * W/2 * 2C

完整的代码流程如下,大家注意看注释,注意看数据的尺寸变化:

class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

    def extra_repr(self) -> str:
        return f"input_resolution={self.input_resolution}, dim={self.dim}"

    def flops(self):
        H, W = self.input_resolution
        flops = H * W * self.dim
        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
        return flops

3.2 Patch Merging注意事项

关于patch merging,这里再强调两点。

(1)做merge的本质上是patch,而不是一个pixel向量。

你可能想问,在3.1的图中,为什么那些红黄蓝绿的小方框,表示的是一个patch,而不是图中的一个像素点向量呢?

我们回到第二部分对patch partition的介绍图中,可以发现,partition后图中每个位置上的向量(维度为1*1*96),其实是由原始图片中一块4*4*3的patch经过Conv2d处理而来。

所以,这些小方框,在处理过后的图中可理解成一个1*1*96维度的pixel向量,但是本质上,它对应着原图中的一个4*4*3的patch。 经过patch merging的concat和linear操作后,它可能不严格对应着原图中的某块patch,而是若干块patch在通道维度拼接而成,但抽象意义上,它仍表示4*4的patch范围。

那么经过patch merging后得到的 H/2 * W/2 * 2C 的特征图,还要再进行patch partition吗?不需要了这一点在第二部分整体架构流程图中也有体现。因为正如上面所说,此时特征图中每个1 * 1 * 2C的向量,其抽象表示范围依然对应着最原始输入图片中4*4的patch。所以这里,我们只需正常再做window划分(下文会说),而不用再去做patch partition了。

所以我们又说,在Swin Transformer中,patch是基本计算单位,这一点不仅表现在它类比于token参与Transformer block计算,也表现在它的抽象意义上。现在回头看1.2中的(a)图,是不是理解更深刻了些?

(2)patch merging和CNN中的池化操作相似

不难发现,其实patch merging就非常像CNN中的池化操作,只是这里将池化操作常用的min/max/average替换成了按固定位置取patch。

四、W-MSA与SW-MSA

好,接下来,我们就来看Swin Transformer是如何在窗口范围内做attention的。

再回顾一下标题中的两个英文符号的含义:

  • W-MSA:Window-Multi Head Attention,表示在窗口范围内做attention
  • SW-MSA:Shifted Window-Multi Head Attention,表示在移动窗口的范围内做attention

我们先来看W-MSA和SW-MSA是什么。再来看相比于VIT的全局Attention,这两个操作带来的好处是什么。

4.1 W-MSA

4.1.1 W-MSA的流程

图片

W-MSA的流程如左图所示。在VIT中,左图中灰色的patch类比于一个个token,我们对全部的token做attention计算。在Swin Transformer中,我们先划分窗口(红框),每个窗口内默认有49个patch。我们只在窗口的范围内做attention

4.1.2 W-MSA的缺点

但是在窗口内计算attention会有一个明显的缺点:patch的感受野变小了。

什么意思呢?就是原来在VIT里,一个patch可以和全局的patch做attention,它能看到整张图的情况。而现在却只能看到窗口内的视野了。为了解决这个缺陷,我们引入了SW-MSA,即在移动窗口的范围内做attention。

4.2 朴素SW-MSA

4.1.1的右图刻画了一种朴素粗暴的移动窗口的办法,你可以理解成将原始的窗口向右下角移动2个patch,就可得到图中的右图。

这样移动的好处是,对一些patch来说,两次计算attention时,它们分属不同的窗口,这样可以让它们和不同的patch交互,扩大感受野。

但这样做的缺点也是明显的:原来我就4个窗口,每个窗口内固定都是49个patch。经过你这么移动,我有9个窗口,每个窗口内的patch数量还不一致,这不是加重了模型的计算复杂度,以及我写代码的难度吗?

所以,作者在这里创新性地提出了一种环形移动窗口法(cyclic shift) 。

4.3 环状SW-MSA

图片

这个“环状”听起来玄学,其实理解起来不复杂。如上图,最开始是9个不均匀的patch。接着,为了让它均匀,我们将原来的A、B、C部分做一个旋转拼接,就又可以恢复成一个有4个window,且patch数量均匀的数据了。

那么恢复过后呢?还是直接在窗口内计算attention吗?当然不是,这样做会有一个显著的问题:假设这张图片绘制的是蓝天和大地,其中C中包含了蓝天的信息。那么此时在左下角那个窗口中,如果你做attention,势必就会让大地和天空一起做attention,那么这时模型大概率可能会学到天空在大地的下面,这当然是不合理的。

那怎么解决这个问题呢?作者在这里创新性地提出了masked MSA(掩码MSA) ,也就是上图大括号后面的部分(这个部分上图画得比较粗糙,我们在后文4.4中会给出更详细的说明)

做完masked MSA后,我们再把A、B、C还原回原来的位置,就可以了。

4.4 Masked Attention

图片

以上图例来自Swin Transformer github issue,由一作绘制得出,它非常直观地展现出了掩码attention的计算过程。我们可以将其和4.3中做完环状移动窗口后的图对比来看。

先看Window0,在这个窗口中,patch没有任何拼接处理(左图),因此它不需要做任何掩码操作,可以正常做attention。右图中window0的颜色为全黑,意味着无需任何掩码操作。

再看Window2,在这个窗口中,标号为3的部分来自同一批patch,标号为6的部分来自环状移动过来的另一批patch。因此3和6之间是不能做attention的,这意味着如果attention score是由3的一块patch和6的一块patch计算得来,我们就需要把这个score设置成-100,这样一来在后续做softmax时,对应位置的结果就可以小到忽略不计,以此来取得遮掩(masked)效果。右图中黄window2的黄色部分表示不需要做mask的分数,黑色部分表示需要做mask的分数。

如果你还有疑惑,那么再来看一眼下面这张图,应该可以帮你解答疑惑:

图片

同理我们可以推知Window1和Window3的结果。

4.5 复习 Swin Transformer Block

图片

现在我们再回头复习一遍Swin Transformer Block的运作流程。

(1)首先,来了一排可以被当作token的patch,我们对这些patch划分好了Window。

(2)先做W-MSA,在窗口范围内计算attention

(3)再做SW-MSA,以环状的方式移动窗口,目的是为了让窗口数量和W-MSA保持一致。移动窗口的目的是为了扩大每个patch的感受野。

(4)因为采用了环状移动的方式,导致patch间的相对位置变动了,因此我们引入mask attention,对被打乱位置的patch间的attention score做掩码操作。

(5)因为每次都必须进行W-MSA + SW-MSA的操作,因此block的数量必须为偶数。

好!到此为止,我们就把整个Swin Transformer的核心技术说完了。但别忘了,还有个重要问题没解决:为什么要从全局attention变成窗口attention呢?

五、窗口attention计算量分析

我们直接对上面这个问题抛出回答:相比于全局attention,窗口attention能有效降低计算量。

那接下来我们肯定就要问,那到底能降多少呢?这点在论文中只给出了结论,而没有给出详细的推导过程。接下来,我们就来看看计算推导细节。

5.1 矩阵相乘的计算量

我们一般用FLOPs(floating point operations,浮点运算次数) 来表示运算量的大小。对于“两矩阵相乘”这个操作而言,其运算量 = 乘法运算的次数 + 加法运算的次数

来看一个具体例子:

图片

两矩阵相乘,为了获取图中深橘色部分的元素,我们一共需要进行n次乘法运算和n-1次加法运算。

那么现在结果矩阵中,一共有m*p个橘色方块,则意味着我们需要进行:m*p*(n + n - 1)次浮点计算。

再进一步,假设此时在蓝色和绿色的矩阵外,我们还有一个bias矩阵,意味着计算单个橘色方块时我们需要进行n次乘法和n-1+1次加法运算,那么此时总计算量为:m*p*(n+n) = 2mnp。当然,即使不加这个bias,我们也可以把-1项给忽略,得到相同的结果。

所以这里我们总结下,假设有两个矩阵A和B,它们的维度分别为(m, n)和(n, p),则这两矩阵相乘的运算量为2mnp

一般在矩阵运算中,**乘法运算的时间要高于加法运算的时间,因此有时在统计运算量时,我们只考虑乘法运算的次数,则此时两矩阵相乘的运算量可近似为mnp,而Swin Transformer的论文中采用的就是这种近似。

5.2 全局attention计算量

有了5.1的结论,我们就可以来算attention的计算量了,毕竟attention说白了也是一系列矩阵操作。我们先来看全局attention的计算量。

我们忽略batch维度(认为b = 1),则此时输入数据的维度为(hw, C),其中hw(即h * w)表示patch数量,也就等同于token序列的长度,C表示token embedding维度。

(1)计算Attention的第一步,我们需要让输入数据和这三个矩阵相乘,得到Q、K、V三个结果矩阵。而这三个矩阵的维度都是(C, C),那么按照5.1中给出的公式:(hw, C) * (C, C) = (hw, C),此时的运算量为

(2) 接下来,我们需要将Q和K相乘,得到attention score。则:(hw, C) * (C, hw) = (hw, hw),此时的运算量为

(3)接下来,我们要将attention score和V相乘,则:(hw, hw) * (hw, C) = (hw, C),此时的运算量为

(4)最后,我们要将结果过一层线性映射层,则:(hw, C) * (C, C) = (hw, C),此时运算量为

以上四步的运算量加总,则全局attention的总运算量为

5.3 窗口attention计算量

前面我们说过,一个窗口内固定有49个patch,也就是窗口可以拆成7*7的矩阵。

我们先来算一个窗口内做attention的计算量,然后再乘上窗口总数,就可以得到总计算量了。

我们设窗口的H和W的值为:H=M,W=M,在默认操作中M=7,则代入5.1的公式可求得单个窗口attention的计算量为 

而此时,我们的窗口总数为: 

则在窗口下,attention的计算总量为: 

5.4 全局MSA VS 窗口MSA

经过5.2和5.3这么一算,两者的计算量就出来了:

图片

你看,在全局attention中,计算量和图像尺寸大小h*w是平方关系,但是变成窗口计算之后,就变成线性关系了。 这大大降低了运算复杂度,这也使得Swin Transformer能够被用于处理高分辨率的图像。

如果你对复杂度还没有那么直观的感受,你可以将M=7,h=w=56带入上面的式子算一算,就能感受到窗口MSA的效果了。

总结来说,通过“窗口MSA + 移动窗口”这种方式,Swin Transformer既做到了节省计算量,又做到了扩充每个patch的感受野。而通过“patch merging”这种方式,它模仿了CNN架构中输出不同尺寸特征图的过程,使得模型能够好处理预测密集型任务(检测、分割)。

关于实验部分,这里就不展开细说了,感兴趣的朋友可以翻阅原始论文。

六、参考

1、https://arxiv.org/pdf/2103.14030.pdf

2、https://github.com/microsoft/Swin-Transformer

3、https://www.bilibili.com/video/BV13L4y1475U/