fix: return 参数数量 及 返回类型考虑 (#2129)

这个提交包含在:
Steven Moder
2025-02-07 21:33:06 +08:00
提交者 GitHub
父节点 6dda2061dd
当前提交 cf7c81170c

查看文件

@@ -1,16 +1,13 @@
import json import json
import time import time
import traceback import traceback
import requests import requests
from loguru import logger from loguru import logger
# config_private.py放自己的秘密如API和代理网址 # config_private.py放自己的秘密如API和代理网址
# 读取时首先看是否存在私密的config_private配置文件不受git管控,如果有,则覆盖原config文件 # 读取时首先看是否存在私密的config_private配置文件不受git管控,如果有,则覆盖原config文件
from toolbox import ( from toolbox import get_conf, is_the_upload_folder, update_ui
get_conf,
update_ui,
is_the_upload_folder,
)
proxies, TIMEOUT_SECONDS, MAX_RETRY = get_conf( proxies, TIMEOUT_SECONDS, MAX_RETRY = get_conf(
"proxies", "TIMEOUT_SECONDS", "MAX_RETRY" "proxies", "TIMEOUT_SECONDS", "MAX_RETRY"
@@ -39,27 +36,35 @@ def decode_chunk(chunk):
用于解读"content""finish_reason"的内容(如果支持思维链也会返回"reasoning_content"内容) 用于解读"content""finish_reason"的内容(如果支持思维链也会返回"reasoning_content"内容)
""" """
chunk = chunk.decode() chunk = chunk.decode()
respose = "" response = ""
reasoning_content = "" reasoning_content = ""
finish_reason = "False" finish_reason = "False"
# 考虑返回类型是 text/json 和 text/event-stream 两种
if chunk.startswith("data: "):
chunk = chunk[6:]
else:
chunk = chunk
try: try:
chunk = json.loads(chunk[6:]) chunk = json.loads(chunk)
except: except:
respose = "" response = ""
finish_reason = chunk finish_reason = chunk
# 错误处理部分 # 错误处理部分
if "error" in chunk: if "error" in chunk:
respose = "API_ERROR" response = "API_ERROR"
try: try:
chunk = json.loads(chunk) chunk = json.loads(chunk)
finish_reason = chunk["error"]["code"] finish_reason = chunk["error"]["code"]
except: except:
finish_reason = "API_ERROR" finish_reason = "API_ERROR"
return respose, finish_reason return response, reasoning_content, finish_reason
try: try:
if chunk["choices"][0]["delta"]["content"] is not None: if chunk["choices"][0]["delta"]["content"] is not None:
respose = chunk["choices"][0]["delta"]["content"] response = chunk["choices"][0]["delta"]["content"]
except: except:
pass pass
try: try:
@@ -71,7 +76,7 @@ def decode_chunk(chunk):
finish_reason = chunk["choices"][0]["finish_reason"] finish_reason = chunk["choices"][0]["finish_reason"]
except: except:
pass pass
return respose, reasoning_content, finish_reason return response, reasoning_content, finish_reason
def generate_message(input, model, key, history, max_output_token, system_prompt, temperature): def generate_message(input, model, key, history, max_output_token, system_prompt, temperature):
@@ -106,7 +111,7 @@ def generate_message(input, model, key, history, max_output_token, system_prompt
what_i_ask_now["role"] = "user" what_i_ask_now["role"] = "user"
what_i_ask_now["content"] = input what_i_ask_now["content"] = input
messages.append(what_i_ask_now) messages.append(what_i_ask_now)
playload = { payload = {
"model": model, "model": model,
"messages": messages, "messages": messages,
"temperature": temperature, "temperature": temperature,
@@ -114,7 +119,7 @@ def generate_message(input, model, key, history, max_output_token, system_prompt
"max_tokens": max_output_token, "max_tokens": max_output_token,
} }
return headers, playload return headers, payload
def get_predict_function( def get_predict_function(
@@ -141,7 +146,7 @@ def get_predict_function(
history=[], history=[],
sys_prompt="", sys_prompt="",
observe_window=None, observe_window=None,
console_slience=False, console_silence=False,
): ):
""" """
发送至chatGPT,等待回复,一次性完成,不显示中间过程。但内部用stream的方法避免中途网线被掐。 发送至chatGPT,等待回复,一次性完成,不显示中间过程。但内部用stream的方法避免中途网线被掐。
@@ -162,7 +167,7 @@ def get_predict_function(
raise RuntimeError(f"APIKEY为空,请检查配置文件的{APIKEY}") raise RuntimeError(f"APIKEY为空,请检查配置文件的{APIKEY}")
if inputs == "": if inputs == "":
inputs = "你好👋" inputs = "你好👋"
headers, playload = generate_message( headers, payload = generate_message(
input=inputs, input=inputs,
model=llm_kwargs["llm_model"], model=llm_kwargs["llm_model"],
key=APIKEY, key=APIKEY,
@@ -182,7 +187,7 @@ def get_predict_function(
endpoint, endpoint,
headers=headers, headers=headers,
proxies=None if disable_proxy else proxies, proxies=None if disable_proxy else proxies,
json=playload, json=payload,
stream=True, stream=True,
timeout=TIMEOUT_SECONDS, timeout=TIMEOUT_SECONDS,
) )
@@ -198,7 +203,7 @@ def get_predict_function(
result = "" result = ""
finish_reason = "" finish_reason = ""
if reasoning: if reasoning:
resoning_buffer = "" reasoning_buffer = ""
stream_response = response.iter_lines() stream_response = response.iter_lines()
while True: while True:
@@ -226,12 +231,12 @@ def get_predict_function(
if chunk: if chunk:
try: try:
if finish_reason == "stop": if finish_reason == "stop":
if not console_slience: if not console_silence:
print(f"[response] {result}") print(f"[response] {result}")
break break
result += response_text result += response_text
if reasoning: if reasoning:
resoning_buffer += reasoning_content reasoning_buffer += reasoning_content
if observe_window is not None: if observe_window is not None:
# 观测窗,把已经获取的数据显示出去 # 观测窗,把已经获取的数据显示出去
if len(observe_window) >= 1: if len(observe_window) >= 1:
@@ -247,9 +252,9 @@ def get_predict_function(
logger.error(error_msg) logger.error(error_msg)
raise RuntimeError("Json解析不合常规") raise RuntimeError("Json解析不合常规")
if reasoning: if reasoning:
# reasoning 的部分加上框 (>) return f'''<div style="padding: 1em; line-height: 1.5; text-wrap: wrap; opacity: 0.8">
return '\n'.join(map(lambda x: '> ' + x, resoning_buffer.split('\n'))) + \ {''.join([f'<p style="margin: 1.25em 0;">{line}</p>' for line in reasoning_buffer.split('\n')])}
'\n\n' + result </div>\n\n''' + result
return result return result
def predict( def predict(
@@ -268,7 +273,7 @@ def get_predict_function(
inputs 是本次问询的输入 inputs 是本次问询的输入
top_p, temperature是chatGPT的内部调优参数 top_p, temperature是chatGPT的内部调优参数
history 是之前的对话列表注意无论是inputs还是history,内容太长了都会触发token数量溢出的错误 history 是之前的对话列表注意无论是inputs还是history,内容太长了都会触发token数量溢出的错误
chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容 chatbot 为WebUI中显示的对话列表,修改它,然后yield出去,可以直接修改对话界面内容
additional_fn代表点击的哪个按钮,按钮见functional.py additional_fn代表点击的哪个按钮,按钮见functional.py
""" """
from .bridge_all import model_info from .bridge_all import model_info
@@ -299,7 +304,7 @@ def get_predict_function(
) # 刷新界面 ) # 刷新界面
time.sleep(2) time.sleep(2)
headers, playload = generate_message( headers, payload = generate_message(
input=inputs, input=inputs,
model=llm_kwargs["llm_model"], model=llm_kwargs["llm_model"],
key=APIKEY, key=APIKEY,
@@ -321,7 +326,7 @@ def get_predict_function(
endpoint, endpoint,
headers=headers, headers=headers,
proxies=None if disable_proxy else proxies, proxies=None if disable_proxy else proxies,
json=playload, json=payload,
stream=True, stream=True,
timeout=TIMEOUT_SECONDS, timeout=TIMEOUT_SECONDS,
) )
@@ -367,7 +372,7 @@ def get_predict_function(
chunk_decoded = chunk.decode() chunk_decoded = chunk.decode()
chatbot[-1] = ( chatbot[-1] = (
chatbot[-1][0], chatbot[-1][0],
"[Local Message] {finish_reason},获得以下报错信息:\n" f"[Local Message] {finish_reason},获得以下报错信息:\n"
+ chunk_decoded, + chunk_decoded,
) )
yield from update_ui( yield from update_ui(
@@ -385,7 +390,9 @@ def get_predict_function(
if reasoning: if reasoning:
gpt_replying_buffer += response_text gpt_replying_buffer += response_text
gpt_reasoning_buffer += reasoning_content gpt_reasoning_buffer += reasoning_content
history[-1] = '\n'.join(map(lambda x: '> ' + x, gpt_reasoning_buffer.split('\n'))) + '\n\n' + gpt_replying_buffer history[-1] = f'''<div style="padding: 1em; line-height: 1.5; text-wrap: wrap; opacity: 0.8">
{''.join([f'<p style="margin: 1.25em 0;">{line}</p>' for line in gpt_reasoning_buffer.split('\n')])}
</div>\n\n''' + gpt_replying_buffer
else: else:
gpt_replying_buffer += response_text gpt_replying_buffer += response_text
# 如果这里抛出异常,一般是文本过长,详情见get_full_error的输出 # 如果这里抛出异常,一般是文本过长,详情见get_full_error的输出