基于SentencePiece扩充LLaMa中文词表

发布时间 2023-09-25 09:57:48作者: sunshine丶23

Sentencepiecegoogle开源的文本Tokenzier工具,其主要原理是利用统计算法,在语料库中生成一个类似分词器的工具,外加可以将词token化的功能;对比开源的分词器,它会将频繁出现的字符串作为词,然后形成词库进行切分,所以它会切分的粒度会更大些。当前各个大模型的分词器基本都是基于该工具实现的。

由于原生LLaMa的训练语料大部分都是英文,中文语料相较较少,使得模型对中文编解码效率不高,扩充LLaMa中文词表可有效提升LLaMa对中文的编解码效率,此外扩充中文词表还提高了模型的上下文窗口长度。

 

安装sentencepiece

 

pip install sentencepiece

训练词表代码:

import sentencepiece as spm
spm.SentencePieceTrainer.train(input='./file_name.txt',input_format='text',model_prefix='bpe_test',model_type='bpe',vocab_size=10000,character_coverage=0.9995,num_threads=32,split_digits=True,byte_fallback=True, max_sentence_length=24000)

该代码运行后会在当前目录下生成三个文件,bpe_test.modelbpe_test.vocabbpe_test.log

 

合并LLaMa词表代码:

 

 1 import os
 2 os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"]="python"
 3 from transformers import LlamaTokenizer
 4 from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
 5 import sentencepiece as spm
 6 
 7 # 位置
 8 llama_tokenizer_dir = "/mnt/models/Baichuan-7B" # 换成你自己模型的位置
 9 chinese_sp_model_file ="./bpe_test.model" # 刚才训练的模型
10 
11 # 加载
12 llama_tokenizer = LlamaTokenizer.from_pretrained(llama_tokenizer_dir)
13 chinese_sp_model = spm.SentencePieceProcessor()
14 chinese_sp_model.Load(chinese_sp_model_file)
15 llama_spm = sp_pb2_model.ModelProto()
16 llama_spm.ParseFromString(llama_tokenizer.sp_model.serialized_model_proto())
17 chinese_spm = sp_pb2_model.ModelProto()
18 chinese_spm.ParseFromString(chinese_sp_model.serialized_model_proto())
19 
20 
21 # 打印两个词表的大小和原llama的特殊token
22 print(len(llama_tokenizer),len(chinese_sp_model))
23 print(llama_tokenizer.all_special_tokens)
24 print(llama_tokenizer.all_special_ids)
25 print(llama_tokenizer.special_tokens_map)
26 
27 
28 # 开始往llama词表里添加,这里你也可以直接加入你想要加入词表的词,或者是领域内的特殊词
29 llama_spm_tokens_set=set(p.piece for p in llama_spm.pieces)
30 print(len(llama_spm_tokens_set))
31 print(f"Before:{len(llama_spm_tokens_set)}")
32 for p in chinese_spm.pieces:
33     piece = p.piece
34     if piece not in llama_spm_tokens_set:
35         new_p = sp_pb2_model.ModelProto().SentencePiece()
36         new_p.piece = piece
37         new_p.score = 0
38         llama_spm.pieces.append(new_p)
39 print(f"New model pieces: {len(llama_spm.pieces)}")
40 
41 # 保存合并后的模型
42 output_sp_dir = 'merged_tokenizer_sp_test'
43 output_hf_dir = 'merged_tokenizer_hf_test'
44 os.makedirs(output_sp_dir,exist_ok=True)
45 with open(output_sp_dir+'/chinese_llama.model', 'wb') as f:
46     f.write(llama_spm.SerializeToString())
47 tokenizer = LlamaTokenizer(vocab_file=output_sp_dir+'/chinese_llama.model')
48 
49 tokenizer.save_pretrained(output_hf_dir)
50 print(f"Chinese-LLaMA tokenizer has been saved to {output_hf_dir}")
51 
52 # 看一下效果
53 llama_tokenizer = LlamaTokenizer.from_pretrained(llama_tokenizer_dir)
54 chinese_llama_tokenizer = LlamaTokenizer.from_pretrained(output_hf_dir)
55 
56 
57 text = "The excellence of a translation can only be judged by noting"
58 print("Test text:\n",text)
59 print(f"Tokenized by LLaMA tokenizer:{llama_tokenizer.tokenize(text)}")
60 print(f"Tokenized length by LLaMA tokenizer:{len(llama_tokenizer.tokenize(text))}")
61 print(f"Tokenized by chinese_llama tokenizer:{chinese_llama_tokenizer.tokenize(text)}")
62 print(f"Tokenized length by LLaMA-extent-1 tokenizer:{len(chinese_llama_tokenizer.tokenize(text))}")
63 
64 text = "麒麟,是中国古代神话中的一种瑞兽"
65 print("Test text:\n",text)
66 print(f"Tokenized by LLaMA tokenizer:{llama_tokenizer.tokenize(text)}")
67 print(f"Tokenized length by LLaMA tokenizer:{len(llama_tokenizer.tokenize(text))}")
68 print(f"Tokenized by chinese_llama tokenizer:{chinese_llama_tokenizer.tokenize(text)}")
69 print(f"Tokenized length by chinese_llama tokenizer:{len(chinese_llama_tokenizer.tokenize(text))}")