prompt gating代码探索

发布时间 2023-08-09 17:36:31作者: 鸽鸽的书房
import torch

def promptGating(gating, adding, x):
    '''
    gating: (num_prefix, dim)  
    adding: (num_prefix, dim) 
    x: (seq_length, batch_size, dim) 
    '''
    if gating is not None:
        gating = gating.unsqueeze(0).expand(x.size(1), -1, -1).transpose(0, 1) # (num_prefix,batch_size,dim)
        gating = torch.cat([gating, torch.ones([x.size(0)-gating.size(0), x.size(1), x.size(2)])], axis=0) 
        # (seq_length, batch_size, dim)
        x = x * gating # prefix之外*1

        if adding is not None: #相当于加上bias
            adding = adding.unsqueeze(0).expand(x.size(1), -1, -1).transpose(0, 1) 
            adding = torch.cat([adding, torch.zeros([x.size(0)-adding.size(0), x.size(1), x.size(2)])], axis=0)

            x = adding + x  # prefix之外+0
    return x

if __name__ == "__main__":
    num_prompt, batch_size, seq_length, dim = 2, 8, 22, 1024
    gating = torch.randn(num_prompt, dim) 
    adding = torch.randn(num_prompt, dim) 
    x = torch.randn(seq_length, batch_size, dim) 

    new_x = promptGating(gating, adding, x)
    print(new_x.shape)
    # 输出:torch.Size([22, 8, 1024])