1、自注意力案例
import torch import torch.nn as nn class Selfattention(nn.Module): def __init__(self,input_dim): super(Selfattention, self).__init__() self.query = nn.Linear(input_dim,input_dim) self.key = nn.Linear(input_dim,input_dim) self.value = nn.Linear(input_dim,input_dim) self.softmax= nn.Softmax(dim=-1) self.dropout = nn.Dropout(0.2) #选择添加 def forward(self,x): q=self.query(x) k=self.query(x) v=self.query(x) score = torch.bmm(q,k.transpose(1,2)) #K转置后维度才能算内积 att_weight = self.softmax(score) #(batch_size,seq_len,seq_len) context_vector = torch.bmm(att_weight, v) #(batch_size,seq_len,dim) output =self.dropout(context_vector) return output # 创建输入数据 batch_size = 2 seq_len = 3 input_dim = 4 x = torch.randn(batch_size, seq_len, input_dim) self_att= Selfattention(input_dim) output = self_att.forward(x) print (output) #输出 tensor([[[-0.1971, -0.0943, 0.1433, -0.0590], [-0.0270, 0.0940, -0.3715, -0.4299], [ 0.3545, -0.3037, -1.3532, -0.9181]], [[-0.0363, -0.3110, -0.3745, 0.0531], [-0.0182, -0.3026, -0.3475, 0.0766], [-0.0538, -0.3153, -0.3854, 0.0132]]], grad_fn=<BmmBackward0>)
2、多头注意力案例
import torch import torch.nn as nn class MultiheadAttention(nn.Module): def __init__(self, input_dim, num_heads): super(MultiheadAttention, self).__init__() self.num_heads = num_heads self.query = nn.Linear(input_dim, input_dim) self.key = nn.Linear(input_dim, input_dim) self.value = nn.Linear(input_dim, input_dim) self.softmax = nn.Softmax(dim=2) def forward(self, x): q = self.query(x) k = self.key(x) v = self.value(x) batch_size, seq_len, _ = x.size() head_dim = input_dim // self.num_heads # Calculate the head dimension print ("origin_q_size:"+str(q.size())) # Reshape query, key, and value tensors q = q.view(batch_size, seq_len, self.num_heads, head_dim).transpose(1, 2) k = k.view(batch_size, seq_len, self.num_heads, head_dim).transpose(1, 2) v = v.view(batch_size, seq_len, self.num_heads, head_dim).transpose(1, 2) print ("trans_q_size:"+str(q.size())) scores = torch.matmul(q, k.transpose(-2, -1)) attention_weights = self.softmax(scores) output = torch.matmul(attention_weights, v) output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, input_dim) return output # 创建输入数据 batch_size = 2 seq_len = 3 input_dim = 4 num_head =2 x = torch.randn(batch_size, seq_len, input_dim) mh_att= MultiheadAttention(input_dim,num_head) output = mh_att.forward(x) print (output) #输出 origin_q_size:torch.Size([2, 3, 4]) trans_q_size:torch.Size([2, 2, 3, 2]) tensor([[[ 0.3837, 0.1198, -0.2621, 0.3031], [ 0.4219, -0.0307, -0.1981, 0.2589], [ 0.8475, -0.2611, -0.3065, 0.3577]], [[ 0.1067, 0.4801, -0.4982, 0.2262], [ 0.1212, 0.7254, -0.4276, 0.2498], [ 0.0941, 0.3030, -0.4351, 0.1991]]], grad_fn=<ViewBackward0>)