merge success

这个提交包含在:
binary-husky
2023-07-18 19:51:13 +08:00
父节点 babb775cfb
当前提交 fd549fb986
共有 4 个文件被更改,包括 47 次插入53 次删除

查看文件

@@ -12,6 +12,22 @@ load_message = f"{model_name}尚未加载,加载需要一段时间。注意,
def try_to_import_special_deps():
import sentencepiece
user_prompt = "<|User|>:{user}<eoh>\n"
robot_prompt = "<|Bot|>:{robot}<eoa>\n"
cur_query_prompt = "<|User|>:{user}<eoh>\n<|Bot|>:"
def combine_history(prompt, hist):
messages = hist
total_prompt = ""
for message in messages:
cur_content = message
cur_prompt = user_prompt.replace("{user}", cur_content[0])
total_prompt += cur_prompt
cur_prompt = robot_prompt.replace("{robot}", cur_content[1])
total_prompt += cur_prompt
total_prompt = total_prompt + cur_query_prompt.replace("{user}", prompt)
return total_prompt
@Singleton
@@ -44,10 +60,10 @@ class GetInternlmHandle(Process):
else:
model = AutoModelForCausalLM.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True).to(torch.bfloat16).cuda()
self._model = self._model.eval()
model = model.eval()
return model, tokenizer
def llm_stream_generator(self, kwargs):
def llm_stream_generator(self, **kwargs):
import torch
import logging
import copy
@@ -63,12 +79,16 @@ class GetInternlmHandle(Process):
max_length = kwargs['max_length']
top_p = kwargs['top_p']
temperature = kwargs['temperature']
return model, tokenizer, prompt, max_length, top_p, temperature
history = kwargs['history']
real_prompt = combine_history(prompt, history)
return model, tokenizer, real_prompt, max_length, top_p, temperature
model, tokenizer, prompt, max_length, top_p, temperature = adaptor()
prefix_allowed_tokens_fn = None
logger = logging.get_logger(__name__)
logits_processor = None
stopping_criteria = None
additional_eos_token_id = 103028
generation_config = None
# 🏃‍♂️🏃‍♂️🏃‍♂️ 子进程执行
# 🏃‍♂️🏃‍♂️🏃‍♂️ https://github.com/InternLM/InternLM/blob/efbf5335709a8c8faeac6eaf07193973ff1d56a1/web_demo.py#L25
@@ -98,7 +118,7 @@ class GetInternlmHandle(Process):
elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if not has_default_max_length:
logger.warn(
logging.warn(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
@@ -108,7 +128,7 @@ class GetInternlmHandle(Process):
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = "input_ids"
logger.warning(
logging.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."