Bert Pytorch 源码分析:四、编解码器

发布时间 2023-06-26 15:25:25作者: 绝不原创的飞龙
# Bert 编码器模块
# 由一个嵌入层和 NL 个 TF 层组成
class BERT(nn.Module):
    """
    BERT model : Bidirectional Encoder Representations from Transformers.
    """

    def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):
        """
        :param vocab_size: vocab_size of total words
        :param hidden: BERT model hidden size
        :param n_layers: numbers of Transformer blocks(layers)
        :param attn_heads: number of attention heads
        :param dropout: dropout rate
        """

        super().__init__()
		# 嵌入大小 ES
        self.hidden = hidden
		# TF 层数 NL
        self.n_layers = n_layers
		# 头部数量 HC
        self.attn_heads = attn_heads

        # FFN 层中的隐藏单元数量,记为 FF,一般是 ES 的四倍
        self.feed_forward_hidden = hidden * 4

        # 嵌入层,嵌入矩阵尺寸 VS * ES
        self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden)

        # NL 个 TF 层
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)])

    def forward(self, x, segment_info):
        # 为`<pad>`(ID = 0)设置掩码
		# 尺寸为 BS * 1 * ML * ML,以便与相似性矩阵 S 匹配
		# 在每个 BS 的 ML * ML 矩阵中,`<pad>`标记对应的行为 1,其余为零
        mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)

        # 单词 ID 传入嵌入层得到词向量
        x = self.embedding(x, segment_info)

        # 依次传入每个 TF 层,得到编码器输出
        for transformer in self.transformer_blocks:
            x = transformer.forward(x, mask)

        return x

# 解码器结构根据具体任务而定
# 任务一般有三种:(1)序列分类,(2)标记分类,(3)序列生成
# 但一般都是全连接的

# 用于下个句子判断的解码器
# 序列分类任务,输入两个句子,输出一个标签,1表示是相邻句子,0表示不是
class NextSentencePrediction(nn.Module):
    """
    2-class classification model : is_next, is_not_next
    """

    def __init__(self, hidden):
        """
        :param hidden: BERT model output size
        """
        super().__init__()
		# 将向量压缩到两维, 尺寸为 ES * 2
        self.linear = nn.Linear(hidden, 2)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
		# 输入 -> 取第一个向量 -> LL -> softmax -> 输出
		# 输出相邻句子和非相邻句子的概率
        return self.softmax(self.linear(x[:, 0]))

# 用于完型填空的解码器
# 序列生成任务,输入是带有`<mask>`的句子,输出是完整句子
class MaskedLanguageModel(nn.Module):
    """
    predicting origin token from masked input sequence
    n-class classification problem, n-class = vocab_size
    """

    def __init__(self, hidden, vocab_size):
        """
        :param hidden: output size of BERT model
        :param vocab_size: total vocab size
        """
        super().__init__()
		# 将输入压缩到词汇表大小
        self.linear = nn.Linear(hidden, vocab_size)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
		# 输入 -> LL -> softmax -> 输出
		# 输出序列中每个词是词汇表中每个词的概率
        return self.softmax(self.linear(x))