fix: modify adapter_model.bin size 433 problem

这个提交包含在:
sharpbai
2023-06-11 09:40:52 +08:00
父节点 3530b5e6ba
当前提交 ee806b7243

查看文件

@@ -198,7 +198,7 @@ def train(
if os.path.exists(checkpoint_name): if os.path.exists(checkpoint_name):
print(f"Restarting from {checkpoint_name}") print(f"Restarting from {checkpoint_name}")
adapters_weights = torch.load(checkpoint_name) adapters_weights = torch.load(checkpoint_name)
model = set_peft_model_state_dict(model, adapters_weights) set_peft_model_state_dict(model, adapters_weights)
else: else:
print(f"Checkpoint {checkpoint_name} not found") print(f"Checkpoint {checkpoint_name} not found")
@@ -254,13 +254,6 @@ def train(
) )
model.config.use_cache = False model.config.use_cache = False
old_state_dict = model.state_dict
model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(
self, old_state_dict()
)
).__get__(model, type(model))
if torch.__version__ >= "2" and sys.platform != "win32": if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model) model = torch.compile(model)