[论文速览] Hard Patches Mining for Masked Image Modeling

发布时间 2023-07-12 10:58:36作者: NoNoe

Pre

title: Hard Patches Mining for Masked Image Modeling
accepted: CVPR 2023
paper: https://arxiv.org/abs/2304.05919
code: https://github.com/Haochen-Wang409/HPM
ref: CVPR 2023 | 挖掘困难样本的 MIM 框架: Hard Patches Mining for Masked Image Modeling

关键词:MIM, self-supervised, 自监督掩码学习
阅读理由:CVPR, 方法似乎简单有效

Idea

先让模型根据重建损失的大小自己生成困难mask,再像传统方法那样训练模型去预测masked patches

Motivation&Solution

m: NLP的词已经高度语义化,而CV里图片则存在空间信息的冗余,各种自监督掩码学习方法的性能强烈依赖于手工定义的掩码策略

s: 提出一种新的困难样本挖掘策略,让模型自主地掩码困难样本

Background

对比学习能学到视角不变的特征。
SimMIM发现大mask核在不同mask比例下更加鲁棒

Method(Model)

Overview

HPM 包含一个学生模型和一个教师模型,它们共享网络结构,包含 encoder \(f_{\theta}\),图像重建 decoder \(d_{\phi}\),损失预测 decoder \(d_{\psi}\)。教师模型的参数是由学生模型指数平滑更新而来的。

每次迭代,一张图片首先打成 patch,并经过教师模型,得到每个 patch 预测的重建损失。进而,基于该预测,产生当前的二元 mask \(\mathbf{M}\) 用于 MIM 任务,0表示遮掩。损失函数包含两项:

\[\mathcal{L} = \mathcal{L}_{\mathrm{rec}} + \mathcal{L}_{\mathrm{pred}}, \tag{2} \]

其中 \(\mathcal{L}_{\mathrm{rec}}\) 表示重建损失,是标准的 MIM 损失;而 \(\mathcal{L}_{\mathrm{pred}}\) 表示的是重建损失预测损失。

\[\mathcal{L}_{\mathrm{rec}} = \mathcal{M} \left( d_{\phi_s}(f_{\theta_s}(\mathbf{x} \odot \mathbf{M})), \mathcal{T}(\mathbf{x} \odot (1 - \mathbf{M})) \right), \tag{3} \]

其中 \(\mathbf{M} \in \{0,1\}^N\) 表示产生的二元 mask, \(\odot\) 表示 element-wise dot product,因此 \(\mathbf{x} \odot \mathbf{M}\) 表示可见的 patches。 \(\mathcal{T}(\cdot)\) 表示 target 的 transformation,例如 MAE 中就是一个恒等映射,而 BEiT 中则是将图像转化为离散的 token。 \(\mathcal{M}(\cdot, \cdot)\) 表示某种度量,如 MAE 中用的 $\ell_2 $ 距离,SimMIM 中用的 smooth \(\ell_1\) 距离。

重建损失L_pred

Absolute loss.
一种最直观的方法就是直接最小化真实重建 loss \(\mathcal{L}_{\mathrm{rec}}\) 和预测的重建 loss 之间的 MSE,即

\[\mathcal{L}_{\mathrm{pred}} = \left( d_{\psi_s}(f_{\theta_s} (\mathbf{x} \odot \mathbf{M})) - \mathcal{L}_{\mathrm{rec}} \right)^2 \odot (1 - \mathbf{M}), \tag{4} \]

其中 \(d_{\psi_s}\) 表示的是学生模型的 loss predictor,而 $\mathcal{L}_{\mathrm{rec}} $ 需截断梯度。

然而,这里的目标是确定图像中的困难样本,需要 patch 之间重建损失的相对大小,因此 MSE 并不是最合适的选择,因为 \(\mathcal{L}_{\mathrm{rec}}\) 将随着训练的进行而减少, \(\mathcal{L}_{\mathrm{pred}}\) 也会变小,但这不代表它学到了东西,为此作者提出了一种基于二元交叉熵的相对损失。

Relative loss.
给定一张含有N个 patch 的图片,其真实的重建损失为 \(\mathcal{L}_{\mathrm{rec}} \in \mathbb{R}^N\) ,目的是预测这N个 patch 之间重建损失的相对大小,即 \(\texttt{argsort}(\mathcal{L}_{\mathrm{rec}})\) ,但 \(\texttt{argsort}(\cdot)\) 不可导,因此作者将其转换为dense relation comparison问题,预测patch两两之间的大小关系:

\[\begin{aligned} \mathcal{L}_{\mathrm{pred}} = &-\sum_{i=1}^N \sum_{j=1 \atop j\neq i}^N \mathbb{I}^{+}_{ij} \log \left( \sigma(\hat{\mathcal{L}}^s_i - \hat{\mathcal{L}}^s_j) \right) \\ &-\sum_{i=1}^N \sum_{j=1 \atop j\neq i}^N \mathbb{I}^{-}_{ij} \log \left( 1 - \sigma(\hat{\mathcal{L}}^s_i - \hat{\mathcal{L}}^s_j) \right), \end{aligned} \tag{5} \]

其中 \(\hat{\mathcal{L}}^s = d_{\psi_s}(f_{\theta_s}(\mathbf{x} \odot \mathbf{M})) \in \mathbb{R}^N\) 是学生模型输出的损失预测值,而 \(i, j=1,2,\dots,N\) 是 patch indexes。 \(\sigma(z) = e^z / (e^z + 1)\)\(\texttt{sigmoid}\) 函数。 \(\mathbb{I}^{+}_{ij}\)\(\mathbb{I}^{-}_{ij}\) 是两个指示函数,表示 patch i 和 patch j 的真实重建损失大小,定义如下:

\[\mathbb{I}^{+}_{ij} = \left\{ \begin{aligned} &1, &&\mathcal{L}_{\mathrm{rec}}(i) > \mathcal{L}_{\mathrm{rec}}(j) \mathrm{\ and\ } \mathbf{M}_i=\mathbf{M}_j=0, \\ &0, &&\mathrm{otherwise}, \end{aligned} \right. \\ \mathbb{I}^{-}_{ij} = \left\{ \begin{aligned} &1, &&\mathcal{L}_{\mathrm{rec}}(i) < \mathcal{L}_{\mathrm{rec}}(j) \mathrm{\ and\ } \mathbf{M}_i=\mathbf{M}_j=0, \\ &0, &&\mathrm{otherwise}, \end{aligned} \right. \]

其中 \(\mathbf{M}_i=\mathbf{M}_j=0\) 表示对应的patch i, j应当被mask。
对于 \(\mathbb{I}^{+}_{ij}\) ,+表示此时i的损失应当大于j,值为1表示损失需要计算,此时 \(\hat{\mathcal{L}}^s_i - \hat{\mathcal{L}}^s_j\) 以0为界,越大表示预测出来i的损失比j越大,符合target的 \(\mathbf{M}_i=\mathbf{M}_j=0\) ,则损失越小。

以前半部分损失为例,本质上损失通过\(\hat{\mathcal{L}}^s\)之间的关系定义,当 \(\mathbb{I}^{+}_{ij}\) 有值1表示i的真实损失应当大于j,此时如果 \(\hat{\mathcal{L}}^s_i > \hat{\mathcal{L}}^s_j\) 且差值越大,则经过sigmoid函数后值越是比0.5大且接近1,再过log最终会是一个负值,配合最前面的负号,会是一个正值且向0靠近。

Easy-to-Hard Mask Generation

一个自然的想法就是每次迭代过程中,先基于教师模型计算 \(\texttt{argsort}[d_{\psi_t}(f_{\theta_t} (\mathbf{x}))]\) ,然后 mask 掉 top-75% 的 patch。然而,在早期训练阶段,学到的大多是纹理,重建损失与判别性(能决定图像类别的前景主体?)还没有建立起相应的关系。为此作者提出了一种由易到难的掩码生成方式,提供一些合理的提示,引导模型一步一步地重建掩码的硬块。

具体来说,假设 mask ratio 为 \(\gamma\) ,则在 t 次迭代只 mask 掉最大的 \(\alpha_t\gamma N\) 个 patch,剩余 \((1-\alpha_t)\gamma N\) 个需要 mask 的 patch 则随机产生,其中 \(\alpha_t = \alpha_0 + \frac{t}{T}(\alpha_T - \alpha_0).\) 。也就是随着训练的推进,逐渐降低随机mask的比例。

算法1 Pytorch风格的HPM伪代码

Experiment

Training Detail

以 ViT-B/16 为 backbone,预训练 200 epochs

Dataset

ImageNet-1K

Results

Ablation Study

表1 在不同重建目标上的消融研究。第一行是MAE baseline,以自回归地生成RGB像素的形式训练,后面三个以教师模型的特征为target(知识蒸馏),架构都一样(ViT-B/16)。

表5 下游任务上的消融 从表1取了两个预训练模型做下游任务

重建目标的消融。可以看到,不管以什么为重建目标,加入作为\(\mathcal{L}_{\mathrm{pred}}\)额外的损失,并基于此进一步产生更难的mask都能获得性能提升。仅仅引入 \(\mathcal{L}_{\mathrm{pred}}\) 也能够带来性能提升,表明挖掘困难样本的能力本身就能够促使学到更好的特征表示这一点不仅在分类任务上得到体现,下游任务(检测分割)也有相应的体现。

表2 不同mask策略的消融,较大的$\alpha_T$表明代理任务(pretext task)更加困难,但该策略的随机性就会下降。

Esay-to-hard: \(\alpha_t = \alpha_0 + \frac{t}{T}(\alpha_T - \alpha_0).\)

难度大的代理任务确实能够带来性能提升,但保留一定的随机性也是同样必要的。直接掩盖那些预测损失最高的 patch 虽然带来了最难的问题,但图像可判别部分几乎被被掩盖了,意味着 可见的patch 几乎都是背景(见图2)。在没有任何提示的情况下,强迫模型只根据这些背景来重建前景是没有意义的。

表3 不同mask策略的消融。验证在预测的重建损失上使用 argmax(·) 的有效性,argmin(·) 表示每次mask简单的patch。 $\alpha_0 > \alpha_T$ 表示困难patch的使用逐渐增加,是hard-to-easy的方式

进一步地,探究困难的代理任务对于 MIM 是否有帮助。其中, argmin 表示这个任务甚至简单于 random masking,跟 hard-to-easy 一样都会导致性能退化。

表4 预测损失形式上的消融,对比了不加和加上公式4、5的两种损失

MSE 相较于 baseline 能够有提升,但 BCE 是一个更好的选择。

表6 在ImageNet-1K上对比SOTA,一个横线的表示作者实现的版本,两个横线的表示从其他论文抄过来的数据 eff. ep. 表示 Effective pre-training epoch

表6将HPM跟其他方法对比,用于对比的分三类:对比学习、像素回归的MIM、特征蒸馏的MIM

表7 在ADE20k上对比SOTA,两个横线的表示从其他论文抄过来的数据

图4 COCO验证集的可视化,该数据集训练时没见过,右边是预测的重建损失

Conclusion

HPM作为一种即插即用的模块可以无缝接入现有的框架中,性能都能得到提升。MIM的常见问题是线性探测和k-NN分类性能不如对比学习方法。此外HPM由于有个额外的解码器,会有更大的计算开销,比起MAE baseline,在训练ViT-L时会花费1.1倍的时间。将来的方向可以是设计一种更好的损失预测任务,不借助额外辅助解码器。

Critique

思想其实挺简单的,但做了非常多的实验来证明,附录还有一大堆。感觉这类还是看了大概就好,除非要去用它的代码。

Unknown