Merge pull request #1 from SCIR-HI/main

同步master
这个提交包含在:
FlowolfzzZ
2023-07-28 12:43:41 +08:00
提交者 GitHub
当前提交 d779da8795
共有 3 个文件被更改,包括 35 次插入20 次删除

查看文件

@@ -227,7 +227,9 @@ https://wandb.ai/thinksoso/llama_med/runs/a5wgcnzt/overview?workspace=user-think
4. Q: 模型运行的结果不同、效果有限 4. Q: 模型运行的结果不同、效果有限
A: 由于生成模型生成多样性的考量,多次运行的结果可能会有差异。当前开源的模型由于LLaMA及Alpaca中文语料有限,且知识结合的方式较为粗糙,目前我们在进行相关改进研究,完成后欢迎大家的关注。 A: 由于生成模型生成多样性的考量,多次运行的结果可能会有差异。当前开源的模型由于LLaMA及Alpaca中文语料有限,且知识结合的方式较为粗糙,目前我们在进行相关改进研究,完成后欢迎大家的关注。
5. Q: 模型无法运行/推理内容完全无法接受
A: 请确定已安装requirements中的依赖、配置好cuda环境并添加环境变量、正确输入下载好的模型以及lora的存储位置;推理内容如存在重复生成或部分错误内容属于llama-based模型的偶发现象,与llama模型的中文能力、训练数据规模以及超参设置均有一定的关系,未来我们会不断迭代缓解此问题。如存在严重问题,请将运行的文件名、模型名、lora等配置信息详细描述在issue中,谢谢大家。

查看文件

@@ -25,6 +25,9 @@ from transformers import LlamaForCausalLM, LlamaTokenizer
from utils.prompter import Prompter from utils.prompter import Prompter
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
def train( def train(
# model/data params # model/data params
@@ -198,7 +201,7 @@ def train(
if os.path.exists(checkpoint_name): if os.path.exists(checkpoint_name):
print(f"Restarting from {checkpoint_name}") print(f"Restarting from {checkpoint_name}")
adapters_weights = torch.load(checkpoint_name) adapters_weights = torch.load(checkpoint_name)
model = set_peft_model_state_dict(model, adapters_weights) set_peft_model_state_dict(model, adapters_weights)
else: else:
print(f"Checkpoint {checkpoint_name} not found") print(f"Checkpoint {checkpoint_name} not found")
@@ -223,6 +226,21 @@ def train(
model.is_parallelizable = True model.is_parallelizable = True
model.model_parallel = True model.model_parallel = True
class SavePeftModelCallback(TrainerCallback):
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
kwargs["model"].save_pretrained(checkpoint_folder)
pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
if os.path.exists(pytorch_model_path):
os.remove(pytorch_model_path)
return control
trainer = transformers.Trainer( trainer = transformers.Trainer(
model=model, model=model,
train_dataset=train_data, train_dataset=train_data,
@@ -251,16 +269,10 @@ def train(
data_collator=transformers.DataCollatorForSeq2Seq( data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
), ),
callbacks=[SavePeftModelCallback],
) )
model.config.use_cache = False model.config.use_cache = False
old_state_dict = model.state_dict
model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(
self, old_state_dict()
)
).__get__(model, type(model))
if torch.__version__ >= "2" and sys.platform != "win32": if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model) model = torch.compile(model)

查看文件

@@ -1,12 +1,13 @@
accelerate accelerate==0.20.1
appdirs appdirs==1.4.4
bitsandbytes bitsandbytes==0.37.2
black black==23.3.0
black[jupyter] black[jupyter]==23.3.0
datasets datasets==2.11.0
fire fire==0.5.0
git+https://github.com/huggingface/peft.git peft==0.3.0
git+https://github.com/huggingface/transformers.git transformers==4.30.1
gradio gradio==3.33.1
sentencepiece sentencepiece==0.1.97
scipy==1.10.1
wandb wandb