这个提交包含在:
binary-husky
2023-09-09 18:56:10 +08:00
父节点 f5357f67ca
当前提交 5c0a0882c8
共有 52 个文件被更改,包括 2710 次插入591 次删除

查看文件

@@ -398,6 +398,22 @@ if "spark" in AVAIL_LLM_MODELS: # 讯飞星火认知大模型
})
except:
print(trimmed_format_exc())
if "sparkv2" in AVAIL_LLM_MODELS: # 讯飞星火认知大模型
try:
from .bridge_spark import predict_no_ui_long_connection as spark_noui
from .bridge_spark import predict as spark_ui
model_info.update({
"sparkv2": {
"fn_with_ui": spark_ui,
"fn_without_ui": spark_noui,
"endpoint": None,
"max_token": 4096,
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
}
})
except:
print(trimmed_format_exc())
if "llama2" in AVAIL_LLM_MODELS: # llama2
try:
from .bridge_llama2 import predict_no_ui_long_connection as llama2_noui

查看文件

@@ -63,9 +63,9 @@ class GetGLMFTHandle(Process):
# if not os.path.exists(conf): raise RuntimeError('找不到微调模型信息')
# with open(conf, 'r', encoding='utf8') as f:
# model_args = json.loads(f.read())
ChatGLM_PTUNING_CHECKPOINT, = get_conf('ChatGLM_PTUNING_CHECKPOINT')
assert os.path.exists(ChatGLM_PTUNING_CHECKPOINT), "找不到微调模型检查点"
conf = os.path.join(ChatGLM_PTUNING_CHECKPOINT, "config.json")
CHATGLM_PTUNING_CHECKPOINT, = get_conf('CHATGLM_PTUNING_CHECKPOINT')
assert os.path.exists(CHATGLM_PTUNING_CHECKPOINT), "找不到微调模型检查点"
conf = os.path.join(CHATGLM_PTUNING_CHECKPOINT, "config.json")
with open(conf, 'r', encoding='utf8') as f:
model_args = json.loads(f.read())
if 'model_name_or_path' not in model_args:
@@ -78,9 +78,9 @@ class GetGLMFTHandle(Process):
config.pre_seq_len = model_args['pre_seq_len']
config.prefix_projection = model_args['prefix_projection']
print(f"Loading prefix_encoder weight from {ChatGLM_PTUNING_CHECKPOINT}")
print(f"Loading prefix_encoder weight from {CHATGLM_PTUNING_CHECKPOINT}")
model = AutoModel.from_pretrained(model_args['model_name_or_path'], config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join(ChatGLM_PTUNING_CHECKPOINT, "pytorch_model.bin"))
prefix_state_dict = torch.load(os.path.join(CHATGLM_PTUNING_CHECKPOINT, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):

查看文件

@@ -137,6 +137,12 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
chatbot.append((inputs, ""))
yield from update_ui(chatbot=chatbot, history=history, msg="等待响应") # 刷新界面
# check mis-behavior
if raw_input.startswith('private_upload/') and len(raw_input) == 34:
chatbot[-1] = (inputs, f"[Local Message] 检测到操作错误!当您上传文档之后,需要点击“函数插件区”按钮进行处理,而不是点击“提交”按钮。")
yield from update_ui(chatbot=chatbot, history=history, msg="正常") # 刷新界面
time.sleep(2)
try:
headers, payload = generate_payload(inputs, llm_kwargs, history, system_prompt, stream)
except RuntimeError as e:
@@ -178,7 +184,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
return
chunk_decoded = chunk.decode()
if is_head_of_the_stream and (r'"object":"error"' not in chunk_decoded) and (r"choices" not in chunk_decoded):
if is_head_of_the_stream and (r'"object":"error"' not in chunk_decoded) and (r"content" not in chunk_decoded):
# 数据流的第一帧不携带content
is_head_of_the_stream = False; continue

查看文件

@@ -49,16 +49,17 @@ def get_access_token():
def generate_message_payload(inputs, llm_kwargs, history, system_prompt):
conversation_cnt = len(history) // 2
if system_prompt == "": system_prompt = "Hello"
messages = [{"role": "user", "content": system_prompt}]
messages.append({"role": "assistant", "content": 'Certainly!'})
if conversation_cnt:
for index in range(0, 2*conversation_cnt, 2):
what_i_have_asked = {}
what_i_have_asked["role"] = "user"
what_i_have_asked["content"] = history[index]
what_i_have_asked["content"] = history[index] if history[index]!="" else "Hello"
what_gpt_answer = {}
what_gpt_answer["role"] = "assistant"
what_gpt_answer["content"] = history[index+1]
what_gpt_answer["content"] = history[index+1] if history[index]!="" else "Hello"
if what_i_have_asked["content"] != "":
if what_gpt_answer["content"] == "": continue
if what_gpt_answer["content"] == timeout_bot_msg: continue

查看文件

@@ -2,11 +2,17 @@
import time
import threading
import importlib
from toolbox import update_ui, get_conf
from toolbox import update_ui, get_conf, update_ui_lastest_msg
from multiprocessing import Process, Pipe
model_name = '星火认知大模型'
def validate_key():
XFYUN_APPID, = get_conf('XFYUN_APPID', )
if XFYUN_APPID == '00000000' or XFYUN_APPID == '':
return False
return True
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=[], console_slience=False):
"""
⭐多线程方法
@@ -15,6 +21,9 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
watch_dog_patience = 5
response = ""
if validate_key() is False:
raise RuntimeError('请配置讯飞星火大模型的XFYUN_APPID, XFYUN_API_KEY, XFYUN_API_SECRET')
from .com_sparkapi import SparkRequestInstance
sri = SparkRequestInstance()
for response in sri.generate(inputs, llm_kwargs, history, sys_prompt):
@@ -30,6 +39,11 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
函数的说明请见 request_llm/bridge_all.py
"""
chatbot.append((inputs, ""))
yield from update_ui(chatbot=chatbot, history=history)
if validate_key() is False:
yield from update_ui_lastest_msg(lastmsg="[Local Message]: 请配置讯飞星火大模型的XFYUN_APPID, XFYUN_API_KEY, XFYUN_API_SECRET", chatbot=chatbot, history=history, delay=0)
return
if additional_fn is not None:
from core_functional import handle_core_functionality

查看文件

@@ -58,11 +58,13 @@ class Ws_Param(object):
class SparkRequestInstance():
def __init__(self):
XFYUN_APPID, XFYUN_API_SECRET, XFYUN_API_KEY = get_conf('XFYUN_APPID', 'XFYUN_API_SECRET', 'XFYUN_API_KEY')
if XFYUN_APPID == '00000000' or XFYUN_APPID == '': raise RuntimeError('请配置讯飞星火大模型的XFYUN_APPID, XFYUN_API_KEY, XFYUN_API_SECRET')
self.appid = XFYUN_APPID
self.api_secret = XFYUN_API_SECRET
self.api_key = XFYUN_API_KEY
self.gpt_url = "ws://spark-api.xf-yun.com/v1.1/chat"
self.gpt_url_v2 = "ws://spark-api.xf-yun.com/v2.1/chat"
self.time_to_yield_event = threading.Event()
self.time_to_exit_event = threading.Event()
@@ -83,7 +85,12 @@ class SparkRequestInstance():
def create_blocking_request(self, inputs, llm_kwargs, history, system_prompt):
wsParam = Ws_Param(self.appid, self.api_key, self.api_secret, self.gpt_url)
if llm_kwargs['llm_model'] == 'sparkv2':
gpt_url = self.gpt_url_v2
else:
gpt_url = self.gpt_url
wsParam = Ws_Param(self.appid, self.api_key, self.api_secret, gpt_url)
websocket.enableTrace(False)
wsUrl = wsParam.create_url()
@@ -167,7 +174,7 @@ def gen_params(appid, inputs, llm_kwargs, history, system_prompt):
},
"parameter": {
"chat": {
"domain": "general",
"domain": "generalv2" if llm_kwargs['llm_model'] == 'sparkv2' else "general",
"temperature": llm_kwargs["temperature"],
"random_threshold": 0.5,
"max_tokens": 4096,