简洁优美的深度学习包-bert4keras

发布时间 2023-06-18 00:35:11作者: China Soft

新手友好bert4keras

https://spaces.ac.cn/


在鹅厂实习阶段,follow苏神(科学空间)的博客,启发了idea,成功改进了线上的一款模型。想法产出和实验进展很大一部分得益于苏神设计的bert4keras,清晰轻量、基于keras,可以很简洁的实现bert,同时附上了很多易读的example,对nlp新手及其友好!本文推荐几篇基于bert4keras的项目,均来自苏神,对新手入门bert比较合适~

项目1:测试bert的mlm
项目地址:basic_masked_language_model

tokenizer:分词器,主要方法:encode,decode。
build_transformer_model:建立bert模型,建议看源码,可以加载多种权重和模型结构(如unilm)。
import numpy as np
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
from bert4keras.snippets import to_array

config_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_model.ckpt'
dict_path = '/root/kg/bert/chinese_L-12_H-768_A-12/vocab.txt'

tokenizer = Tokenizer(dict_path, do_lower_case=True) # 建立分词器
model = build_transformer_model(
config_path=config_path, checkpoint_path=checkpoint_path, with_mlm=True
) # 建立模型,加载权重

token_ids, segment_ids = tokenizer.encode(u'科学技术是第一生产力')

# mask掉“技术”
token_ids[3] = token_ids[4] = tokenizer._token_mask_id
token_ids, segment_ids = to_array([token_ids], [segment_ids])

# 用mlm模型预测被mask掉的部分
probas = model.predict([token_ids, segment_ids])[0]
print(tokenizer.decode(probas[3:5].argmax(axis=1))) # 结果正是“技术”

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
项目2:句子对分类任务
项目地址:task_sentence_similarity_lcqmc
核心模型代码:

句子1和句子2拼接在一起输入bert。
bert模型的pooler输出经dropout和mlp投影到2维空间,做分类问题。
最终整个模型是一个标准的keras model。
class data_generator(DataGenerator):
"""数据生成器
"""
def __iter__(self, random=False):
batch_token_ids, batch_segment_ids, batch_labels = [], [], []
for is_end, (text1, text2, label) in self.sample(random):
token_ids, segment_ids = tokenizer.encode(
text1, text2, maxlen=maxlen
)
batch_token_ids.append(token_ids)
batch_segment_ids.append(segment_ids)
batch_labels.append([label])
if len(batch_token_ids) == self.batch_size or is_end:
batch_token_ids = sequence_padding(batch_token_ids)
batch_segment_ids = sequence_padding(batch_segment_ids)
batch_labels = sequence_padding(batch_labels)
yield [batch_token_ids, batch_segment_ids], batch_labels
batch_token_ids, batch_segment_ids, batch_labels = [], [], []

# 加载预训练模型
bert = build_transformer_model(
config_path=config_path,
checkpoint_path=checkpoint_path,
with_pool=True,
return_keras_model=False,
)

output = Dropout(rate=0.1)(bert.model.output)
output = Dense(
units=2, activation='softmax', kernel_initializer=bert.initializer
)(output)

model = keras.models.Model(bert.model.input, output)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
项目3:标题生成任务
项目地址:task_seq2seq_autotitle
NLG任务很方便用unilm结构实现,而bert4keras实现unilm只需一个参数。

model = build_transformer_model(
config_path,
checkpoint_path,
application='unilm',
keep_tokens=keep_tokens, # 只保留keep_tokens中的字,精简原字表
)
1
2
3
4
5
6
NLG任务的loss是交叉熵,示例中的实现很美观:

CrossEntropy类继承Loss类,重写compute_loss。
将参与计算loss的变量过一遍CrossEntropy,这个过程中loss会被计算,具体阅读Loss类源码。
最终整个模型是一个标准的keras model。
class CrossEntropy(Loss):
"""交叉熵作为loss,并mask掉输入部分
"""
def compute_loss(self, inputs, mask=None):
y_true, y_mask, y_pred = inputs
y_true = y_true[:, 1:] # 目标token_ids
y_mask = y_mask[:, 1:] # segment_ids,刚好指示了要预测的部分
y_pred = y_pred[:, :-1] # 预测序列,错开一位
loss = K.sparse_categorical_crossentropy(y_true, y_pred)
loss = K.sum(loss * y_mask) / K.sum(y_mask)
return loss


model = build_transformer_model(
config_path,
checkpoint_path,
application='unilm',
keep_tokens=keep_tokens, # 只保留keep_tokens中的字,精简原字表
)

output = CrossEntropy(2)(model.inputs + model.outputs)

model = Model(model.inputs, output)
model.compile(optimizer=Adam(1e-5))
model.summary()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
预测阶段自回归解码,继承AutoRegressiveDecoder类可以很容易实现beam_search。

项目4:SimBert
项目地址:SimBert
融合了unilm和对比学习,data generator和loss类的设计很巧妙,值得仔细阅读,建议看不懂的地方打开jupyter对着一行一行print来理解。

项目5:SPACES:“抽取-生成”式长文本摘要
项目地址:SPACES
一个比较全面的项目,核心部分是BioCopyNet+Unilm。

总结
bert4keras项目的优点:

build_transformer_model一句代码构建bert模型,一个参数即可切换为unilm结构。
继承Loss类,重写compute_loss方法,很容易计算loss。
深度基于keras,训练、保存和keras一致。
丰富的example!苏神的前沿算法研究也会附上bert4keras实现。
————————————————
版权声明:本文为CSDN博主「一只用R的浣熊」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/weixin_44597588/article/details/123910248