Update Huozi-based model

Major update. Please try our new Huozi-based model, which is much better.
这个提交包含在:
s65b40
2023-08-07 21:46:05 +08:00
父节点 b51d25e1ee
当前提交 5ae846fb74
共有 6 个文件被更改,包括 81 次插入82 次删除

查看文件

@@ -21,7 +21,7 @@ from peft import (
prepare_model_for_int8_training,
set_peft_model_state_dict,
)
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils.prompter import Prompter
@@ -111,14 +111,14 @@ def train(
if len(wandb_log_model) > 0:
os.environ["WANDB_LOG_MODEL"] = wandb_log_model
model = LlamaForCausalLM.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=True,
torch_dtype=torch.float16,
device_map=device_map,
)
tokenizer = LlamaTokenizer.from_pretrained(base_model)
tokenizer = AutoTokenizer.from_pretrained(base_model)
tokenizer.pad_token_id = (
0 # unk. we want this to be different from the eos token