diff --git a/finetune.py b/finetune.py index 5104681..768c2c4 100644 --- a/finetune.py +++ b/finetune.py @@ -25,6 +25,9 @@ from transformers import LlamaForCausalLM, LlamaTokenizer from utils.prompter import Prompter +from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR + def train( # model/data params @@ -198,7 +201,7 @@ def train( if os.path.exists(checkpoint_name): print(f"Restarting from {checkpoint_name}") adapters_weights = torch.load(checkpoint_name) - model = set_peft_model_state_dict(model, adapters_weights) + set_peft_model_state_dict(model, adapters_weights) else: print(f"Checkpoint {checkpoint_name} not found") @@ -223,6 +226,21 @@ def train( model.is_parallelizable = True model.model_parallel = True + class SavePeftModelCallback(TrainerCallback): + def on_save( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") + kwargs["model"].save_pretrained(checkpoint_folder) + pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin") + if os.path.exists(pytorch_model_path): + os.remove(pytorch_model_path) + return control + trainer = transformers.Trainer( model=model, train_dataset=train_data, @@ -251,16 +269,10 @@ def train( data_collator=transformers.DataCollatorForSeq2Seq( tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True ), + callbacks=[SavePeftModelCallback], ) model.config.use_cache = False - old_state_dict = model.state_dict - model.state_dict = ( - lambda self, *_, **__: get_peft_model_state_dict( - self, old_state_dict() - ) - ).__get__(model, type(model)) - if torch.__version__ >= "2" and sys.platform != "win32": model = torch.compile(model) diff --git a/requirements.txt b/requirements.txt index d253e87..4dccf77 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,13 @@ -accelerate -appdirs -bitsandbytes -black -black[jupyter] -datasets -fire -git+https://github.com/huggingface/peft.git -git+https://github.com/huggingface/transformers.git -gradio -sentencepiece +accelerate==0.20.1 +appdirs==1.4.4 +bitsandbytes==0.37.2 +black==23.3.0 +black[jupyter]==23.3.0 +datasets==2.11.0 +fire==0.5.0 +peft==0.3.0 +transformers==4.30.1 +gradio==3.33.1 +sentencepiece==0.1.97 +scipy==1.10.1 wandb