[论文阅读] Momentum contrast for unsupervised visual representation learning

发布时间 2023-08-31 17:03:18作者: silly丶

Momentum contrast for unsupervised visual representation learning

Introduction

我们提出了动量对比(MoCo)作为一种构建具有对比损失的无监督学习的大型一致字典的方法(图1)。

我们将字典维护为数据样本队列:当前小批的编码表示被排队,最旧的被踢出queue。队列将字典大小与mini-batch大小解耦,使字典能够很大。此外,由于字典键来自于前面的几个mini-batch,因此提出了一种缓慢更新的键编码器以保持一致性,具体实现为查询编码器的基于动量的移动平均。

image-20230831095206795

Figure 1

图1所示。动量对比(MoCo)通过使用对比损失将编码查询q与编码键的字典进行匹配来训练视觉表示编码器。字典键 \(\{k_0,k_1,k_2\}\) 由一组数据样本动态定义。字典被构建为一个队列,当前的mini-batch进入队列,最旧的mini-batch退出队列,将其与mini-batch大小解耦。键由一个缓慢推进的编码器编码,由查询编码器的动量更新驱动。这种方法为学习视觉表示提供了一个大而一致的字典。

Method

Contrastive Learning as Dictionary Look-up

考虑一个编码查询 \(q\) 和一组编码样本 \(\{k_0,k_1,k_2,\dots\}\) 作为字典的键。假设字典中有一个键(表示为 \(k_+\)),与 \(q\) 匹配。对比损失是一个函数,当 \(q\) 与其正键 \(k_+\) 相似并且与所有其他键(对于\(q\) 被视为负键)不相似时,其值较低。通过点积来度量相似性,本文考虑了一种对比损失函数的形式,称为 InfoNCE:

\[\mathcal{L}_1=-\log\frac{exp(q\cdot k_+/ \tau)}{\sum_{i=0}^Kexp(q\cdot k_i / \tau)} \tag{1} \]

其中τ是温度超参数。求和是在一个正样本和 \(K\) 个负样本上进行的。直观地说,这个损失是基于 (K+1) 类 softmax 分类器的对数损失,该分类器试图将 \(q\) 分类为 \(k+\)

对比损失用作训练提取queries和keys表示的编码器网络 的无监督目标函数。一般来说,查询表示为 \(q=f_q(x^q)\),其中 \(f_q\) 是一个编码器网络,\(x^q\) 是一个查询样本(同样,\(k=f_k(x^k)\))。它们的实例化取决于特定的预文本任务。输入 \(x^q\)\(x^k\) 可以是图像、patches 或由一组patches组成的上下文。网络 \(f_q\)\(f_k\) 可以是相同的、部分共享的或不同的。

Momentum Contrast

我们的假设是,好的特征可以通过一个包含丰富负样本集的大字典来学习,而且字典键的编码器尽管在进化中仍尽可能保持一致。

Dictionary as a queue

我们方法的核心是将字典维护为数据样本的队列。这使我们能够重复使用来自前几个最近的mini-batch的编码键。引入队列的概念将字典大小与mini-batch大小解耦。我们的字典大小可以比典型的mini-batch大小大得多,并且可以作为超参数进行灵活独立地设置。字典中的样本逐渐被替换。当前的mini-batch被入队到字典中,而队列中最旧的mini-batch被移除。

Momentum update

使用队列可以使字典变得很大,但它也使得通过反向传播来更新键编码器变得难以处理(梯度应该传播到队列中的所有样本)。我们提出了一个动量更新来解决这个问题。

形式上,将 \(f_k\) 的参数记为 \(θ_k\)\(f_q\) 的参数记为 \(θ_q\),我们将\(θ_k\)更新为

\[\theta_k\leftarrow m\theta_k+(1-m)\theta_q \tag{2} \]

这里的 \(m \in [0,1)\) 是一个动量系数。只有参数 \(θ_q\) 通过反向传播进行更新。方程(2)中的动量更新使得 \(θ_k\) 的演变比 \(θ_q\) 更加平稳。因此,尽管队列中的键是由不同的编码器编码的(在不同的mini-batch中),但这些编码器之间的差异可以变得很小。在实验中,相对较大的动量值(例如,m = 0.999,我们的默认值)比较小的值(例如,m = 0.9)要好得多,这表明缓慢更新的键编码器是充分利用队列的关键。

image-20230831162622010