diff --git a/finetune.py b/finetune.py index d696e79..768c2c4 100644 --- a/finetune.py +++ b/finetune.py @@ -25,6 +25,9 @@ from transformers import LlamaForCausalLM, LlamaTokenizer from utils.prompter import Prompter +from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR + def train( # model/data params @@ -223,6 +226,21 @@ def train( model.is_parallelizable = True model.model_parallel = True + class SavePeftModelCallback(TrainerCallback): + def on_save( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") + kwargs["model"].save_pretrained(checkpoint_folder) + pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin") + if os.path.exists(pytorch_model_path): + os.remove(pytorch_model_path) + return control + trainer = transformers.Trainer( model=model, train_dataset=train_data, @@ -251,6 +269,7 @@ def train( data_collator=transformers.DataCollatorForSeq2Seq( tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True ), + callbacks=[SavePeftModelCallback], ) model.config.use_cache = False