diff --git a/finetune.py b/finetune.py index 5104681..d696e79 100644 --- a/finetune.py +++ b/finetune.py @@ -198,7 +198,7 @@ def train( if os.path.exists(checkpoint_name): print(f"Restarting from {checkpoint_name}") adapters_weights = torch.load(checkpoint_name) - model = set_peft_model_state_dict(model, adapters_weights) + set_peft_model_state_dict(model, adapters_weights) else: print(f"Checkpoint {checkpoint_name} not found") @@ -254,13 +254,6 @@ def train( ) model.config.use_cache = False - old_state_dict = model.state_dict - model.state_dict = ( - lambda self, *_, **__: get_peft_model_state_dict( - self, old_state_dict() - ) - ).__get__(model, type(model)) - if torch.__version__ >= "2" and sys.platform != "win32": model = torch.compile(model)