【论文阅读】CrossViT:Cross-Attention Multi-Scale Vision Transformer for Image Classification

发布时间 2023-07-10 09:37:44作者: 睡晚不猿序程

?前言

  • ?博客主页:?睡晚不猿序程?
  • ⌚首发时间:23.7.10
  • ⏰最近更新时间:23.7.10
  • ?本文由 睡晚不猿序程 原创
  • ?作者是蒻蒟本蒟,如果文章里有任何错误或者表述不清,请 tt 我,万分感谢!orz


1. 内容简介

论文标题:CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification

发布于:ICCV 2021

自己认为的关键词:多尺度、ViT

是否开源?https://github.com/IBM/CrossViT


2. 论文速览

论文动机

  1. CNN 和 ViT 混合模型取得了不错的成果,但是相比起纯 ViT 计算量较大

    可能是因为在 Token 上用了普通卷积了

  2. ViT 需要大量的数据集进行训练,否则效果不佳

  3. ViT 无法学习多尺度信息

本文工作

  1. 一个双分支结构的 ViT,可以学习多尺度信息
  2. 一个 Token 混合模块,可以让不同尺度的信息之间进行交互

完成效果

  1. 在参数量和运算量都更小的情况下,ImageNet1K 分类正确率比 DeiT 多了 2%

3. 图片、表格浏览

图一

image-20230614213720704

效果图,可以看到和 ViT 的参数差距非常大

图二

image-20230614214036220

模型架构图

可以看到由两个分支组成,但是这个中间的交叉注意力的计算没有细说,最后两个分支的输出分别使用MLP头处理然后相加

图三

image-20230614214053680

四种不同的多尺度混合方式:

  1. 不考虑 CLS 的特点,直接当成 token 一起处理
  2. CLS 可以当成是全局特征的表示,二者进行融合即可
  3. 各自进行混合
  4. 这个就是所谓的 cross-attention,CLS 调换位置,然后进行融合

脑洞有点大了

图四

image-20230614214112358

cross-attention 的实例,以大 patch 为例,小 patch 同样这样做即可

使用了一个线性映射来讲 CLS 变换为对应的维度,然后再变换回来就可以了

所以两个分支就是通过 CLS 进行信息交互的


4. 引言浏览

Vision Transformer 需要使用大量的数据集进行训练,DeiT 证实了数据增强技术和模型正则化可以让 ViT 的训练效果更好

DeiT:使用了知识蒸馏的策略,仅使用 ImageNet-1K 的数据集就可以达到 SOTA

本文工作:研究如何让 ViT 学习到多尺度的特征表示,探索适合于 Transformer 的特征融合机制

模型结构:双分支 Transformer,两个分支的 Patch 大小不同,并且二者会进行交互

本文贡献

  1. 双分支 ViT,用于提取多尺度特征,并且提出了一个简单且高效的特征融合机制
  2. SOTA

自由阅读

5. 方法

ViT:将图像划分为确定大小的 patch,添加一个 CLS Token 并且经过位置嵌入之后送给 Transformer 做处理,最后使用 CLS token 作为分类依据

Feed Forwad Network 由两层全连接网络组成,使用 GELU 以及 LayerNorm

和 CNN 的区别:CNN 最后的表示一般是通过平均池化来表示,但是 Transformer 使用的是 CLS,这个 CLS 在 Transformer block 中会与其他的 patch 进行交互

因为有了 CLS 的存在,所以作者提出了一种基于 CLS 的双路多尺度 ViT 模型

Multi-Scale ViT

patch 的划分大小影响着 Transformer 的性能以及计算开销,作者提出了双分支 ViT 结构

Cross ViT

  1. L-Branch:patch 大小更大,包含更多的 Encoder block,并且有更深的 embedding dimensions
  2. S-Branch:patch 更小,拥有更少的 Encoder block 以及更少的 embedding dimension

这样设计应该是想要平衡性能以及计算开销

两个分支将会混合 L 次,并且最后两个输出都会用于预测。并且两个分支都使用了自学习的位置嵌入

Multi-Scale Feature Fusion

为了让两个分支的数据可以进行融合交互,提出了多种方案

  1. All-Attention: 直接两个分支拿过来一起计算注意力【计算开销大】
  2. Class Token Fusion:只是用 Class Token 进行混合(直接使用加法)
  3. Pairwise Fusion:基于 patch 所属的空间位置进行混合——这里会先进行插值来对其空间大小,然后再进行混合(CLS 相加,patch 也相加)
  4. Cross-Attention Fusion:利用 CLS 来交互信息。

Cross-Attention Fusion

将 CLS 当成是一个分支的抽象信息,那么只需要交换两个分支的 CLS,然后送入 Transformer 中,两个分支的信息就可以进行交互了,这样有助于在另一个分支中引入不同尺度的信息

image-20230614214151778

上图为实例,就是使用一个 Transformer block 来生成新的 CLS。例子是 Large Branch 的。

首先 CLS 先经过映射对齐维度,然后 CLS 作为 Q,另一个分支的 patch 加上本分支的 CLS 作为 K,V,经过自注意力运算,最后映射回到原维度,生成了最后的 CLS 结果

这个结果是通过与另一个分支交互得到的

因为自注意力运算只计算一个向量,所以计算复杂度变小

并且在计算 Cross-Attention 时,没有使用 FFN,而是通过残差连接加映射直接得到结果

6. 实验

6.1 实验结果

backbone:DeiT

dataset:ImageNet

使用了多种数据增强技术:rand augmentation,mixup,cutmix,random erasing

image-20230614214208559

6.2 消融实验、

不同分支信息融合技术

image-20230614214232070

可以看到 CA 得到了最好的效果

体现出来的其他特点也挺有意思的,单分支的分类准确率 L 远远大于 S,应该是因为分类任务中,小尺度信息利用不多。

引入 CA 之后,L 分支的准确率变低了,S 分支的准确率变高了,这个应该是信息交互带来的结果,但是最后的输出准确率提升——这样可以得出结论:他们可以学习到图像中的不同特征

patch size

image-20230614214302045

出乎意料的是,(12,16) 得到的效果比(8,16)得到的效果要好——作者表示,这可能是两个分支之间的细粒度差异很大,使得特征的平滑学习很困难,所以使得(12,16)效果更好

S 分支的通道维度以及深度

同样是上图,B,C 行中分别测试了 S 分支在不同参数下的结果,可以看到改变不大,作者认为应是因为分类主要使用的是 L 分支的信息,S 分支的信息被作为额外信息来使用

CA 的深度以及 Multi-Scale Transformer encoder

更频繁的使用 CA 没有带来更好的结果,而且让 CA 更深也没有效果,因为另一个分支的 CLS 不会因此改变

CLS 的重要性

作者把 CLS 去除,用所有 token 的平均来作为 CLS,最后比 CrossViT 差了 1%,得出结果:CLS 是特征的摘要信息

6. 总结、预告

6.1 总结

  1. 尝试构建了一个多尺度 Transformer,由两个分支组成
  2. 提出了两个分支的交互形式,使用 CLS 进行 Cross Attention 计算,计算开销小效果好

应该是从 CNN 的多尺度结构中进行借鉴得到的模型,将图像根据不同的 patch 大小进行划分,得到一个双分支结构,并且两个分支通过 CLS 进行交互,但是消融实验也有需要注意的点:

  1. patch 大小相差过大导致效果变差:作者说是两个分支的尺度差异太大,并且由于是通过 CA 进行学习的,所以导致模型学习困难
  2. 主要分类信息由 L 分支提供,而 S 分支主要提供额外的辅助信息
  3. 可以使用 CNN 作为嵌入方式,不一定使用线性嵌入