【论文阅读笔记】Learning to Prompt for Continual Learning

发布时间 2023-04-11 20:14:17作者: 空口吃大蒜
Create_time: April 27, 2022 5:21 PM
Edited_by: Huang Yujun
Org: Google Research, Northeastern University

Learning to Prompt for Continual Learning

L2P(1)

[38]Learning to Prompt for Continual Learning.pdf

问题:

  • 最终输入transformer encoder的序列长度是怎么组成的,原始输入如何编码,是否要加上position embeding(已知class token为预训练模型的一部分)

0. 对 prompt 的背景知识的补充

1. Contribution

  • 提出将NLP领域做迁移学习的 Promting 引入到持续学习中,使用 prompt 来表征 task 的信息
  • 尝试不使用task id 而是通过key-value的方式,来选择task specific prompt

2. Motivation

持续学习本质是一个拟合 “随序列动态变化的数据分布” 的问题。目前的工作主要集中于调整整个模型的参数以拟合数据分布。

目前的主流方法仍面临许多挑战

  • 样本回放方法
    • 性能受 buffer 大小影响很大
    • 在严格限制额外存储空间的场景不适用
  • 任务增量的方法
    • 需要在测试时知道任务id

现有方法的都面临两个问题:

(1) Whether the form of episodic memory can go beyond buffering past data to more intelligent and succinct episodic memory system?
(2) How to automatically select relevant knowledge component for arbitrary sample without knowing its task identity?

本文从这两个问题出发,发现在NLP领域的 prompting 技术可以处理第一个问题,即(粗略的理解)使用一部分 task-specific 的参数来学习task的知识,但是保持主体网络不变(一个预训练得非常好的大模型)。在推断时一个 item 时,使用这部分 task-specific 的参数产生 hint 来引导大模型做预测。

对第二个问题,作者通过对 task-specific promt 建立键值对的索引关系,从而解决更一般的task id 不可知的情况。

3. Methodology

首先对 Prompt Tuning(PT) 在图像这个模态上做定义(实际上seq2seq的模型还可以用到其他模态上),同时也对一些符号及其内涵做定义。

一开始我们有一张 2D 的图片 \(x\in R^{H\times W \times C}\) 以及一个预训练好的 ViT 模型 \(f=f_{r}\circ f_{e}\),其中 \(f_{e}\) 是一个输入 embeding 层, \(f_{r}\) 是一组 self-attention 层。首先,这张2D图片 x 会 reshape 成 \(x_{p}\in R^{L\times (S^{2}\cdot C)}\),其中 L 是序列中包含的 token 数(如图片中 patch 的数量),S 是 patch 的边长,C 是原图的通道数,这个 \(x_{p}\) 会经过 embeding 层映射成 \(x_{e}=f_{e}(x) \in R^{L\times D}\),其中 D 是 embeding dimention。

在使用预训练好的大模型处理下游任务时,Prompt Tuning的方法,会将大模型直接固定。PT最直接的使用方式是,使用一个可学习的前缀 \(P_{e}\in R^{L_{p}\times D}\) (称为 prompt,其token数量为 \(L_{p}\) ,维度与前面词向量的维度D相同)与前面提到的图片嵌入 \(x_{e}\) 做拼接 \(x_{p}=[P_{e};x_{e}]\),得到最终输入大模型的嵌入序列。最后再由大模型 encode在经过分类器得到预测类别 \(g_{\phi}(f_{r}(x_{p}))\),其中 \(g_{\phi}\) 是分类器,\(\phi\) 是分类器参数。

3.1 From prompt to prompt pool

作者发现 Prompt Tuning 技术不能够直接用到持续学习的场景中,主要是因为3个原因:

  • task id在测试时不知道
  • 及时我们在测试时找到了对应的task-speicifc prompt ,但这种方式阻止了相似task的知识共享
  • 如果按照最朴素的思路,使用一个prompt学习所有的task,这又会导致严重的遗忘

最合理的模型应该是能够共享相似task之间的知识,同时也能够保证task之间的独立性。基于这种思路,作者将 Prompt Tuning 中只一个任务只使用一个 prompt 的方式,改成了使用一个 task 可使用多个 prompt 的方式,作者称保存下来的所有 task 对应的 prompt (可学习参数)为 prompt pool。

L2P(2)

其中,\(P_{i}\in R^{L_{p}\times D}\) 是单个 prompt 。(这里对 M 的标识有点奇怪,应该是指所有task的数量)

为了能够共享task之间的知识,作者这里不只是使用一个 prompt 与 task 对应,而是使用了多个prompt 与 一个 task 对应。下面表达式中的分号标识concat,\(\{s_{i}\}_{i=1}^{N}\) 是 P 的子集。

L2P(3)

作者认为,prompt 可以自由组合,这样就能够联合编码各种知识,如视觉特征或者task特有的特征

3.2 Instance-wise prompt query

为了解决在多数场景下,测试时没有 task id 的问题,作者设计了一种 key-value 方式索引 task 对应的 prompt。这里定义键值对为:\(\{(k_{1},P_{1}),(k_{2},P_{2}),\cdot\cdot\cdot,(k_{M},P_{M})\}\),其中 \(k_{i}\in R^{D_{k}}\) 是一个可学习参数向量。所有的 key 可定义为 \(K=\{k_{i}\}^{M}_{i=1}\)

虽然确定了以 key-value 的方式索引 task-specific prompt ,但如何建立输入图像与 key 之间的映射关系也是个难题。对此,作者希望设置一个能够区分不同 task 且不需要训练的 query function \(q:R^{H\times W\times C}\rightarrow R^{D_{k}}\) ,于是作者直接使用了大模型的 encoder 提取输入图片的特征(对应 class token 的输出)作为 query ,即 \(q(x)=f(x)[0,:]\) 。最后,使用余弦相似度计算 query 与 key 之间的相似程度,从而选出前 N 个最小的 key ,得到对应的prompt。

L2P(4)

其中,\(\gamma :R^{D_{k}}\times R^{D_{k}}\rightarrow R\) 是一个评价 query 与 key 相似度的函数(作者发现余弦相似度效果最好)

以上的公式都是针对 训练+测试 时不知道 task id 的情况下,作者设计了一个含先验知识的 query-key 匹配方式(附加模块,可加可不加,加了会更好)。这个时候需要维护一个 frequency query table \(H_{t}=[h_{1},h_{2},\cdot\cdot\cdot,h_{M}]\),其中每一项表示的是到 t-1 个task之前,每一次选中 \(P_{i}\) 的频率(已做好了归一化)。这个频率的计算方式是训练好prompt后重新过一次数据,这个时候累计各个prompt被调用的次数

L2P(5)

这里的 \(h_{s_{i}}\) 会惩罚经常使用的 prompt,使得更加多样的 key 能够被query-key 映射函数选中。

3.3 Optimization objective for L2P

以下为学习 key, prompt 以及分类器参数 的损失函数

L2P(6)

这里 \(f_{r}^{avg}=AvgPool(f_{r}(x_{p}[0:NL_{p},:]))\) ,由于之前选择了 N 个最匹配的 prompt ,在经过self-attention layer stack 后得到了 \(NL_{p}\) 个 tokens,这里会对这输出的这些 token 做 avgPooling 如下图所示:

L2P(1)

4. Experiments

下面表格中的acc均是指 average acc

对cifar100,结果均是b10i10的设置下

5-datasets,指 CIFAR-10, MNIST, Fashion-MNIST, SVHN, and notMNIST

4.1 Compare with SOTAs

L2P(7)

下表为本文方法与模型扩展方法的比较

L2P(8)

4.2 Results on domain-incremental learning

L2P(9)

4.3 Results on task-agnostic learning

L2P(10)

下面这个图展示了prompt 与 task 对应的id被选中的频率,可见,对于task之间差异不大的情况,相似的prompt会被频繁调用,而在task差异较大的情况,相似的prompt被调用得比较少

L2P(11)
L2P(12)

第一行指不适用prompt pool 而使用一个prompt去学习所有的task;

第二行指不使用可学习的key,而是使用 KNN 的方式(计算各个task的类中心)选择prompt;第三行指训练测试时不使用多组prompt,即N=1。

第三行中的去掉多组prompt的结果反而上升了,表明在task差异较大的情况下,使用单一的prompt能够减少不同task之间知识的干扰

4.5 Effect of hyperparameters for L2P

L2P(13)
Figure 4. Left-Middle: Average accuracy w.r.t prompt length $L_{p}$ and prompt selection size N for Split CIFAR-100 and 5-datasets, respectively, given M = 20. Right: Average accuracy (%) w.r.t. prompt pool size M, given $L_{p}$ = 5, N = 5 for Split CIFAR-100 and $L_{p}$ = 5, N = 4 for 5-datasets.

作者发现,prompt 的维度不能够太小也不能太大,否则会出现欠拟合和过拟合的结果。此外,prompt pool的越大,对于task差异比较大的数据集提升效果显著,但对于task差异小的数据集提升较小。