Merge pull request #62 from sharpbai/main

修复安装及adapter_model.bin存储错误问题
这个提交包含在:
Haochun Wang
2023-06-11 21:38:12 +08:00
提交者 GitHub
当前提交 9ead770cc7
共有 2 个文件被更改,包括 32 次插入19 次删除

查看文件

@@ -25,6 +25,9 @@ from transformers import LlamaForCausalLM, LlamaTokenizer
from utils.prompter import Prompter
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
def train(
# model/data params
@@ -198,7 +201,7 @@ def train(
if os.path.exists(checkpoint_name):
print(f"Restarting from {checkpoint_name}")
adapters_weights = torch.load(checkpoint_name)
model = set_peft_model_state_dict(model, adapters_weights)
set_peft_model_state_dict(model, adapters_weights)
else:
print(f"Checkpoint {checkpoint_name} not found")
@@ -223,6 +226,21 @@ def train(
model.is_parallelizable = True
model.model_parallel = True
class SavePeftModelCallback(TrainerCallback):
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
kwargs["model"].save_pretrained(checkpoint_folder)
pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
if os.path.exists(pytorch_model_path):
os.remove(pytorch_model_path)
return control
trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
@@ -251,16 +269,10 @@ def train(
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
callbacks=[SavePeftModelCallback],
)
model.config.use_cache = False
old_state_dict = model.state_dict
model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(
self, old_state_dict()
)
).__get__(model, type(model))
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)

查看文件

@@ -1,12 +1,13 @@
accelerate
appdirs
bitsandbytes
black
black[jupyter]
datasets
fire
git+https://github.com/huggingface/peft.git
git+https://github.com/huggingface/transformers.git
gradio
sentencepiece
accelerate==0.20.1
appdirs==1.4.4
bitsandbytes==0.37.2
black==23.3.0
black[jupyter]==23.3.0
datasets==2.11.0
fire==0.5.0
peft==0.3.0
transformers==4.30.1
gradio==3.33.1
sentencepiece==0.1.97
scipy==1.10.1
wandb