ALBEF-ITC损失部分

发布时间 2023-11-21 21:20:56作者: Laplace蒜子

Align before Fuse: Vision and Language Representation Learning with Momentum Distillation

引言

VLP目标是从大规模图片-文本对子中学习到多模态表示,一次改进下游的视觉-语言任务。

VLP框架的局限性如下:

  1. 图片特征和文字token嵌入在它们各自的空间内,使得多模态encoder难以去学习它们之间的关系。
  2. 目标decoder既需要大量标注也需要大量的计算资源,因为其在预训练时候需要边界框的标注和高分辨率的图片(600x1000)。
  3. 许多图片-文本数据集来自于网络,通常包含噪声,导致如MLM等模型可能会拟合噪声文本,降低模型的泛化能力。

 

作者提出了一种新的VLP框架ALBEF,以此解决以上问题:

1. 跨模态注意力:

作者首先使用detector-free的图片encoder(不需要检查特征点,直接匹配)和文本encoder对图片和文本编码。

然后使用多模态编码器通过跨模态注意力去融合图片特征和文本特征。

 

2. 作者提出了图片-文本对比(ITC)损失:

对齐图片特征和文本特征,使得其更容易用于多模态编码器执行跨模态学习。

帮助单模态编码器更好的理解图片和文本的语义

学习一个低维空间去嵌入图片和文本,可以使得图片-文本匹配目标挖掘更多有信息的样本。

 

3. 为了在噪声监督下学习,作者还提出了动量蒸馏(MoD):

在训练期间,通过获取模型的参数的移动平均值,保持模型的一个动量版本。然后使用动量模型生成伪目标作为额外的监督。

MoD模型不会因为产生与网络注释不同的输出而受到惩罚。

MoD不仅改进了预训练,也对下游任务的标注进行清洗。

 

方法

图-文对比学习

首先,图片编码器和文本编码器都会在图片序列和文本序列的首部加上[CLS]标签,表示学习到的图片全局表示。

之后的对比就是基于[CLS]向量的对比。

图片和文本的[CLS]分别用vclswcls表示,动量编码器的输出特征分别使用g'w(w'cls)和g'v(v'cls)表示

对比学习,是学习与动量编码器输出的相似度。

s(I,T)=gv(vcls)T g'w(w'cls)

s(T,I)= gw(wcls)Tg'v(v'cls)

对于每个图片和文本,计算归一化的图片对文本的相似度和文本对图片的相似度。

τ是温度超参数。Tm是动量编码器输出的所有图片的[CLS],Im是动量编码器输出的所有文本[CLS]。

图文对比学习损失ITC如下:

其中H为交叉熵损失,y为Ground Truth标签。(在实际预训练中,代码中y采用的是伪标签)

已知交叉熵损失

代入ITC损失,得到

其中预测概率p为

  

其中s(I,T)是当前Image与一个Text的相似度。最终需要计算当前Image与所有Text的相似度,所以在源码中,是直接计算I与动量编码器Text队列中所有的Text的相似度。s(T,I)也是如此。

代入到ITC损失中得到源代码中的计算公式(对应源代码中不蒸馏部分)

#源码中,映射图片和文本的全连接层输出embedding_dim为256
#一批输入中,有N个图片和N个句子
#图片和文本队列大小都为57600,维度为256,也就是是可以保存57600个维度为256的[CLS]
#为了存储方便,队列形状设置为256 x 57600
image_embeds = self.visual_encoder(image) 
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) 
#image_feat形状为:(N,256)
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,return_dict = True, mode = 'text')            
text_embeds = text_output.last_hidden_state
text_feat = F.normalize(self.text_proj(text_embeds[:,0,:]),dim=-1)               
#text_feat形状为:(N,256)

idx = idx.view(-1,1)
#idx为图片-文本对的标签,分为一致2,中性1,对立0。
#原本形状为(N,),现在变为(N,1)
#idx转置成形状(1,N),idx_queue形状为(1,57600)
#然后将idx拼接到队列的头部得到idx_all,形状为(1,N+57600)
idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1)  
pos_idx = torch.eq(idx, idx_all).float()
#idx形状为(N,1),idx_all形状为(1,N+57600)
#比较之后,比较矩阵为(N,N+57600),表示N个标签分别与N+57600个的比较结果。
#由于队列的头部是新添加的标签,新标签与其比较时,自然而然对角线为1。
sim_targets = pos_idx / pos_idx.sum(1,keepdim=True) #硬标签

with torch.no_grad():
    self._momentum_update()#更新动量编码器
    image_embeds_m = self.visual_encoder_m(image) 
    image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)  
    image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)                                         
    #image_feat_m转置后形状为:256 x 2 , Image队列的形状为256 x 57600
    #上述拼接操作是将队列复制一份,并将image_feat_m拼接到队列的头部!。
    
    text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,return_dict = True, mode = 'text')    
    text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) 
    text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
    #文本也一致    
    #text_feat_all和image_feat_all分别为text队列和image队列中所有的[CLS]集合

#计算图文特征分别对队列中所有特征的相似度
sim_i2t = image_feat @ text_feat_all / self.temp 
sim_t2i = text_feat @ image_feat_all / self.temp

loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_targets,dim=1).mean()#计算与硬标签的交叉熵损失
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_targets,dim=1).mean()

loss_ita = (loss_i2t+loss_t2i)/2

self._dequeue_and_enqueue(image_feat_m, text_feat_m, idx)

 动量蒸馏图文对比损失

带蒸馏则需要动量编码器输出的新样本与队列中所有样本的相似度

然后最小化q和p之间的KL散度

代入原始式子

最小化KL散度的等价关系如下

最小化原式子等价于

得到如下公式(对应于源代码中的式子):

 

源代码如下:

#源码中,映射图片和文本的全连接层输出embedding_dim为256
#一批输入中,有N个图片和N个句子
#图片和文本队列大小都为57600,维度为256,也就是是可以保存57600个维度为256的[CLS]
#为了存储方便,队列形状设置为256 x 57600
image_embeds = self.visual_encoder(image) 
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) 
#image_feat形状为:(N,256)
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,return_dict = True, mode = 'text')            
text_embeds = text_output.last_hidden_state
text_feat = F.normalize(self.text_proj(text_embeds[:,0,:]),dim=-1)               
#text_feat形状为:(N,256)

idx = idx.view(-1,1)
#idx为图片-文本对的标签,分为一致2,中性1,对立0。
#原本形状为(N,),现在变为(N,1)
#idx转置成形状(1,N),idx_queue形状为(1,57600)
#然后将idx拼接到队列的头部得到idx_all,形状为(1,N+57600)
idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1)  
pos_idx = torch.eq(idx, idx_all).float()
#idx形状为(N,1),idx_all形状为(1,N+57600)
#比较之后,比较矩阵为(N,N+57600),表示N个标签分别与N+57600个的比较结果。
#由于队列的头部是新添加的标签,新标签与其比较时,自然而然对角线为1。
sim_targets = pos_idx / pos_idx.sum(1,keepdim=True)
        
with torch.no_grad():
    self._momentum_update()#更新动量编码器
    image_embeds_m = self.visual_encoder_m(image) 
    image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)  
    image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)                                         
    #image_feat_m转置后形状为:256 x 2 , Image队列的形状为256 x 57600
    #上述拼接操作是将队列复制一份,并将image_feat_m拼接到队列的头部!。
    
    text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,return_dict = True, mode = 'text')    
    text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) 
    text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
    #文本也一致
    
    #text_feat_all和image_feat_all分别为text队列和image队列中所有的[CLS]集合
    #动量蒸馏,创建软标签
    sim_i2t_m = image_feat_m @ text_feat_all / self.temp 
    sim_t2i_m = text_feat_m @ image_feat_all / self.temp   
    sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
    sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets 

#计算图文特征分别对队列中所有特征的相似度
sim_i2t = image_feat @ text_feat_all / self.temp 
sim_t2i = text_feat @ image_feat_all / self.temp

#动量蒸馏,计算与软标签的等价KL散度。
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() 
    
loss_ita = (loss_i2t+loss_t2i)/2

self._dequeue_and_enqueue(image_feat_m, text_feat_m, idx)