基于TigerBot-13b训练其函数调用能力

发布时间 2023-12-23 16:06:32作者: AlphaInf

写在前面

原生的tigerbot似乎并不支持函数调用,于是我来支持一下

 

数据集

我在huggingface上找了个英文的数据集

https://huggingface.co/datasets/sadmoseby/sample-function-call

这里面包含了1k组的函数调用,这个数据集的特点如下:

1. 包含有单个/多个/没有函数调用的情形

2. 描述函数的json_schema与OpenAI格式的一致(但多函数情况下,并没有用列表框起来)

3. 数据虽然是多轮对话的数据,但是每一个都是一整条的数据,且每个的开头与tigerbot的头不太一致

 

数据转换

我写了一个数据转换的代码,具体任务如下:

1. 将多个函数时没有用列表格式框选的情况给修复了

2. 切分为了多轮的对话,有多条训练数据

3. 修改了开头的情况

代码如下

  1 import re
  2 
  3 import json
  4 import re
  5 
  6 # system_prompt中可能有多个函数,多个函数的话要转为标准的[]格式
  7 def get_function_json(input_string):
  8     # 使用正则表达式分割字符串,找出独立的 JSON 字符串
  9     json_strings = re.findall(r'\{[\s\S]+?\}\s*(?=\{|$)', input_string)
 10 
 11     # 解析每个 JSON 字符串并把它们加入到列表中
 12     json_objects = []
 13     for json_str in json_strings:
 14         input_string = input_string.replace(json_str, '')
 15         try:
 16             json_obj = json.loads(json_str)
 17             json_objects.append(json_obj)
 18         except json.JSONDecodeError as e:
 19             print(f"Error decoding JSON: {e}")
 20     # 打印结果或进行其他操作
 21     if json_objects:
 22         return input_string + json.dumps(json_objects, ensure_ascii=False, indent=4)
 23     else:
 24         return input_string
 25 
 26 # 切分读入的数据
 27 def split_string_with_keywords(s, keywords):
 28     # 将关键词列表转化为正则表达式,使用括号捕获分隔符
 29     # 比如 ['system', 'assistant'] 会被转换成 (system)|(assistant)
 30     regex_pattern = '({})'.format('|'.join(map(re.escape, keywords)))
 31 
 32     # 使用 re.split,它会返回包含分隔符的列表
 33     parts = re.split(regex_pattern, s)
 34 
 35     # 初始化结果列表
 36     result = []
 37 
 38     # 存储上一个匹配到的关键词,初始时没有关键词
 39     last_keyword = None
 40 
 41     # 遍历分割后的列表
 42     for part in parts:
 43         # 如果当前部分是关键词,记录下来并继续下一轮循环
 44         if part in keywords:
 45             last_keyword = part
 46             continue
 47         # 如果当前部分不是关键词,且上一部分是关键词,则将其作为结果加入
 48         if last_keyword:
 49             result.append((last_keyword, part.strip()))
 50             last_keyword = None  # 重置关键词
 51 
 52     return result
 53 
 54 max_len = 0
 55 
 56 
 57 def count_words_and_punctuation(s):
 58     # 使用正则表达式来匹配单词和标点符号
 59     # \w+ 匹配单词字符(字母、数字、下划线)出现一次或多次组成的单词
 60     # | 表示或,用来分隔不同的匹配规则
 61     # \s 表示空白字符
 62     # [^\w\s] 匹配任意不是单词字符和不是空白字符的字符,即标点符号
 63     matches = re.findall(r'\w+|[^\w\s]', s)
 64 
 65     # 计算匹配项的数量,即单词和标点符号的总数
 66     return len(matches)
 67 
 68 def solve(input):
 69     global max_len
 70     max_len = max(max_len , count_words_and_punctuation(input))
 71     import json
 72     # 基础替换
 73     input = input.replace('<|endoftext|>', '')
 74 
 75     replace_map = {
 76         'SYSTEM:' : '\n\n### System:\n ',
 77         'ASSISTANT:': '\n\n### Response:\n ',
 78         'USER:': '\n\n### Instruction:\n ',
 79         'FUNCTION RESPONSE:': '\n\n### Function:\n '
 80     }
 81 
 82     data = split_string_with_keywords(input, list(replace_map.keys()))
 83 
 84     # 更换函数的格式
 85     if data[0][0] == 'SYSTEM:':
 86         data[0] = (data[0][0], get_function_json(data[0][1]))
 87 
 88     return_data = []
 89     train_str = ''
 90     for element in data:
 91         train_str += replace_map[element[0]]
 92         if element[0] == 'ASSISTANT:':
 93             return_data.append({
 94                 "instruction": train_str,
 95                 "input": "",
 96                 "output": element[1]
 97             })
 98         train_str += element[1]
 99 
100     return return_data
101 
102 import pandas as pd
103 
104 train_data = []
105 
106 # 读取Parquet文件
107 df = pd.read_parquet('train-00000-of-00001.parquet')
108 column_name = df.columns[0]
109 for value in df[column_name]:
110     train_data += solve(value)
111 
112 with open('train_function_call.json', 'w', encoding='utf-8') as f:
113     json.dump(train_data, f, ensure_ascii=False, indent=4)
114 print(max_len)

 

启动训练

笔者依然在恒源云上,基于tigerbot-13b-chat-v5-4k进行训练。

考虑到vllm暂时不支持PEFT格式的adapter,此次依然采用了freeze训练。

为了尽可能地训练更多的层,笔者采用了单个A100-80G的显卡,这样可以在seq_len达到3072的情况下,训练10层的tranformer参数。

注意,此次的template和以前不太一样(因为有各种的function和自己添加的system),所以添加了一个新的模板

 1 register_template(
 2     name="null",
 3     prefix=[
 4         ""
 5     ],
 6     prompt=[
 7         "{{query}}"
 8     ],
 9     system="",
10     sep=[]
11 )

训练命令如下

 1 python src/train_bash.py \
 2     --stage sft \
 3     --model_name_or_path /hy-tmp/tigerbot-13b-chat-v5-4k \
 4     --do_train True \
 5     --finetuning_type freeze \
 6     --num_layer_trainable 10 \
 7     --template null \
 8     --dataset_dir data \
 9     --dataset train_function_call \
10     --cutoff_len 3072 \
11     --learning_rate 1e-4 \
12     --num_train_epochs 1.0 \
13     --per_device_train_batch_size 4 \
14     --gradient_accumulation_steps 2 \
15     --logging_steps 1 \
16     --save_steps 10000 \
17     --output_dir /hy-tmp/tigerbot-13b-function-call \
18     --fp16 True \
19     --plot_loss True \
20     --overwrite_output_dir