diff --git a/README.md b/README.md index a84cb49..e059b10 100644 --- a/README.md +++ b/README.md @@ -227,7 +227,9 @@ https://wandb.ai/thinksoso/llama_med/runs/a5wgcnzt/overview?workspace=user-think 4. Q: 模型运行的结果不同、效果有限 A: 由于生成模型生成多样性的考量,多次运行的结果可能会有差异。当前开源的模型由于LLaMA及Alpaca中文语料有限,且知识结合的方式较为粗糙,目前我们在进行相关改进研究,完成后欢迎大家的关注。 - +5. Q: 模型无法运行/推理内容完全无法接受 + + A: 请确定已安装requirements中的依赖、配置好cuda环境并添加环境变量、正确输入下载好的模型以及lora的存储位置;推理内容如存在重复生成或部分错误内容属于llama-based模型的偶发现象,与llama模型的中文能力、训练数据规模以及超参设置均有一定的关系,未来我们会不断迭代缓解此问题。如存在严重问题,请将运行的文件名、模型名、lora等配置信息详细描述在issue中,谢谢大家。 diff --git a/finetune.py b/finetune.py index 5104681..768c2c4 100644 --- a/finetune.py +++ b/finetune.py @@ -25,6 +25,9 @@ from transformers import LlamaForCausalLM, LlamaTokenizer from utils.prompter import Prompter +from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR + def train( # model/data params @@ -198,7 +201,7 @@ def train( if os.path.exists(checkpoint_name): print(f"Restarting from {checkpoint_name}") adapters_weights = torch.load(checkpoint_name) - model = set_peft_model_state_dict(model, adapters_weights) + set_peft_model_state_dict(model, adapters_weights) else: print(f"Checkpoint {checkpoint_name} not found") @@ -223,6 +226,21 @@ def train( model.is_parallelizable = True model.model_parallel = True + class SavePeftModelCallback(TrainerCallback): + def on_save( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") + kwargs["model"].save_pretrained(checkpoint_folder) + pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin") + if os.path.exists(pytorch_model_path): + os.remove(pytorch_model_path) + return control + trainer = transformers.Trainer( model=model, train_dataset=train_data, @@ -251,16 +269,10 @@ def train( data_collator=transformers.DataCollatorForSeq2Seq( tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True ), + callbacks=[SavePeftModelCallback], ) model.config.use_cache = False - old_state_dict = model.state_dict - model.state_dict = ( - lambda self, *_, **__: get_peft_model_state_dict( - self, old_state_dict() - ) - ).__get__(model, type(model)) - if torch.__version__ >= "2" and sys.platform != "win32": model = torch.compile(model) diff --git a/requirements.txt b/requirements.txt index d253e87..4dccf77 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,13 @@ -accelerate -appdirs -bitsandbytes -black -black[jupyter] -datasets -fire -git+https://github.com/huggingface/peft.git -git+https://github.com/huggingface/transformers.git -gradio -sentencepiece +accelerate==0.20.1 +appdirs==1.4.4 +bitsandbytes==0.37.2 +black==23.3.0 +black[jupyter]==23.3.0 +datasets==2.11.0 +fire==0.5.0 +peft==0.3.0 +transformers==4.30.1 +gradio==3.33.1 +sentencepiece==0.1.97 +scipy==1.10.1 wandb