镜像自地址
https://github.com/SCIR-HI/Huatuo-Llama-Med-Chinese.git
已同步 2025-12-10 00:16:49 +00:00
fix: fix checkpoint >6G problem
这个提交包含在:
19
finetune.py
19
finetune.py
@@ -25,6 +25,9 @@ from transformers import LlamaForCausalLM, LlamaTokenizer
|
|||||||
|
|
||||||
from utils.prompter import Prompter
|
from utils.prompter import Prompter
|
||||||
|
|
||||||
|
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
|
||||||
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
# model/data params
|
# model/data params
|
||||||
@@ -223,6 +226,21 @@ def train(
|
|||||||
model.is_parallelizable = True
|
model.is_parallelizable = True
|
||||||
model.model_parallel = True
|
model.model_parallel = True
|
||||||
|
|
||||||
|
class SavePeftModelCallback(TrainerCallback):
|
||||||
|
def on_save(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
|
||||||
|
kwargs["model"].save_pretrained(checkpoint_folder)
|
||||||
|
pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
|
||||||
|
if os.path.exists(pytorch_model_path):
|
||||||
|
os.remove(pytorch_model_path)
|
||||||
|
return control
|
||||||
|
|
||||||
trainer = transformers.Trainer(
|
trainer = transformers.Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
train_dataset=train_data,
|
train_dataset=train_data,
|
||||||
@@ -251,6 +269,7 @@ def train(
|
|||||||
data_collator=transformers.DataCollatorForSeq2Seq(
|
data_collator=transformers.DataCollatorForSeq2Seq(
|
||||||
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
||||||
),
|
),
|
||||||
|
callbacks=[SavePeftModelCallback],
|
||||||
)
|
)
|
||||||
model.config.use_cache = False
|
model.config.use_cache = False
|
||||||
|
|
||||||
|
|||||||
在新工单中引用
屏蔽一个用户