深度学习面试常用代码:MHA/MQA/GQA/LN/BN/位置编码代码

发布时间 2023-12-11 16:57:08作者: 自私的人

深度学习常用代码

1. MHA(MultiHeadAttention)代码实现


# 1. MHA实现
import torch
import torch.nn as nn
import torch.nn.functional as F

class ScaleDotProductAttention(nn.Module):
    def __init__(self, ):
        super(ScaleDotProductAttention, self).__init__()
        self.softmax = nn.Softmax(dim = -1)

    def forward(self, Q, K, V, mask=None):
        K_T = K.transpose(-1, -2) # 计算矩阵 K 的转置  
        d_k = Q.size(-1)
        # 1, 计算 Q, K^T 矩阵的点积,再除以 sqrt(d_k) 得到注意力分数矩阵
        scores = torch.matmul(Q, K_T) / math.sqrt(d_k)
        # 2, 如果有掩码,则将注意力分数矩阵中对应掩码位置的值设为负无穷大
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        # 3, 对注意力分数矩阵按照最后一个维度进行 softmax 操作,得到注意力权重矩阵,值范围为 [0, 1]
        attn_weights = self.softmax(scores)
        # 4, 将注意力权重矩阵乘以 V,得到最终的输出矩阵
        output = torch.matmul(attn_weights, V)

        return output, attn_weights
      
    
class MultiHeadAttention(nn.Module):
    """Multi-Head Attention Layer
    Args:
        d_model: Dimensions of the input embedding vector, equal to input and output dimensions of each head
        n_head: number of heads, which is also the number of parallel attention layers
    """
    def __init__(self, d_model, n_head):
        super(MultiHeadAttention, self).__init__()
        self.n_head = n_head
        self.attention = ScaleDotProductAttention()
        self.w_q = nn.Linear(d_model, d_model)  # Q 线性变换层
        self.w_k = nn.Linear(d_model, d_model)  # K 线性变换层
        self.w_v = nn.Linear(d_model, d_model)  # V 线性变换层
        self.fc = nn.Linear(d_model, d_model)   # 输出线性变换层

    def forward(self, q, k, v, mask=None):
        # 1. dot product with weight matrices
        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v) # size is [batch_size, seq_len, d_model]
        # 2, split by number of heads(n_head) # size is [batch_size, n_head, seq_len, d_model//n_head]
        q, k, v = self.split(q), self.split(k), self.split(v)
        # 3, compute attention
        sa_output, attn_weights = self.attention(q, k, v, mask)
        # 4, concat attention and linear transformation
        concat_tensor = self.concat(sa_output)
        mha_output = self.fc(concat_tensor)

        return mha_output

    def split(self, tensor):
        """
        split tensor by number of head(n_head)

        :param tensor: [batch_size, seq_len, d_model]
        :return: [batch_size, n_head, seq_len, d_model//n_head], 输出矩阵是四维的,第二个维度是 head 维度

        # 将 Q、K、V 通过 reshape 函数拆分为 n_head 个头
        batch_size, seq_len, _ = q.shape
        q = q.reshape(batch_size, seq_len, n_head, d_model // n_head)
        k = k.reshape(batch_size, seq_len, n_head, d_model // n_head)
        v = v.reshape(batch_size, seq_len, n_head, d_model // n_head)
        """

        batch_size, seq_len, d_model = tensor.size()
        d_tensor = d_model // self.n_head
        split_tensor = tensor.view(batch_size, seq_len, self.n_head, d_tensor).transpose(1, 2)
        # it is similar with group convolution (split by number of heads)

        return split_tensor
      
    
    def concat(self, sa_output):
        """ merge multiple heads back together

        Args:
            sa_output: [batch_size, n_head, seq_len, d_tensor]
            return: [batch_size, seq_len, d_model]
        """
        batch_size, n_head, seq_len, d_tensor = sa_output.size()
        d_model = n_head * d_tensor
        concat_tensor = sa_output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)

        return concat_tensor
      
MHA = MultiHeadAttention(8, 2)
q = torch.ones([2, 3, 8])     # bs, seq_len, dimision
k = torch.ones([2, 3, 8])
v = torch.ones([2, 3, 8])
MHA(q,k,v)

MQA(MultiQueryAttention)代码实现

# MQA实现
class MultiQueryAttention(nn.Module):
    """Multi-Query self attention.
    Using torch or triton attention implemetation enables user to also use
    additive bias.
    """
    def __init__(self,d_model: int,n_heads: int,device: Optional[str] = None):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.Wqkv = nn.Linear(                           # 【关键】Multi-Query Attention 的创建方法
            d_model,
            d_model + 2 * self.head_dim,                 # 只创建 query 的 head 向量,所以只有 1 个 d_model
            device=device,                               # 而 key 和 value 则只共享各自的一个 head_dim 的向量
        )
        self.attn_fn = scaled_multihead_dot_product_attention
        self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
        self.out_proj._is_residual = True  # type: ignore
    def forward(self,x):
        qkv = self.Wqkv(x)                                           # (1, 512, 960)
        query, key, value = qkv.split(                               # query -> (1, 512, 768)
            [self.d_model, self.head_dim, self.head_dim],            # key   -> (1, 512, 96)
            dim=2                                                    # value -> (1, 512, 96)
        )
        context, attn_weights, past_key_value = self.attn_fn(query,key,value,self.n_heads,multiquery=True)
        return self.out_proj(context), attn_weights, past_key_value

GQA(GroupQueryAttention)代码实现

# 3.GQA实现
# 参考:llama2源代码
# https://zhuanlan.zhihu.com/p/649756898?utm_id=0

import torch
import torch.nn as nn


def repeat_kv(x, n_rep):
  bs, slen, n_kv_heads, head_dim = x.size()
  # 根据n_rep扩展kv
  if n_rep == 1:
    return x
  return (x[:,:,:,None,:].expand(bs, slen, n_kv_heads, n_rep, head_dim).reshape(bs, slen, n_kv_heads*n_rep, head_dim))

class Attention(nn.Module):
  def __init__(self, n_heads, n_kv_heads, dim,head_dim, max_batch_size, max_seq_len, model_parallel_size):
    super().__init__()
    
    self.n_local_heads = n_heads // model_parallel_size         # Q的头数       [涉及模型并行]
    self.n_local_kv_heads = n_kv_heads // model_parallel_size   # KV的头数
    self.n_rep = self.n_local_heads // self.n_local_kv_heads    # KV 需要重复的次数
    
    self.wq = nn.Linear(dim, n_heads * head_dim)       # [768, 96=768/8 * 8] Q头数*每个头的dim
    self.wk = nn.Linear(dim, n_kv_heads * head_dim)    
    self.wv = nn.Linear(dim, n_kv_heads * head_dim)
    self.wo = nn.Linear(n_heads * head_dim, dim)
    
  def forward(self, x, mask=None):
    bsz, seqlen, _ = x.size()
    xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
    
    xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
    xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
    x = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
    
    xq, xk = apply_rotary_emb(xq, xk)       # RoPE位置编码
    
    # KV Cache
    
    # repeat K/V heads if n_kv_heads < n_heads
    keys = repeat_kv(keys, self.n_rep)          # [bs,slen,n_kv_heads*n_rep, dim]
    values = repeat_kv(values, self.n_rep)      # [bs,slen,n_kv_heads*n_rep, dim]
    
    scores = torch.matmul(xq, keys.transpose(2,3)) / math.sqrt(self.head_dim)
    
    if mask is not None:
      scores = scores + mask
    
    scores = F.softmax(scores, dim=-1)
    output = torch.matmul(scores, values)
    output = output.transpose(1,2).contiguous().view(bsz, seqlen, -1)
    return self.wo(output)

KV_Cache

import torch.nn as nn

class IncrementalAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(IncrementalAttention, self).__init__()
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.depth = d_model // self.num_heads

        # Q, K, V 的线性层
        self.WQ = nn.Linear(d_model, d_model)
        self.WK = nn.Linear(d_model, d_model)
        self.WV = nn.Linear(d_model, d_model)

        self.softmax = nn.Softmax(dim=-1)

        # 初始化K和V的空缓存
        self.k_cache = None
        self.v_cache = None

    def forward(self, q, k, v, mask=None):
        # 计算Q
        q = self.WQ(q).view(-1, self.num_heads, self.depth)

        # 计算新令牌的K和V
        k_new = self.WK(k).view(-1, self.num_heads, self.depth)
        v_new = self.WV(v).view(-1, self.num_heads, self.depth)

        # 添加到缓存
        if self.k_cache is not None:
            k = torch.cat([self.k_cache, k_new], dim=1)
            v = torch.cat([self.v_cache, v_new], dim=1)
        else:
            k = k_new
            v = v_new
        
        # 更新缓存以供下一次迭代使用
        self.k_cache = k
        self.v_cache = v

        # 注意力机制(简化了,以便简短)
        scores = torch.matmul(q, k.transpose(1, 2)) / self.depth**0.5
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn_weights = self.softmax(scores)
        output = torch.matmul(attn_weights, v)

        return output

Transformer Embedding实现


# Embeddiing实现: PositionEmbedding + TokenEmbedding
import torch
import torch.nn as nn
import torch.nn.functional as F

class PositionalEncoding(nn.Module):
    """
    compute sinusoid encoding.
    """
    def __init__(self, d_model, max_len, device):
        """
        constructor of sinusoid encoding class

        :param d_model: dimension of model
        :param max_len: max sequence length
        :param device: hardware device setting
        """
        super(PositionalEncoding, self).__init__()

        # same size with input matrix (for adding with input matrix)
        self.encoding = torch.zeros(max_len, d_model, device=device)
        self.encoding.requires_grad = False  # we don't need to compute gradient

        pos = torch.arange(0, max_len, device=device)
        pos = pos.float().unsqueeze(dim=1)
        # 1D => 2D unsqueeze to represent word's position

        _2i = torch.arange(0, d_model, step=2, device=device).float()
        # 'i' means index of d_model (e.g. embedding size = 50, 'i' = [0,50])
        # "step=2" means 'i' multiplied with two (same with 2 * i)

        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
        # compute positional encoding to consider positional information of words

    def forward(self, x):
        # self.encoding
        # [max_len = 512, d_model = 512]

        batch_size, seq_len = x.size()
        # [batch_size = 128, seq_len = 30]

        return self.encoding[:seq_len, :]
        # [seq_len = 30, d_model = 512]
        # it will add with tok_emb : [128, 30, 512]
        
        
class TokenEmbedding(nn.Embedding):
    """
    Token Embedding using torch.nn
    they will dense representation of word using weighted matrix
    """

    def __init__(self, vocab_size, d_model):
        """
        class for token embedding that included positional information
        :param vocab_size: size of vocabulary
        :param d_model: dimensions of model
        """
        super(TokenEmbedding, self).__init__(vocab_size, d_model, padding_idx=1)

class TransformerEmbedding(nn.Module):
    """
    token embedding + positional encoding (sinusoid)
    positional encoding can give positional information to network
    """

    def __init__(self, vocab_size, max_len, d_model, drop_prob, device):
        """
        class for word embedding that included positional information
        :param vocab_size: size of vocabulary
        :param d_model: dimensions of model
        """
        super(TransformerEmbedding, self).__init__()
        self.tok_emb = TokenEmbedding(vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model, max_len, device)
        self.drop_out = nn.Dropout(p=drop_prob)

    def forward(self, x):
        tok_emb = self.tok_emb(x)
        pos_emb = self.pos_emb(x)
        return self.drop_out(tok_emb + pos_emb)

LN代码实现

# LN实现
import torch
import torch.nn as nn
import torch.nn.functional as F

class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-12):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True) # '-1' means last dimension. 
        var = x.var(-1, keepdim=True)

        out = (x - mean) / torch.sqrt(var + self.eps)
        out = self.gamma * out + self.beta

        return out

# NLP Example
batch, sentence_length, embedding_dim = 20, 5, 10
embedding = torch.randn(batch, sentence_length, embedding_dim)

# 1,Activate nn.LayerNorm module
layer_norm1 = nn.LayerNorm(embedding_dim)
pytorch_ln_out = layer_norm1(embedding)

# 2,Activate my nn.LayerNorm module
layer_norm2 = LayerNorm(embedding_dim)
my_ln_out = layer_norm2(embedding)

# 比较结果
print(torch.allclose(pytorch_ln_out, my_ln_out, rtol=0.1,atol=0.01))  # 输出 True