增扩框架的参数IO

这个提交包含在:
qingxu fu
2023-04-09 20:42:23 +08:00
父节点 ea031ab05b
当前提交 0666fec86e
共有 18 个文件被更改,包括 239 次插入203 次删除

查看文件

@@ -72,14 +72,14 @@ def predict_no_ui(inputs, top_p, temperature, history=[], sys_prompt=""):
raise ConnectionAbortedError("Json解析不合常规,可能是文本过长" + response.text)
def predict_no_ui_long_connection(inputs, top_p, temperature, history=[], sys_prompt="", observe_window=None, console_slience=False):
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None, console_slience=False):
"""
发送至chatGPT,等待回复,一次性完成,不显示中间过程。但内部用stream的方法避免中途网线被掐。
inputs
是本次问询的输入
sys_prompt:
系统静默prompt
top_p, temperature
llm_kwargs
chatGPT的内部调优参数
history
是之前的对话列表
@@ -87,7 +87,7 @@ def predict_no_ui_long_connection(inputs, top_p, temperature, history=[], sys_pr
用于负责跨越线程传递已经输出的部分,大部分时候仅仅为了fancy的视觉效果,留空即可。observe_window[0]观测窗。observe_window[1]:看门狗
"""
watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可
headers, payload = generate_payload(inputs, top_p, temperature, history, system_prompt=sys_prompt, stream=True)
headers, payload = generate_payload(inputs, llm_kwargs, history, system_prompt=sys_prompt, stream=True)
retry = 0
while True:
try:
@@ -135,8 +135,7 @@ def predict_no_ui_long_connection(inputs, top_p, temperature, history=[], sys_pr
return result
def predict(inputs, top_p, temperature, chatbot=[], history=[], system_prompt='',
stream = True, additional_fn=None):
def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream = True, additional_fn=None):
"""
发送至chatGPT,流式获取输出。
用于基础的对话功能。
@@ -157,9 +156,9 @@ def predict(inputs, top_p, temperature, chatbot=[], history=[], system_prompt=''
raw_input = inputs
logging.info(f'[raw_input] {raw_input}')
chatbot.append((inputs, ""))
yield from update_ui(chatbot=chatbot, history=history, msg="等待响应")
yield from update_ui(chatbot=chatbot, history=history, msg="等待响应") # 刷新界面
headers, payload = generate_payload(inputs, top_p, temperature, history, system_prompt, stream)
headers, payload = generate_payload(inputs, llm_kwargs, history, system_prompt, stream)
history.append(inputs); history.append(" ")
retry = 0
@@ -172,7 +171,7 @@ def predict(inputs, top_p, temperature, chatbot=[], history=[], system_prompt=''
retry += 1
chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg))
retry_msg = f",正在重试 ({retry}/{MAX_RETRY}) ……" if MAX_RETRY > 0 else ""
yield from update_ui(chatbot=chatbot, history=history, msg="请求超时"+retry_msg)
yield from update_ui(chatbot=chatbot, history=history, msg="请求超时"+retry_msg) # 刷新界面
if retry > MAX_RETRY: raise TimeoutError
gpt_replying_buffer = ""
@@ -200,11 +199,11 @@ def predict(inputs, top_p, temperature, chatbot=[], history=[], system_prompt=''
gpt_replying_buffer = gpt_replying_buffer + json.loads(chunk.decode()[6:])['choices'][0]["delta"]["content"]
history[-1] = gpt_replying_buffer
chatbot[-1] = (history[-2], history[-1])
yield from update_ui(chatbot=chatbot, history=history, msg=status_text)
yield from update_ui(chatbot=chatbot, history=history, msg=status_text) # 刷新界面
except Exception as e:
traceback.print_exc()
yield from update_ui(chatbot=chatbot, history=history, msg="Json解析不合常规")
yield from update_ui(chatbot=chatbot, history=history, msg="Json解析不合常规") # 刷新界面
chunk = get_full_error(chunk, stream_response)
error_msg = chunk.decode()
if "reduce the length" in error_msg:
@@ -218,10 +217,10 @@ def predict(inputs, top_p, temperature, chatbot=[], history=[], system_prompt=''
from toolbox import regular_txt_to_markdown
tb_str = '```\n' + traceback.format_exc() + '```'
chatbot[-1] = (chatbot[-1][0], f"[Local Message] 异常 \n\n{tb_str} \n\n{regular_txt_to_markdown(chunk.decode()[4:])}")
yield from update_ui(chatbot=chatbot, history=history, msg="Json异常" + error_msg)
yield from update_ui(chatbot=chatbot, history=history, msg="Json异常" + error_msg) # 刷新界面
return
def generate_payload(inputs, top_p, temperature, history, system_prompt, stream):
def generate_payload(inputs, llm_kwargs, history, system_prompt, stream):
"""
整合所有信息,选择LLM模型,生成http请求,为发送请求做准备
"""
@@ -257,8 +256,8 @@ def generate_payload(inputs, top_p, temperature, history, system_prompt, stream)
payload = {
"model": LLM_MODEL,
"messages": messages,
"temperature": temperature, # 1.0,
"top_p": top_p, # 1.0,
"temperature": llm_kwargs['temperature'], # 1.0,
"top_p": llm_kwargs['top_p'], # 1.0,
"n": 1,
"stream": stream,
"presence_penalty": 0,