Swin Transformer: Hierarchical Vision Transformer using Shifted Windows详解

发布时间 2023-12-17 16:31:57作者: InunI

初读印象

comment:: (Swin-transformer)代码:https://github. com/microsoft/Swin-Transformer

动机

将在nlp上主流的Transformer转换到cv上。存在以下困难:

  • nlp中单词标记是一个基本单元,但是视觉元素在尺度上有很大的变化。
  • 图像分辨率高,自注意力操作计算复杂度是图像大小的二次方
    提出了一种通用的Transformer主干,称为Swin Transformer,它构造了分层的特征映射,并且具有与图像大小相关的线性计算复杂度。Swin Transformer通过从小尺寸的补丁,并在更深的Transformer层中逐渐合并相邻的补丁来构造分层表示。通过这些分层特征图,Swin Transformer模型可以方便地利用先进的密集预测技术,如特征金字塔网络( FPN )或U - Net。线性计算复杂度是通过在划分图像(用红色勾画)的非重叠窗口内局部计算自注意力来实现的。每个窗口中的补丁数量是固定的,因此复杂度与图像大小成线性关系。

方法

总体架构(以简易版本Swin-T为例)

Pasted image 20221107111202

  1. Patch Parition:将整张图片切片,此处切成\(4\times 4\)的patch,每个patch看成一个token,拼接成\(4\times 4\times3=48\)的特征向量。(为了产生层次化的表示)
  2. Stage1:将token经过一个线性嵌入层变为长度为C的特征向量,整张图变成\(\frac{H}{4}\times\frac{W}{4}\times C\)的大小,经过两个Swin Transformer Block(不改变特征图大小)。
  3. Stage2: 一个Patch Merging再加上两个Swin Transformer Block。
    • Patch Merging:为了产生层次化的表示,随着网络的加深,通过补丁合并层来减少token的数量。合并相邻的\(2\times 2\)个patch,拼接得到长度为4C的token特征向量,并经过一个线性层,token的长度变为2C,整张图变为\(\frac{H}{8}\times\frac{W}{8}\times 2C\)
  4. stage3:同stage2,整张图变为\(\frac{H}{16}\times\frac{W}{16}\times 4C\)
  5. stage4:同stage2,整张图变为\(\frac{H}{32}\times\frac{W}{32}\times 8C\)

Swin Trasformer Block

将Transformer中的多头注意力模块换成了窗口多头注意力(WIndow-MSA)和移动窗口多头注意力(Shifted Window-MSA)

Pasted image 20221107114051只在窗口中做自注意力能够减少计算复杂度,但是这将减少跨窗口的连接,减少模型建模能力,因此在两个利纳许的SwinTransformer Blocks中交替使用两种划分的移动窗口。

Pasted image 20221107145113

  • 第l层:使用规则的大小为\(M\times M\)的窗口,在每个窗口内做自注意力。
  • 第l+1层:将l层的规则窗口移位\((\frac{M}{2},\frac{M}{2})\)(每个窗口是被一个井字切割开的),在每个窗口内做自注意力,以达到穿越规则边界做注意力的目的。

相对位置编码

在每个窗口中使用相对位置编码\(B \in R^{M^2 \times M^2}\)计算每个head的注意力:
Pasted image 20221107153622其中\(M^2\)是每个窗口的token数量。由于一个窗口内在同一轴上的两个位置之间的距离介于\([-M+1,M-1]\)之间,因此设置一个偏执矩阵\(\hat{B}\in R^{(2M-1)\times (2M-1)}\),\(B\)的值从\(\hat{B}\)中选。

模型变种

Pasted image 20221107160449

输入输出

输入为一张图片,输出根据任务的不同而不同。

表现

在语义分割上 的表现:

Pasted image 20221107163811消融实验:Pasted image 20221107164240

启发

  1. 通过改变窗口的位置,将局部的注意力扩充到全局;
  2. 在局部块内使用了相对位置信息。