写在前面
原生的tigerbot似乎并不支持函数调用,于是我来支持一下
数据集
我在huggingface上找了个英文的数据集
https://huggingface.co/datasets/sadmoseby/sample-function-call
这里面包含了1k组的函数调用,这个数据集的特点如下:
1. 包含有单个/多个/没有函数调用的情形
2. 描述函数的json_schema与OpenAI格式的一致(但多函数情况下,并没有用列表框起来)
3. 数据虽然是多轮对话的数据,但是每一个都是一整条的数据,且每个的开头与tigerbot的头不太一致
数据转换
我写了一个数据转换的代码,具体任务如下:
1. 将多个函数时没有用列表格式框选的情况给修复了
2. 切分为了多轮的对话,有多条训练数据
3. 修改了开头的情况
代码如下
1 import re 2 3 import json 4 import re 5 6 # system_prompt中可能有多个函数,多个函数的话要转为标准的[]格式 7 def get_function_json(input_string): 8 # 使用正则表达式分割字符串,找出独立的 JSON 字符串 9 json_strings = re.findall(r'\{[\s\S]+?\}\s*(?=\{|$)', input_string) 10 11 # 解析每个 JSON 字符串并把它们加入到列表中 12 json_objects = [] 13 for json_str in json_strings: 14 input_string = input_string.replace(json_str, '') 15 try: 16 json_obj = json.loads(json_str) 17 json_objects.append(json_obj) 18 except json.JSONDecodeError as e: 19 print(f"Error decoding JSON: {e}") 20 # 打印结果或进行其他操作 21 if json_objects: 22 return input_string + json.dumps(json_objects, ensure_ascii=False, indent=4) 23 else: 24 return input_string 25 26 # 切分读入的数据 27 def split_string_with_keywords(s, keywords): 28 # 将关键词列表转化为正则表达式,使用括号捕获分隔符 29 # 比如 ['system', 'assistant'] 会被转换成 (system)|(assistant) 30 regex_pattern = '({})'.format('|'.join(map(re.escape, keywords))) 31 32 # 使用 re.split,它会返回包含分隔符的列表 33 parts = re.split(regex_pattern, s) 34 35 # 初始化结果列表 36 result = [] 37 38 # 存储上一个匹配到的关键词,初始时没有关键词 39 last_keyword = None 40 41 # 遍历分割后的列表 42 for part in parts: 43 # 如果当前部分是关键词,记录下来并继续下一轮循环 44 if part in keywords: 45 last_keyword = part 46 continue 47 # 如果当前部分不是关键词,且上一部分是关键词,则将其作为结果加入 48 if last_keyword: 49 result.append((last_keyword, part.strip())) 50 last_keyword = None # 重置关键词 51 52 return result 53 54 max_len = 0 55 56 57 def count_words_and_punctuation(s): 58 # 使用正则表达式来匹配单词和标点符号 59 # \w+ 匹配单词字符(字母、数字、下划线)出现一次或多次组成的单词 60 # | 表示或,用来分隔不同的匹配规则 61 # \s 表示空白字符 62 # [^\w\s] 匹配任意不是单词字符和不是空白字符的字符,即标点符号 63 matches = re.findall(r'\w+|[^\w\s]', s) 64 65 # 计算匹配项的数量,即单词和标点符号的总数 66 return len(matches) 67 68 def solve(input): 69 global max_len 70 max_len = max(max_len , count_words_and_punctuation(input)) 71 import json 72 # 基础替换 73 input = input.replace('<|endoftext|>', '') 74 75 replace_map = { 76 'SYSTEM:' : '\n\n### System:\n ', 77 'ASSISTANT:': '\n\n### Response:\n ', 78 'USER:': '\n\n### Instruction:\n ', 79 'FUNCTION RESPONSE:': '\n\n### Function:\n ' 80 } 81 82 data = split_string_with_keywords(input, list(replace_map.keys())) 83 84 # 更换函数的格式 85 if data[0][0] == 'SYSTEM:': 86 data[0] = (data[0][0], get_function_json(data[0][1])) 87 88 return_data = [] 89 train_str = '' 90 for element in data: 91 train_str += replace_map[element[0]] 92 if element[0] == 'ASSISTANT:': 93 return_data.append({ 94 "instruction": train_str, 95 "input": "", 96 "output": element[1] 97 }) 98 train_str += element[1] 99 100 return return_data 101 102 import pandas as pd 103 104 train_data = [] 105 106 # 读取Parquet文件 107 df = pd.read_parquet('train-00000-of-00001.parquet') 108 column_name = df.columns[0] 109 for value in df[column_name]: 110 train_data += solve(value) 111 112 with open('train_function_call.json', 'w', encoding='utf-8') as f: 113 json.dump(train_data, f, ensure_ascii=False, indent=4) 114 print(max_len)
启动训练
笔者依然在恒源云上,基于tigerbot-13b-chat-v5-4k进行训练。
考虑到vllm暂时不支持PEFT格式的adapter,此次依然采用了freeze训练。
为了尽可能地训练更多的层,笔者采用了单个A100-80G的显卡,这样可以在seq_len达到3072的情况下,训练10层的tranformer参数。
注意,此次的template和以前不太一样(因为有各种的function和自己添加的system),所以添加了一个新的模板
1 register_template( 2 name="null", 3 prefix=[ 4 "" 5 ], 6 prompt=[ 7 "{{query}}" 8 ], 9 system="", 10 sep=[] 11 )
训练命令如下
1 python src/train_bash.py \ 2 --stage sft \ 3 --model_name_or_path /hy-tmp/tigerbot-13b-chat-v5-4k \ 4 --do_train True \ 5 --finetuning_type freeze \ 6 --num_layer_trainable 10 \ 7 --template null \ 8 --dataset_dir data \ 9 --dataset train_function_call \ 10 --cutoff_len 3072 \ 11 --learning_rate 1e-4 \ 12 --num_train_epochs 1.0 \ 13 --per_device_train_batch_size 4 \ 14 --gradient_accumulation_steps 2 \ 15 --logging_steps 1 \ 16 --save_steps 10000 \ 17 --output_dir /hy-tmp/tigerbot-13b-function-call \ 18 --fp16 True \ 19 --plot_loss True \ 20 --overwrite_output_dir