镜像自地址
https://github.com/SCIR-HI/Huatuo-Llama-Med-Chinese.git
已同步 2025-12-05 22:16:49 +00:00
28
finetune.py
28
finetune.py
@@ -25,6 +25,9 @@ from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
from utils.prompter import Prompter
|
||||
|
||||
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
|
||||
|
||||
def train(
|
||||
# model/data params
|
||||
@@ -198,7 +201,7 @@ def train(
|
||||
if os.path.exists(checkpoint_name):
|
||||
print(f"Restarting from {checkpoint_name}")
|
||||
adapters_weights = torch.load(checkpoint_name)
|
||||
model = set_peft_model_state_dict(model, adapters_weights)
|
||||
set_peft_model_state_dict(model, adapters_weights)
|
||||
else:
|
||||
print(f"Checkpoint {checkpoint_name} not found")
|
||||
|
||||
@@ -223,6 +226,21 @@ def train(
|
||||
model.is_parallelizable = True
|
||||
model.model_parallel = True
|
||||
|
||||
class SavePeftModelCallback(TrainerCallback):
|
||||
def on_save(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
|
||||
kwargs["model"].save_pretrained(checkpoint_folder)
|
||||
pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
|
||||
if os.path.exists(pytorch_model_path):
|
||||
os.remove(pytorch_model_path)
|
||||
return control
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
model=model,
|
||||
train_dataset=train_data,
|
||||
@@ -251,16 +269,10 @@ def train(
|
||||
data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
||||
),
|
||||
callbacks=[SavePeftModelCallback],
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
old_state_dict = model.state_dict
|
||||
model.state_dict = (
|
||||
lambda self, *_, **__: get_peft_model_state_dict(
|
||||
self, old_state_dict()
|
||||
)
|
||||
).__get__(model, type(model))
|
||||
|
||||
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||
model = torch.compile(model)
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
accelerate
|
||||
appdirs
|
||||
bitsandbytes
|
||||
black
|
||||
black[jupyter]
|
||||
datasets
|
||||
fire
|
||||
git+https://github.com/huggingface/peft.git
|
||||
git+https://github.com/huggingface/transformers.git
|
||||
gradio
|
||||
sentencepiece
|
||||
accelerate==0.20.1
|
||||
appdirs==1.4.4
|
||||
bitsandbytes==0.37.2
|
||||
black==23.3.0
|
||||
black[jupyter]==23.3.0
|
||||
datasets==2.11.0
|
||||
fire==0.5.0
|
||||
peft==0.3.0
|
||||
transformers==4.30.1
|
||||
gradio==3.33.1
|
||||
sentencepiece==0.1.97
|
||||
scipy==1.10.1
|
||||
wandb
|
||||
|
||||
在新工单中引用
屏蔽一个用户