attention案列

发布时间 2023-10-09 18:11:46作者: glowwormss
1、自注意力案例
import torch
import torch.nn as nn

class Selfattention(nn.Module):
    def __init__(self,input_dim):
    
        super(Selfattention, self).__init__()
        
        self.query = nn.Linear(input_dim,input_dim)
        self.key = nn.Linear(input_dim,input_dim)
        self.value = nn.Linear(input_dim,input_dim)
        self.softmax= nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(0.2) #选择添加
        
    def forward(self,x):
        q=self.query(x)
        k=self.query(x)
        v=self.query(x)
        
        score = torch.bmm(q,k.transpose(1,2))  #K转置后维度才能算内积
        att_weight = self.softmax(score)  #(batch_size,seq_len,seq_len)       
        context_vector  =  torch.bmm(att_weight, v) #(batch_size,seq_len,dim)      
        output =self.dropout(context_vector)
        
        return output
    
# 创建输入数据
batch_size = 2
seq_len = 3
input_dim = 4


x = torch.randn(batch_size, seq_len, input_dim) 
self_att= Selfattention(input_dim)
output = self_att.forward(x)

print (output)


#输出
tensor([[[-0.1971, -0.0943,  0.1433, -0.0590],
         [-0.0270,  0.0940, -0.3715, -0.4299],
         [ 0.3545, -0.3037, -1.3532, -0.9181]],

        [[-0.0363, -0.3110, -0.3745,  0.0531],
         [-0.0182, -0.3026, -0.3475,  0.0766],
         [-0.0538, -0.3153, -0.3854,  0.0132]]], grad_fn=<BmmBackward0>)

 

 

 

2、多头注意力案例

import torch
import torch.nn as nn

class MultiheadAttention(nn.Module):
    def __init__(self, input_dim, num_heads):
        super(MultiheadAttention, self).__init__()
        self.num_heads = num_heads
        self.query = nn.Linear(input_dim, input_dim)
        self.key = nn.Linear(input_dim, input_dim)
        self.value = nn.Linear(input_dim, input_dim)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, x):
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        batch_size, seq_len, _ = x.size()
        

        head_dim = input_dim // self.num_heads  # Calculate the head dimension
        
        print ("origin_q_size:"+str(q.size()))
        # Reshape query, key, and value tensors
        q = q.view(batch_size, seq_len, self.num_heads, head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.num_heads, head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.num_heads, head_dim).transpose(1, 2)
        
        print ("trans_q_size:"+str(q.size()))

        scores = torch.matmul(q, k.transpose(-2, -1))
        attention_weights = self.softmax(scores)

        output = torch.matmul(attention_weights, v)
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, input_dim)

        return output

# 创建输入数据
batch_size = 2
seq_len = 3
input_dim = 4
num_head =2


x = torch.randn(batch_size, seq_len, input_dim) 
mh_att= MultiheadAttention(input_dim,num_head)
output = mh_att.forward(x)

print (output)

#输出
origin_q_size:torch.Size([2, 3, 4])
trans_q_size:torch.Size([2, 2, 3, 2])
tensor([[[ 0.3837,  0.1198, -0.2621,  0.3031],
         [ 0.4219, -0.0307, -0.1981,  0.2589],
         [ 0.8475, -0.2611, -0.3065,  0.3577]],

        [[ 0.1067,  0.4801, -0.4982,  0.2262],
         [ 0.1212,  0.7254, -0.4276,  0.2498],
         [ 0.0941,  0.3030, -0.4351,  0.1991]]], grad_fn=<ViewBackward0>)