镜像自地址
https://github.com/SCIR-HI/Huatuo-Llama-Med-Chinese.git
已同步 2025-12-06 06:26:48 +00:00
@@ -227,7 +227,9 @@ https://wandb.ai/thinksoso/llama_med/runs/a5wgcnzt/overview?workspace=user-think
|
|||||||
4. Q: 模型运行的结果不同、效果有限
|
4. Q: 模型运行的结果不同、效果有限
|
||||||
|
|
||||||
A: 由于生成模型生成多样性的考量,多次运行的结果可能会有差异。当前开源的模型由于LLaMA及Alpaca中文语料有限,且知识结合的方式较为粗糙,目前我们在进行相关改进研究,完成后欢迎大家的关注。
|
A: 由于生成模型生成多样性的考量,多次运行的结果可能会有差异。当前开源的模型由于LLaMA及Alpaca中文语料有限,且知识结合的方式较为粗糙,目前我们在进行相关改进研究,完成后欢迎大家的关注。
|
||||||
|
5. Q: 模型无法运行/推理内容完全无法接受
|
||||||
|
|
||||||
|
A: 请确定已安装requirements中的依赖、配置好cuda环境并添加环境变量、正确输入下载好的模型以及lora的存储位置;推理内容如存在重复生成或部分错误内容属于llama-based模型的偶发现象,与llama模型的中文能力、训练数据规模以及超参设置均有一定的关系,未来我们会不断迭代缓解此问题。如存在严重问题,请将运行的文件名、模型名、lora等配置信息详细描述在issue中,谢谢大家。
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
28
finetune.py
28
finetune.py
@@ -25,6 +25,9 @@ from transformers import LlamaForCausalLM, LlamaTokenizer
|
|||||||
|
|
||||||
from utils.prompter import Prompter
|
from utils.prompter import Prompter
|
||||||
|
|
||||||
|
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
|
||||||
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
# model/data params
|
# model/data params
|
||||||
@@ -198,7 +201,7 @@ def train(
|
|||||||
if os.path.exists(checkpoint_name):
|
if os.path.exists(checkpoint_name):
|
||||||
print(f"Restarting from {checkpoint_name}")
|
print(f"Restarting from {checkpoint_name}")
|
||||||
adapters_weights = torch.load(checkpoint_name)
|
adapters_weights = torch.load(checkpoint_name)
|
||||||
model = set_peft_model_state_dict(model, adapters_weights)
|
set_peft_model_state_dict(model, adapters_weights)
|
||||||
else:
|
else:
|
||||||
print(f"Checkpoint {checkpoint_name} not found")
|
print(f"Checkpoint {checkpoint_name} not found")
|
||||||
|
|
||||||
@@ -223,6 +226,21 @@ def train(
|
|||||||
model.is_parallelizable = True
|
model.is_parallelizable = True
|
||||||
model.model_parallel = True
|
model.model_parallel = True
|
||||||
|
|
||||||
|
class SavePeftModelCallback(TrainerCallback):
|
||||||
|
def on_save(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
|
||||||
|
kwargs["model"].save_pretrained(checkpoint_folder)
|
||||||
|
pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
|
||||||
|
if os.path.exists(pytorch_model_path):
|
||||||
|
os.remove(pytorch_model_path)
|
||||||
|
return control
|
||||||
|
|
||||||
trainer = transformers.Trainer(
|
trainer = transformers.Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
train_dataset=train_data,
|
train_dataset=train_data,
|
||||||
@@ -251,16 +269,10 @@ def train(
|
|||||||
data_collator=transformers.DataCollatorForSeq2Seq(
|
data_collator=transformers.DataCollatorForSeq2Seq(
|
||||||
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
||||||
),
|
),
|
||||||
|
callbacks=[SavePeftModelCallback],
|
||||||
)
|
)
|
||||||
model.config.use_cache = False
|
model.config.use_cache = False
|
||||||
|
|
||||||
old_state_dict = model.state_dict
|
|
||||||
model.state_dict = (
|
|
||||||
lambda self, *_, **__: get_peft_model_state_dict(
|
|
||||||
self, old_state_dict()
|
|
||||||
)
|
|
||||||
).__get__(model, type(model))
|
|
||||||
|
|
||||||
if torch.__version__ >= "2" and sys.platform != "win32":
|
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||||
model = torch.compile(model)
|
model = torch.compile(model)
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
accelerate
|
accelerate==0.20.1
|
||||||
appdirs
|
appdirs==1.4.4
|
||||||
bitsandbytes
|
bitsandbytes==0.37.2
|
||||||
black
|
black==23.3.0
|
||||||
black[jupyter]
|
black[jupyter]==23.3.0
|
||||||
datasets
|
datasets==2.11.0
|
||||||
fire
|
fire==0.5.0
|
||||||
git+https://github.com/huggingface/peft.git
|
peft==0.3.0
|
||||||
git+https://github.com/huggingface/transformers.git
|
transformers==4.30.1
|
||||||
gradio
|
gradio==3.33.1
|
||||||
sentencepiece
|
sentencepiece==0.1.97
|
||||||
|
scipy==1.10.1
|
||||||
wandb
|
wandb
|
||||||
|
|||||||
在新工单中引用
屏蔽一个用户