镜像自地址
https://github.com/SCIR-HI/Huatuo-Llama-Med-Chinese.git
已同步 2025-12-06 14:36:49 +00:00
init code
这个提交包含在:
15
.gitignore
vendored
普通文件
15
.gitignore
vendored
普通文件
@@ -0,0 +1,15 @@
|
|||||||
|
out/
|
||||||
|
7B/
|
||||||
|
13B/
|
||||||
|
__pycache__/
|
||||||
|
checkpoint**
|
||||||
|
minimal-llama**
|
||||||
|
lora-llama-med
|
||||||
|
upload.py
|
||||||
|
lora-**
|
||||||
|
*.out
|
||||||
|
*result
|
||||||
|
*ckpt
|
||||||
|
wandb
|
||||||
|
todo.txt
|
||||||
|
.vscode/
|
||||||
2
LICENSE
2
LICENSE
@@ -186,7 +186,7 @@
|
|||||||
same "printed page" as the copyright notice for easier
|
same "printed page" as the copyright notice for easier
|
||||||
identification within third-party archives.
|
identification within third-party archives.
|
||||||
|
|
||||||
Copyright [yyyy] [name of copyright owner]
|
Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[**中文**](./README.md) | [**English**](./README_EN.md)
|
[**中文**](./README.md) | [**English**](./README_EN.md)
|
||||||
|
|
||||||
# 华驼(HuaTuo): 基于中文医学知识的Llama微调模型
|
# 华驼(HuaTuo): 基于中文医学知识的LLaMA微调模型
|
||||||
### HuaTuo: Tuning LLaMA Model With Chinese Medical Instructions
|
### HuaTuo: Tuning LLaMA Model With Chinese Medical Instructions
|
||||||
[](https://github.com/SCIR-HI/Huatuo-Llama-Med-Chinese/blob/main/LICENSE)
|
[](https://github.com/SCIR-HI/Huatuo-Llama-Med-Chinese/blob/main/LICENSE)
|
||||||
[](https://www.python.org/downloads/release/python-390/)
|
[](https://www.python.org/downloads/release/python-390/)
|
||||||
|
|||||||
5
data/infer.json
普通文件
5
data/infer.json
普通文件
@@ -0,0 +1,5 @@
|
|||||||
|
{"instruction": "小张最近感觉身体不适,出现心悸、气促等症状。体检发现心脏扩大、搏动减弱。", "input": "", "output": "小张可能患有心肌炎,建议进行心电图和心脏超声等检查来确定诊断。治疗方案包括使用泼尼松、生脉饮和三磷酸腺苷等药物,同时建议适当控制体温,保持良好的营养状况。"}
|
||||||
|
{"instruction": "小王被确诊为肝炎双重感染,最可能的并发症是什么?", "input": "", "output": "最可能的并发症是肝癌和肝硬化。"}
|
||||||
|
{"instruction": "麻风病和儿童哮喘的病因是否一致?", "input": "", "output": "不一致,麻风病的病因是麻风分枝杆菌,而儿童哮喘的病因是气候、药物、吸入过敏原等。"}
|
||||||
|
{"instruction": "一位患者出现了持续性的胀痛和黄疸,经影像学检查后确诊为肝胆管结石病,应该采取哪些治疗方案?", "input": "", "output": "可以考虑采用残石的处理、微创保肝取石术或手术治疗等治疗方案。"}
|
||||||
|
{"instruction": "一位35岁女性患者出现了关节疼痛和糖耐量异常,请问她可能患了什么疾病?", "input": "", "output": "该患者可能患上了慢性自身免疫性胰腺炎,伴有慢性风湿性关节炎和糖耐量异常的症状。建议到消化内科进行检查和治疗。"}
|
||||||
8658
data/llama_data.json
普通文件
8658
data/llama_data.json
普通文件
文件差异内容过多而无法显示
加载差异
19
data/tmp.py
普通文件
19
data/tmp.py
普通文件
@@ -0,0 +1,19 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
# 假设数据存储在 data.json 文件中
|
||||||
|
data = []
|
||||||
|
with open('llama_data.json', 'r') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
data.append(json.loads(line))
|
||||||
|
|
||||||
|
# 遍历每个对象,找到 output 属性不为字符串类型的对象
|
||||||
|
new_data = []
|
||||||
|
for obj in data:
|
||||||
|
if isinstance(obj['output'], str):
|
||||||
|
new_data.append(obj)
|
||||||
|
with open("llama_data_1.json","w") as f:
|
||||||
|
for n in new_data:
|
||||||
|
f.write(json.dumps(n,ensure_ascii=False))
|
||||||
|
f.write("\n")
|
||||||
57
export_hf_checkpoint.py
普通文件
57
export_hf_checkpoint.py
普通文件
@@ -0,0 +1,57 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from peft import PeftModel
|
||||||
|
from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402
|
||||||
|
|
||||||
|
BASE_MODEL = os.environ.get("BASE_MODEL", None)
|
||||||
|
assert (
|
||||||
|
BASE_MODEL
|
||||||
|
), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=decapoda-research/llama-7b-hf`" # noqa: E501
|
||||||
|
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
|
||||||
|
|
||||||
|
base_model = LlamaForCausalLM.from_pretrained(
|
||||||
|
BASE_MODEL,
|
||||||
|
load_in_8bit=False,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device_map={"": "cpu"},
|
||||||
|
)
|
||||||
|
|
||||||
|
first_weight = base_model.model.layers[0].self_attn.q_proj.weight
|
||||||
|
first_weight_old = first_weight.clone()
|
||||||
|
|
||||||
|
lora_model = PeftModel.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
"tloen/alpaca-lora-7b",
|
||||||
|
device_map={"": "cpu"},
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
|
||||||
|
lora_weight = lora_model.base_model.model.model.layers[
|
||||||
|
0
|
||||||
|
].self_attn.q_proj.weight
|
||||||
|
|
||||||
|
assert torch.allclose(first_weight_old, first_weight)
|
||||||
|
|
||||||
|
# merge weights
|
||||||
|
for layer in lora_model.base_model.model.model.layers:
|
||||||
|
layer.self_attn.q_proj.merge_weights = True
|
||||||
|
layer.self_attn.v_proj.merge_weights = True
|
||||||
|
|
||||||
|
lora_model.train(False)
|
||||||
|
|
||||||
|
# did we do anything?
|
||||||
|
assert not torch.allclose(first_weight_old, first_weight)
|
||||||
|
|
||||||
|
lora_model_sd = lora_model.state_dict()
|
||||||
|
deloreanized_sd = {
|
||||||
|
k.replace("base_model.model.", ""): v
|
||||||
|
for k, v in lora_model_sd.items()
|
||||||
|
if "lora" not in k
|
||||||
|
}
|
||||||
|
|
||||||
|
LlamaForCausalLM.save_pretrained(
|
||||||
|
base_model, "./hf_ckpt", state_dict=deloreanized_sd, max_shard_size="400MB"
|
||||||
|
)
|
||||||
125
export_state_dict_checkpoint.py
普通文件
125
export_state_dict_checkpoint.py
普通文件
@@ -0,0 +1,125 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from peft import PeftModel
|
||||||
|
from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: E402
|
||||||
|
|
||||||
|
BASE_MODEL = os.environ.get("BASE_MODEL", None)
|
||||||
|
assert (
|
||||||
|
BASE_MODEL
|
||||||
|
), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=decapoda-research/llama-7b-hf`" # noqa: E501
|
||||||
|
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
|
||||||
|
|
||||||
|
base_model = LlamaForCausalLM.from_pretrained(
|
||||||
|
BASE_MODEL,
|
||||||
|
load_in_8bit=False,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device_map={"": "cpu"},
|
||||||
|
)
|
||||||
|
|
||||||
|
lora_model = PeftModel.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
"tloen/alpaca-lora-7b",
|
||||||
|
device_map={"": "cpu"},
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
|
||||||
|
# merge weights
|
||||||
|
for layer in lora_model.base_model.model.model.layers:
|
||||||
|
layer.self_attn.q_proj.merge_weights = True
|
||||||
|
layer.self_attn.v_proj.merge_weights = True
|
||||||
|
|
||||||
|
lora_model.train(False)
|
||||||
|
|
||||||
|
lora_model_sd = lora_model.state_dict()
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"dim": 4096,
|
||||||
|
"multiple_of": 256,
|
||||||
|
"n_heads": 32,
|
||||||
|
"n_layers": 32,
|
||||||
|
"norm_eps": 1e-06,
|
||||||
|
"vocab_size": -1,
|
||||||
|
}
|
||||||
|
n_layers = params["n_layers"]
|
||||||
|
n_heads = params["n_heads"]
|
||||||
|
dim = params["dim"]
|
||||||
|
dims_per_head = dim // n_heads
|
||||||
|
base = 10000.0
|
||||||
|
inv_freq = 1.0 / (
|
||||||
|
base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def permute(w):
|
||||||
|
return (
|
||||||
|
w.view(n_heads, dim // n_heads // 2, 2, dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.reshape(dim, dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def unpermute(w):
|
||||||
|
return (
|
||||||
|
w.view(n_heads, 2, dim // n_heads // 2, dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.reshape(dim, dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def translate_state_dict_key(k): # noqa: C901
|
||||||
|
k = k.replace("base_model.model.", "")
|
||||||
|
if k == "model.embed_tokens.weight":
|
||||||
|
return "tok_embeddings.weight"
|
||||||
|
elif k == "model.norm.weight":
|
||||||
|
return "norm.weight"
|
||||||
|
elif k == "lm_head.weight":
|
||||||
|
return "output.weight"
|
||||||
|
elif k.startswith("model.layers."):
|
||||||
|
layer = k.split(".")[2]
|
||||||
|
if k.endswith(".self_attn.q_proj.weight"):
|
||||||
|
return f"layers.{layer}.attention.wq.weight"
|
||||||
|
elif k.endswith(".self_attn.k_proj.weight"):
|
||||||
|
return f"layers.{layer}.attention.wk.weight"
|
||||||
|
elif k.endswith(".self_attn.v_proj.weight"):
|
||||||
|
return f"layers.{layer}.attention.wv.weight"
|
||||||
|
elif k.endswith(".self_attn.o_proj.weight"):
|
||||||
|
return f"layers.{layer}.attention.wo.weight"
|
||||||
|
elif k.endswith(".mlp.gate_proj.weight"):
|
||||||
|
return f"layers.{layer}.feed_forward.w1.weight"
|
||||||
|
elif k.endswith(".mlp.down_proj.weight"):
|
||||||
|
return f"layers.{layer}.feed_forward.w2.weight"
|
||||||
|
elif k.endswith(".mlp.up_proj.weight"):
|
||||||
|
return f"layers.{layer}.feed_forward.w3.weight"
|
||||||
|
elif k.endswith(".input_layernorm.weight"):
|
||||||
|
return f"layers.{layer}.attention_norm.weight"
|
||||||
|
elif k.endswith(".post_attention_layernorm.weight"):
|
||||||
|
return f"layers.{layer}.ffn_norm.weight"
|
||||||
|
elif k.endswith("rotary_emb.inv_freq") or "lora" in k:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
print(layer, k)
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
print(k)
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
new_state_dict = {}
|
||||||
|
for k, v in lora_model_sd.items():
|
||||||
|
new_k = translate_state_dict_key(k)
|
||||||
|
if new_k is not None:
|
||||||
|
if "wq" in new_k or "wk" in new_k:
|
||||||
|
new_state_dict[new_k] = unpermute(v)
|
||||||
|
else:
|
||||||
|
new_state_dict[new_k] = v
|
||||||
|
|
||||||
|
os.makedirs("./ckpt", exist_ok=True)
|
||||||
|
|
||||||
|
torch.save(new_state_dict, "./ckpt/consolidated.00.pth")
|
||||||
|
|
||||||
|
with open("./ckpt/params.json", "w") as f:
|
||||||
|
json.dump(params, f)
|
||||||
277
finetune.py
普通文件
277
finetune.py
普通文件
@@ -0,0 +1,277 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import wandb
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
"""
|
||||||
|
Unused imports:
|
||||||
|
import torch.nn as nn
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
"""
|
||||||
|
|
||||||
|
from peft import (
|
||||||
|
LoraConfig,
|
||||||
|
get_peft_model,
|
||||||
|
get_peft_model_state_dict,
|
||||||
|
prepare_model_for_int8_training,
|
||||||
|
set_peft_model_state_dict,
|
||||||
|
)
|
||||||
|
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||||
|
|
||||||
|
from utils.prompter import Prompter
|
||||||
|
|
||||||
|
|
||||||
|
def train(
|
||||||
|
# model/data params
|
||||||
|
base_model: str = "", # the only required argument
|
||||||
|
data_path: str = "yahma/alpaca-cleaned",
|
||||||
|
output_dir: str = "./lora-alpaca",
|
||||||
|
# training hyperparams
|
||||||
|
batch_size: int = 128,
|
||||||
|
micro_batch_size: int = 8,
|
||||||
|
num_epochs: int = 10,
|
||||||
|
learning_rate: float = 3e-4,
|
||||||
|
cutoff_len: int = 256,
|
||||||
|
val_set_size: int = 500,
|
||||||
|
# lora hyperparams
|
||||||
|
lora_r: int = 8,
|
||||||
|
lora_alpha: int = 16,
|
||||||
|
lora_dropout: float = 0.05,
|
||||||
|
lora_target_modules: List[str] = [
|
||||||
|
"q_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
# llm hyperparams
|
||||||
|
train_on_inputs: bool = False, # if False, masks out inputs in loss
|
||||||
|
group_by_length: bool = False, # faster, but produces an odd training loss curve
|
||||||
|
# wandb params
|
||||||
|
wandb_project: str = "llama_med",
|
||||||
|
wandb_run_name: str = "",
|
||||||
|
wandb_watch: str = "", # options: false | gradients | all
|
||||||
|
wandb_log_model: str = "", # options: false | true
|
||||||
|
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
||||||
|
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
|
||||||
|
):
|
||||||
|
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
|
||||||
|
print(
|
||||||
|
f"Training Alpaca-LoRA model with params:\n"
|
||||||
|
f"base_model: {base_model}\n"
|
||||||
|
f"data_path: {data_path}\n"
|
||||||
|
f"output_dir: {output_dir}\n"
|
||||||
|
f"batch_size: {batch_size}\n"
|
||||||
|
f"micro_batch_size: {micro_batch_size}\n"
|
||||||
|
f"num_epochs: {num_epochs}\n"
|
||||||
|
f"learning_rate: {learning_rate}\n"
|
||||||
|
f"cutoff_len: {cutoff_len}\n"
|
||||||
|
f"val_set_size: {val_set_size}\n"
|
||||||
|
f"lora_r: {lora_r}\n"
|
||||||
|
f"lora_alpha: {lora_alpha}\n"
|
||||||
|
f"lora_dropout: {lora_dropout}\n"
|
||||||
|
f"lora_target_modules: {lora_target_modules}\n"
|
||||||
|
f"train_on_inputs: {train_on_inputs}\n"
|
||||||
|
f"group_by_length: {group_by_length}\n"
|
||||||
|
f"wandb_project: {wandb_project}\n"
|
||||||
|
f"wandb_run_name: {wandb_run_name}\n"
|
||||||
|
f"wandb_watch: {wandb_watch}\n"
|
||||||
|
f"wandb_log_model: {wandb_log_model}\n"
|
||||||
|
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
|
||||||
|
f"prompt template: {prompt_template_name}\n"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
base_model
|
||||||
|
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|
||||||
|
gradient_accumulation_steps = batch_size // micro_batch_size
|
||||||
|
|
||||||
|
prompter = Prompter(prompt_template_name)
|
||||||
|
|
||||||
|
device_map = "auto"
|
||||||
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
|
ddp = world_size != 1
|
||||||
|
if ddp:
|
||||||
|
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
|
||||||
|
gradient_accumulation_steps = gradient_accumulation_steps // world_size
|
||||||
|
|
||||||
|
# Check if parameter passed or if set within environ
|
||||||
|
use_wandb = len(wandb_project) > 0 or (
|
||||||
|
"WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
|
||||||
|
)
|
||||||
|
# Only overwrite environ if wandb param passed
|
||||||
|
if len(wandb_project) > 0:
|
||||||
|
os.environ["WANDB_PROJECT"] = wandb_project
|
||||||
|
if len(wandb_watch) > 0:
|
||||||
|
os.environ["WANDB_WATCH"] = wandb_watch
|
||||||
|
if len(wandb_log_model) > 0:
|
||||||
|
os.environ["WANDB_LOG_MODEL"] = wandb_log_model
|
||||||
|
|
||||||
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
load_in_8bit=True,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device_map=device_map,
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(base_model)
|
||||||
|
|
||||||
|
tokenizer.pad_token_id = (
|
||||||
|
0 # unk. we want this to be different from the eos token
|
||||||
|
)
|
||||||
|
tokenizer.padding_side = "left" # Allow batched inference
|
||||||
|
|
||||||
|
def tokenize(prompt, add_eos_token=True):
|
||||||
|
# there's probably a way to do this with the tokenizer settings
|
||||||
|
# but again, gotta move fast
|
||||||
|
result = tokenizer(
|
||||||
|
prompt,
|
||||||
|
truncation=True,
|
||||||
|
max_length=cutoff_len,
|
||||||
|
padding=False,
|
||||||
|
return_tensors=None,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
result["input_ids"][-1] != tokenizer.eos_token_id
|
||||||
|
and len(result["input_ids"]) < cutoff_len
|
||||||
|
and add_eos_token
|
||||||
|
):
|
||||||
|
result["input_ids"].append(tokenizer.eos_token_id)
|
||||||
|
result["attention_mask"].append(1)
|
||||||
|
|
||||||
|
result["labels"] = result["input_ids"].copy()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def generate_and_tokenize_prompt(data_point):
|
||||||
|
full_prompt = prompter.generate_prompt(
|
||||||
|
data_point["instruction"],
|
||||||
|
data_point["input"],
|
||||||
|
data_point["output"],
|
||||||
|
)
|
||||||
|
tokenized_full_prompt = tokenize(full_prompt)
|
||||||
|
if not train_on_inputs:
|
||||||
|
user_prompt = prompter.generate_prompt(
|
||||||
|
data_point["instruction"], data_point["input"]
|
||||||
|
)
|
||||||
|
tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
|
||||||
|
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
||||||
|
|
||||||
|
tokenized_full_prompt["labels"] = [
|
||||||
|
-100
|
||||||
|
] * user_prompt_len + tokenized_full_prompt["labels"][
|
||||||
|
user_prompt_len:
|
||||||
|
] # could be sped up, probably
|
||||||
|
return tokenized_full_prompt
|
||||||
|
|
||||||
|
model = prepare_model_for_int8_training(model)
|
||||||
|
|
||||||
|
config = LoraConfig(
|
||||||
|
r=lora_r,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
target_modules=lora_target_modules,
|
||||||
|
lora_dropout=lora_dropout,
|
||||||
|
bias="none",
|
||||||
|
task_type="CAUSAL_LM",
|
||||||
|
)
|
||||||
|
model = get_peft_model(model, config)
|
||||||
|
|
||||||
|
if data_path.endswith(".json") or data_path.endswith(".jsonl"):
|
||||||
|
data = load_dataset("json", data_files=data_path)
|
||||||
|
else:
|
||||||
|
data = load_dataset(data_path)
|
||||||
|
|
||||||
|
if resume_from_checkpoint:
|
||||||
|
# Check the available weights and load them
|
||||||
|
checkpoint_name = os.path.join(
|
||||||
|
resume_from_checkpoint, "pytorch_model.bin"
|
||||||
|
) # Full checkpoint
|
||||||
|
if not os.path.exists(checkpoint_name):
|
||||||
|
checkpoint_name = os.path.join(
|
||||||
|
resume_from_checkpoint, "adapter_model.bin"
|
||||||
|
) # only LoRA model - LoRA config above has to fit
|
||||||
|
resume_from_checkpoint = (
|
||||||
|
False # So the trainer won't try loading its state
|
||||||
|
)
|
||||||
|
# The two files above have a different name depending on how they were saved, but are actually the same.
|
||||||
|
if os.path.exists(checkpoint_name):
|
||||||
|
print(f"Restarting from {checkpoint_name}")
|
||||||
|
adapters_weights = torch.load(checkpoint_name)
|
||||||
|
model = set_peft_model_state_dict(model, adapters_weights)
|
||||||
|
else:
|
||||||
|
print(f"Checkpoint {checkpoint_name} not found")
|
||||||
|
|
||||||
|
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
|
||||||
|
|
||||||
|
if val_set_size > 0:
|
||||||
|
train_val = data["train"].train_test_split(
|
||||||
|
test_size=val_set_size, shuffle=True, seed=2023
|
||||||
|
)
|
||||||
|
train_data = (
|
||||||
|
train_val["train"].shuffle().map(generate_and_tokenize_prompt)
|
||||||
|
)
|
||||||
|
val_data = (
|
||||||
|
train_val["test"].shuffle().map(generate_and_tokenize_prompt)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
|
||||||
|
val_data = None
|
||||||
|
|
||||||
|
if not ddp and torch.cuda.device_count() > 1:
|
||||||
|
# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
|
||||||
|
model.is_parallelizable = True
|
||||||
|
model.model_parallel = True
|
||||||
|
|
||||||
|
trainer = transformers.Trainer(
|
||||||
|
model=model,
|
||||||
|
train_dataset=train_data,
|
||||||
|
eval_dataset=val_data,
|
||||||
|
args=transformers.TrainingArguments(
|
||||||
|
per_device_train_batch_size=micro_batch_size,
|
||||||
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||||
|
warmup_ratio=0.1,
|
||||||
|
num_train_epochs=num_epochs,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
fp16=True,
|
||||||
|
logging_steps=8,
|
||||||
|
optim="adamw_torch",
|
||||||
|
evaluation_strategy="steps" if val_set_size > 0 else "no",
|
||||||
|
save_strategy="steps",
|
||||||
|
eval_steps=32 if val_set_size > 0 else None,
|
||||||
|
save_steps=32,
|
||||||
|
output_dir=output_dir,
|
||||||
|
save_total_limit=5,
|
||||||
|
load_best_model_at_end=True if val_set_size > 0 else False,
|
||||||
|
ddp_find_unused_parameters=False if ddp else None,
|
||||||
|
group_by_length=group_by_length,
|
||||||
|
report_to="wandb" if use_wandb else None,
|
||||||
|
run_name=wandb_run_name if use_wandb else None,
|
||||||
|
),
|
||||||
|
data_collator=transformers.DataCollatorForSeq2Seq(
|
||||||
|
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
model.config.use_cache = False
|
||||||
|
|
||||||
|
old_state_dict = model.state_dict
|
||||||
|
model.state_dict = (
|
||||||
|
lambda self, *_, **__: get_peft_model_state_dict(
|
||||||
|
self, old_state_dict()
|
||||||
|
)
|
||||||
|
).__get__(model, type(model))
|
||||||
|
|
||||||
|
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||||
|
model = torch.compile(model)
|
||||||
|
|
||||||
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
|
|
||||||
|
model.save_pretrained(output_dir)
|
||||||
|
|
||||||
|
print(
|
||||||
|
"\n If there's a warning about missing keys above, please disregard :)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(train)
|
||||||
172
generate.py
普通文件
172
generate.py
普通文件
@@ -0,0 +1,172 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import gradio as gr
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from peft import PeftModel
|
||||||
|
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
|
||||||
|
|
||||||
|
from utils.prompter import Prompter
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
try:
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
device = "mps"
|
||||||
|
except: # noqa: E722
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
load_8bit: bool = False,
|
||||||
|
base_model: str = "",
|
||||||
|
lora_weights: str = "tloen/alpaca-lora-7b",
|
||||||
|
prompt_template: str = "med_template", # The prompt template to use, will default to alpaca.
|
||||||
|
server_name: str = "0.0.0.0", # Allows to listen on all interfaces by providing '0.0.0.0'
|
||||||
|
share_gradio: bool = True,
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
base_model
|
||||||
|
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|
||||||
|
|
||||||
|
prompter = Prompter(prompt_template)
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(base_model)
|
||||||
|
if device == "cuda":
|
||||||
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
load_in_8bit=load_8bit,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device_map="auto",
|
||||||
|
)
|
||||||
|
model = PeftModel.from_pretrained(
|
||||||
|
model,
|
||||||
|
lora_weights,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
elif device == "mps":
|
||||||
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
device_map={"": device},
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
model = PeftModel.from_pretrained(
|
||||||
|
model,
|
||||||
|
lora_weights,
|
||||||
|
device_map={"": device},
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
|
base_model, device_map={"": device}, low_cpu_mem_usage=True
|
||||||
|
)
|
||||||
|
model = PeftModel.from_pretrained(
|
||||||
|
model,
|
||||||
|
lora_weights,
|
||||||
|
device_map={"": device},
|
||||||
|
)
|
||||||
|
|
||||||
|
# unwind broken decapoda-research config
|
||||||
|
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
|
||||||
|
model.config.bos_token_id = 1
|
||||||
|
model.config.eos_token_id = 2
|
||||||
|
|
||||||
|
if not load_8bit:
|
||||||
|
model.half() # seems to fix bugs for some users.
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||||
|
model = torch.compile(model)
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
instruction,
|
||||||
|
input=None,
|
||||||
|
temperature=0.1,
|
||||||
|
top_p=0.75,
|
||||||
|
top_k=40,
|
||||||
|
num_beams=4,
|
||||||
|
max_new_tokens=128,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
prompt = prompter.generate_prompt(instruction, input)
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
input_ids = inputs["input_ids"].to(device)
|
||||||
|
generation_config = GenerationConfig(
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
num_beams=num_beams,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
generation_output = model.generate(
|
||||||
|
input_ids=input_ids,
|
||||||
|
generation_config=generation_config,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
output_scores=True,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
)
|
||||||
|
s = generation_output.sequences[0]
|
||||||
|
output = tokenizer.decode(s)
|
||||||
|
return prompter.get_response(output)
|
||||||
|
|
||||||
|
gr.Interface(
|
||||||
|
fn=evaluate,
|
||||||
|
inputs=[
|
||||||
|
gr.components.Textbox(
|
||||||
|
lines=2,
|
||||||
|
label="Instruction",
|
||||||
|
placeholder="Tell me about alpacas.",
|
||||||
|
),
|
||||||
|
gr.components.Textbox(lines=2, label="Input", placeholder="none"),
|
||||||
|
gr.components.Slider(
|
||||||
|
minimum=0, maximum=1, value=0.1, label="Temperature"
|
||||||
|
),
|
||||||
|
gr.components.Slider(
|
||||||
|
minimum=0, maximum=1, value=0.75, label="Top p"
|
||||||
|
),
|
||||||
|
gr.components.Slider(
|
||||||
|
minimum=0, maximum=100, step=1, value=40, label="Top k"
|
||||||
|
),
|
||||||
|
gr.components.Slider(
|
||||||
|
minimum=1, maximum=4, step=1, value=4, label="Beams"
|
||||||
|
),
|
||||||
|
gr.components.Slider(
|
||||||
|
minimum=1, maximum=2000, step=1, value=128, label="Max tokens"
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
gr.inputs.Textbox(
|
||||||
|
lines=5,
|
||||||
|
label="Output",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
title="🦙🌲 Alpaca-LoRA",
|
||||||
|
description="Alpaca-LoRA is a 7B-parameter LLaMA model finetuned to follow instructions. It is trained on the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset and makes use of the Huggingface LLaMA implementation. For more information, please visit [the project's website](https://github.com/tloen/alpaca-lora).", # noqa: E501
|
||||||
|
).launch(server_name=server_name, share=share_gradio)
|
||||||
|
# Old testing code follows.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# testing code for readme
|
||||||
|
for instruction in [
|
||||||
|
"Tell me about alpacas.",
|
||||||
|
"Tell me about the president of Mexico in 2019.",
|
||||||
|
"Tell me about the king of France in 2019.",
|
||||||
|
"List all Canadian provinces in alphabetical order.",
|
||||||
|
"Write a Python program that prints the first 10 Fibonacci numbers.",
|
||||||
|
"Write a program that prints the numbers from 1 to 100. But for multiples of three print 'Fizz' instead of the number and for the multiples of five print 'Buzz'. For numbers which are multiples of both three and five print 'FizzBuzz'.", # noqa: E501
|
||||||
|
"Tell me five words that rhyme with 'shock'.",
|
||||||
|
"Translate the sentence 'I have no mouth but I must scream' into Spanish.",
|
||||||
|
"Count up from 1 to 500.",
|
||||||
|
]:
|
||||||
|
print("Instruction:", instruction)
|
||||||
|
print("Response:", evaluate(instruction))
|
||||||
|
print()
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(main)
|
||||||
124
infer.py
普通文件
124
infer.py
普通文件
@@ -0,0 +1,124 @@
|
|||||||
|
import sys
|
||||||
|
import json
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import gradio as gr
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from peft import PeftModel
|
||||||
|
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
|
||||||
|
|
||||||
|
from utils.prompter import Prompter
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
def load_instruction(instruct_dir):
|
||||||
|
input_data = []
|
||||||
|
with open(instruct_dir, "r") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
d = json.loads(line)
|
||||||
|
input_data.append(d)
|
||||||
|
return input_data
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
load_8bit: bool = False,
|
||||||
|
base_model: str = "",
|
||||||
|
# the infer data, if not exists, infer the default instructions in code
|
||||||
|
instruct_dir: str = "",
|
||||||
|
use_lora: bool = True,
|
||||||
|
lora_weights: str = "tloen/alpaca-lora-7b",
|
||||||
|
# The prompt template to use, will default to alpaca.
|
||||||
|
prompt_template: str = "med_template",
|
||||||
|
):
|
||||||
|
prompter = Prompter(prompt_template)
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(base_model)
|
||||||
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
load_in_8bit=load_8bit,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device_map="auto",
|
||||||
|
)
|
||||||
|
if use_lora:
|
||||||
|
print(f"using lora {lora_weights}")
|
||||||
|
model = PeftModel.from_pretrained(
|
||||||
|
model,
|
||||||
|
lora_weights,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
# unwind broken decapoda-research config
|
||||||
|
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
|
||||||
|
model.config.bos_token_id = 1
|
||||||
|
model.config.eos_token_id = 2
|
||||||
|
if not load_8bit:
|
||||||
|
model.half() # seems to fix bugs for some users.
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||||
|
model = torch.compile(model)
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
instruction,
|
||||||
|
input=None,
|
||||||
|
temperature=0.1,
|
||||||
|
top_p=0.75,
|
||||||
|
top_k=40,
|
||||||
|
num_beams=4,
|
||||||
|
max_new_tokens=256,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
prompt = prompter.generate_prompt(instruction, input)
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
input_ids = inputs["input_ids"].to(device)
|
||||||
|
generation_config = GenerationConfig(
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
num_beams=num_beams,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
generation_output = model.generate(
|
||||||
|
input_ids=input_ids,
|
||||||
|
generation_config=generation_config,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
output_scores=True,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
)
|
||||||
|
s = generation_output.sequences[0]
|
||||||
|
output = tokenizer.decode(s)
|
||||||
|
return prompter.get_response(output)
|
||||||
|
|
||||||
|
def infer_from_json(instruct_dir):
|
||||||
|
input_data = load_instruction(instruct_dir)
|
||||||
|
for d in input_data:
|
||||||
|
instruction = d["instruction"]
|
||||||
|
output = d["output"]
|
||||||
|
print("###infering###")
|
||||||
|
model_output = evaluate(instruction)
|
||||||
|
print("###instruction###")
|
||||||
|
print(instruction)
|
||||||
|
print("###golden output###")
|
||||||
|
print(output)
|
||||||
|
print("###model output###")
|
||||||
|
print(model_output)
|
||||||
|
|
||||||
|
if instruct_dir != "":
|
||||||
|
infer_from_json(instruct_dir)
|
||||||
|
else:
|
||||||
|
for instruction in [
|
||||||
|
"一位50岁女性出现不适、厌油腻、肝囊肿等症状,检查后发现为胆囊癌,并且病情十分严重,应该如何进行治疗?",
|
||||||
|
"一个患有肝衰竭综合征的病人,除了常见的临床表现外,还有哪些特殊的体征?",
|
||||||
|
"急性阑尾炎和缺血性心脏病的多发群体有何不同?",
|
||||||
|
]:
|
||||||
|
print("Instruction:", instruction)
|
||||||
|
print("Response:", evaluate(instruction))
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(main)
|
||||||
12
requirements.txt
普通文件
12
requirements.txt
普通文件
@@ -0,0 +1,12 @@
|
|||||||
|
accelerate
|
||||||
|
appdirs
|
||||||
|
bitsandbytes
|
||||||
|
black
|
||||||
|
black[jupyter]
|
||||||
|
datasets
|
||||||
|
fire
|
||||||
|
git+https://github.com/huggingface/peft.git
|
||||||
|
git+https://github.com/huggingface/transformers.git
|
||||||
|
gradio
|
||||||
|
sentencepiece
|
||||||
|
wandb
|
||||||
11
scripts/finetune.sh
普通文件
11
scripts/finetune.sh
普通文件
@@ -0,0 +1,11 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
exp_tag="e1"
|
||||||
|
python finetune.py \
|
||||||
|
--base_model 'decapoda-research/llama-7b-hf' \
|
||||||
|
--data_path './data/llama_data.json' \
|
||||||
|
--output_dir './lora-llama-med-'$exp_tag \
|
||||||
|
--prompt_template_name 'med_template' \
|
||||||
|
--micro_batch_size 128 \
|
||||||
|
--batch_size 128 \
|
||||||
|
--wandb_run_name $exp_tag
|
||||||
12
scripts/infer.sh
普通文件
12
scripts/infer.sh
普通文件
@@ -0,0 +1,12 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
|
||||||
|
# If inferring with the llama model, set 'use_lora' to 'False' and 'prompt_template' to 'ori_template'.
|
||||||
|
# If inferring with the default alpaca model, set 'use_lora' to 'True', 'lora_weights' to 'tloen/alpaca-lora-7b', and 'prompt_template' to 'alpaca'.
|
||||||
|
# If inferring with the llama-med model, download the LORA weights and set 'lora_weights' to './lora-llama-med' (or the exact directory of LORA weights) and 'prompt_template' to 'med_template'.
|
||||||
|
|
||||||
|
python infer.py \
|
||||||
|
--base_model 'decapoda-research/llama-7b-hf' \
|
||||||
|
--lora_weights './lora-llama-med' \
|
||||||
|
--use_lora True \
|
||||||
|
--instruct_dir './data/infer.json' \
|
||||||
|
--prompt_template 'med_template'
|
||||||
46
templates/README.md
普通文件
46
templates/README.md
普通文件
@@ -0,0 +1,46 @@
|
|||||||
|
# Prompt templates
|
||||||
|
|
||||||
|
This directory contains template styles for the prompts used to finetune LoRA models.
|
||||||
|
|
||||||
|
## Format
|
||||||
|
|
||||||
|
A template is described via a JSON file with the following keys:
|
||||||
|
|
||||||
|
- `prompt_input`: The template to use when input is not None. Uses `{instruction}` and `{input}` placeholders.
|
||||||
|
- `prompt_no_input`: The template to use when input is None. Uses `{instruction}` placeholders.
|
||||||
|
- `description`: A short description of the template, with possible use cases.
|
||||||
|
- `response_split`: The text to use as separator when cutting real response from the model output.
|
||||||
|
|
||||||
|
No `{response}` placeholder was used, since the response is always the last element of the template and is just to be concatenated to the rest.
|
||||||
|
|
||||||
|
## Example template
|
||||||
|
|
||||||
|
The default template, used unless otherwise specified, is `alpaca.json`
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"description": "Template used by Alpaca-LoRA.",
|
||||||
|
"prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
|
||||||
|
"prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
|
||||||
|
"response_split": "### Response:"
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## Current templates
|
||||||
|
|
||||||
|
### alpaca
|
||||||
|
|
||||||
|
Default template used for generic LoRA fine tunes so far.
|
||||||
|
|
||||||
|
### alpaca_legacy
|
||||||
|
|
||||||
|
Legacy template used by the original alpaca repo, with no `\n` after the response field. Kept for reference and experiments.
|
||||||
|
|
||||||
|
### alpaca_short
|
||||||
|
|
||||||
|
A trimmed down alpaca template which seems to perform just as well and spare some tokens. Models created with the default template seem to be queryable by the short tempalte as well. More experiments are welcome.
|
||||||
|
|
||||||
|
### vigogne
|
||||||
|
|
||||||
|
The default alpaca template, translated to french. This template was used to train the "Vigogne" LoRA and is to be used to query it, or for extra fine tuning.
|
||||||
6
templates/alpaca.json
普通文件
6
templates/alpaca.json
普通文件
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"description": "Template used by Alpaca-LoRA.",
|
||||||
|
"prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
|
||||||
|
"prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
|
||||||
|
"response_split": "### Response:"
|
||||||
|
}
|
||||||
6
templates/alpaca_legacy.json
普通文件
6
templates/alpaca_legacy.json
普通文件
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"description": "Legacy template, used by Original Alpaca repository.",
|
||||||
|
"prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:",
|
||||||
|
"prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:",
|
||||||
|
"response_split": "### Response:"
|
||||||
|
}
|
||||||
6
templates/alpaca_short.json
普通文件
6
templates/alpaca_short.json
普通文件
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"description": "A shorter template to experiment with.",
|
||||||
|
"prompt_input": "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
|
||||||
|
"prompt_no_input": "### Instruction:\n{instruction}\n\n### Response:\n",
|
||||||
|
"response_split": "### Response:"
|
||||||
|
}
|
||||||
6
templates/med_template.json
普通文件
6
templates/med_template.json
普通文件
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"description": "Template used by Med Instruction Tuning",
|
||||||
|
"prompt_input": "下面是一个问题,运用医学知识来正确回答提问.\n### 问题:\n{instruction}\n### 回答:\n",
|
||||||
|
"prompt_no_input": "下面是一个问题,运用医学知识来正确回答提问.\n### 问题:\n{instruction}\n### 回答:\n",
|
||||||
|
"response_split": "### 回答:"
|
||||||
|
}
|
||||||
6
templates/ori_template.json
普通文件
6
templates/ori_template.json
普通文件
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"description": "Template used by Llama without sft",
|
||||||
|
"prompt_input": "问题:{instruction} 回答:",
|
||||||
|
"prompt_no_input": "问题:{instruction} 回答:",
|
||||||
|
"response_split": "回答:"
|
||||||
|
}
|
||||||
6
templates/vigogne.json
普通文件
6
templates/vigogne.json
普通文件
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"description": "French template, used by Vigogne for finetuning.",
|
||||||
|
"prompt_input": "Ci-dessous se trouve une instruction qui décrit une tâche, associée à une entrée qui fournit un contexte supplémentaire. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Entrée:\n{input}\n\n### Réponse:\n",
|
||||||
|
"prompt_no_input": "Ci-dessous se trouve une instruction qui décrit une tâche. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Réponse:\n",
|
||||||
|
"response_split": "### Réponse:"
|
||||||
|
}
|
||||||
7
utils/README.md
普通文件
7
utils/README.md
普通文件
@@ -0,0 +1,7 @@
|
|||||||
|
# Directory for helpers modules
|
||||||
|
|
||||||
|
## prompter.py
|
||||||
|
|
||||||
|
Prompter class, a template manager.
|
||||||
|
|
||||||
|
`from utils.prompter import Prompter`
|
||||||
0
utils/__init__.py
普通文件
0
utils/__init__.py
普通文件
51
utils/prompter.py
普通文件
51
utils/prompter.py
普通文件
@@ -0,0 +1,51 @@
|
|||||||
|
"""
|
||||||
|
A dedicated helper to manage templates and prompt building.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os.path as osp
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
|
class Prompter(object):
|
||||||
|
__slots__ = ("template", "_verbose")
|
||||||
|
|
||||||
|
def __init__(self, template_name: str = "", verbose: bool = False):
|
||||||
|
self._verbose = verbose
|
||||||
|
if not template_name:
|
||||||
|
# Enforce the default here, so the constructor can be called with '' and will not break.
|
||||||
|
template_name = "alpaca"
|
||||||
|
file_name = osp.join("templates", f"{template_name}.json")
|
||||||
|
if not osp.exists(file_name):
|
||||||
|
raise ValueError(f"Can't read {file_name}")
|
||||||
|
with open(file_name) as fp:
|
||||||
|
self.template = json.load(fp)
|
||||||
|
if self._verbose:
|
||||||
|
print(
|
||||||
|
f"Using prompt template {template_name}: {self.template['description']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_prompt(
|
||||||
|
self,
|
||||||
|
instruction: str,
|
||||||
|
input: Union[None, str] = None,
|
||||||
|
label: Union[None, str] = None,
|
||||||
|
) -> str:
|
||||||
|
# returns the full prompt from instruction and optional input
|
||||||
|
# if a label (=response, =output) is provided, it's also appended.
|
||||||
|
if input:
|
||||||
|
res = self.template["prompt_input"].format(
|
||||||
|
instruction=instruction, input=input
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
res = self.template["prompt_no_input"].format(
|
||||||
|
instruction=instruction
|
||||||
|
)
|
||||||
|
if label:
|
||||||
|
res = f"{res}{label}"
|
||||||
|
if self._verbose:
|
||||||
|
print(res)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def get_response(self, output: str) -> str:
|
||||||
|
return output.split(self.template["response_split"])[1].strip()
|
||||||
在新工单中引用
屏蔽一个用户