镜像自地址
https://github.com/binary-husky/gpt_academic.git
已同步 2025-12-06 14:36:48 +00:00
fix: return 参数数量 及 返回类型考虑 (#2129)
这个提交包含在:
@@ -1,16 +1,13 @@
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
# config_private.py放自己的秘密如API和代理网址
|
# config_private.py放自己的秘密如API和代理网址
|
||||||
# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
|
# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
|
||||||
from toolbox import (
|
from toolbox import get_conf, is_the_upload_folder, update_ui
|
||||||
get_conf,
|
|
||||||
update_ui,
|
|
||||||
is_the_upload_folder,
|
|
||||||
)
|
|
||||||
|
|
||||||
proxies, TIMEOUT_SECONDS, MAX_RETRY = get_conf(
|
proxies, TIMEOUT_SECONDS, MAX_RETRY = get_conf(
|
||||||
"proxies", "TIMEOUT_SECONDS", "MAX_RETRY"
|
"proxies", "TIMEOUT_SECONDS", "MAX_RETRY"
|
||||||
@@ -39,27 +36,35 @@ def decode_chunk(chunk):
|
|||||||
用于解读"content"和"finish_reason"的内容(如果支持思维链也会返回"reasoning_content"内容)
|
用于解读"content"和"finish_reason"的内容(如果支持思维链也会返回"reasoning_content"内容)
|
||||||
"""
|
"""
|
||||||
chunk = chunk.decode()
|
chunk = chunk.decode()
|
||||||
respose = ""
|
response = ""
|
||||||
reasoning_content = ""
|
reasoning_content = ""
|
||||||
finish_reason = "False"
|
finish_reason = "False"
|
||||||
|
|
||||||
|
# 考虑返回类型是 text/json 和 text/event-stream 两种
|
||||||
|
if chunk.startswith("data: "):
|
||||||
|
chunk = chunk[6:]
|
||||||
|
else:
|
||||||
|
chunk = chunk
|
||||||
|
|
||||||
try:
|
try:
|
||||||
chunk = json.loads(chunk[6:])
|
chunk = json.loads(chunk)
|
||||||
except:
|
except:
|
||||||
respose = ""
|
response = ""
|
||||||
finish_reason = chunk
|
finish_reason = chunk
|
||||||
|
|
||||||
# 错误处理部分
|
# 错误处理部分
|
||||||
if "error" in chunk:
|
if "error" in chunk:
|
||||||
respose = "API_ERROR"
|
response = "API_ERROR"
|
||||||
try:
|
try:
|
||||||
chunk = json.loads(chunk)
|
chunk = json.loads(chunk)
|
||||||
finish_reason = chunk["error"]["code"]
|
finish_reason = chunk["error"]["code"]
|
||||||
except:
|
except:
|
||||||
finish_reason = "API_ERROR"
|
finish_reason = "API_ERROR"
|
||||||
return respose, finish_reason
|
return response, reasoning_content, finish_reason
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if chunk["choices"][0]["delta"]["content"] is not None:
|
if chunk["choices"][0]["delta"]["content"] is not None:
|
||||||
respose = chunk["choices"][0]["delta"]["content"]
|
response = chunk["choices"][0]["delta"]["content"]
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
@@ -71,7 +76,7 @@ def decode_chunk(chunk):
|
|||||||
finish_reason = chunk["choices"][0]["finish_reason"]
|
finish_reason = chunk["choices"][0]["finish_reason"]
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
return respose, reasoning_content, finish_reason
|
return response, reasoning_content, finish_reason
|
||||||
|
|
||||||
|
|
||||||
def generate_message(input, model, key, history, max_output_token, system_prompt, temperature):
|
def generate_message(input, model, key, history, max_output_token, system_prompt, temperature):
|
||||||
@@ -106,7 +111,7 @@ def generate_message(input, model, key, history, max_output_token, system_prompt
|
|||||||
what_i_ask_now["role"] = "user"
|
what_i_ask_now["role"] = "user"
|
||||||
what_i_ask_now["content"] = input
|
what_i_ask_now["content"] = input
|
||||||
messages.append(what_i_ask_now)
|
messages.append(what_i_ask_now)
|
||||||
playload = {
|
payload = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
@@ -114,7 +119,7 @@ def generate_message(input, model, key, history, max_output_token, system_prompt
|
|||||||
"max_tokens": max_output_token,
|
"max_tokens": max_output_token,
|
||||||
}
|
}
|
||||||
|
|
||||||
return headers, playload
|
return headers, payload
|
||||||
|
|
||||||
|
|
||||||
def get_predict_function(
|
def get_predict_function(
|
||||||
@@ -141,7 +146,7 @@ def get_predict_function(
|
|||||||
history=[],
|
history=[],
|
||||||
sys_prompt="",
|
sys_prompt="",
|
||||||
observe_window=None,
|
observe_window=None,
|
||||||
console_slience=False,
|
console_silence=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
发送至chatGPT,等待回复,一次性完成,不显示中间过程。但内部用stream的方法避免中途网线被掐。
|
发送至chatGPT,等待回复,一次性完成,不显示中间过程。但内部用stream的方法避免中途网线被掐。
|
||||||
@@ -162,7 +167,7 @@ def get_predict_function(
|
|||||||
raise RuntimeError(f"APIKEY为空,请检查配置文件的{APIKEY}")
|
raise RuntimeError(f"APIKEY为空,请检查配置文件的{APIKEY}")
|
||||||
if inputs == "":
|
if inputs == "":
|
||||||
inputs = "你好👋"
|
inputs = "你好👋"
|
||||||
headers, playload = generate_message(
|
headers, payload = generate_message(
|
||||||
input=inputs,
|
input=inputs,
|
||||||
model=llm_kwargs["llm_model"],
|
model=llm_kwargs["llm_model"],
|
||||||
key=APIKEY,
|
key=APIKEY,
|
||||||
@@ -182,7 +187,7 @@ def get_predict_function(
|
|||||||
endpoint,
|
endpoint,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
proxies=None if disable_proxy else proxies,
|
proxies=None if disable_proxy else proxies,
|
||||||
json=playload,
|
json=payload,
|
||||||
stream=True,
|
stream=True,
|
||||||
timeout=TIMEOUT_SECONDS,
|
timeout=TIMEOUT_SECONDS,
|
||||||
)
|
)
|
||||||
@@ -198,7 +203,7 @@ def get_predict_function(
|
|||||||
result = ""
|
result = ""
|
||||||
finish_reason = ""
|
finish_reason = ""
|
||||||
if reasoning:
|
if reasoning:
|
||||||
resoning_buffer = ""
|
reasoning_buffer = ""
|
||||||
|
|
||||||
stream_response = response.iter_lines()
|
stream_response = response.iter_lines()
|
||||||
while True:
|
while True:
|
||||||
@@ -226,12 +231,12 @@ def get_predict_function(
|
|||||||
if chunk:
|
if chunk:
|
||||||
try:
|
try:
|
||||||
if finish_reason == "stop":
|
if finish_reason == "stop":
|
||||||
if not console_slience:
|
if not console_silence:
|
||||||
print(f"[response] {result}")
|
print(f"[response] {result}")
|
||||||
break
|
break
|
||||||
result += response_text
|
result += response_text
|
||||||
if reasoning:
|
if reasoning:
|
||||||
resoning_buffer += reasoning_content
|
reasoning_buffer += reasoning_content
|
||||||
if observe_window is not None:
|
if observe_window is not None:
|
||||||
# 观测窗,把已经获取的数据显示出去
|
# 观测窗,把已经获取的数据显示出去
|
||||||
if len(observe_window) >= 1:
|
if len(observe_window) >= 1:
|
||||||
@@ -247,9 +252,9 @@ def get_predict_function(
|
|||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
raise RuntimeError("Json解析不合常规")
|
raise RuntimeError("Json解析不合常规")
|
||||||
if reasoning:
|
if reasoning:
|
||||||
# reasoning 的部分加上框 (>)
|
return f'''<div style="padding: 1em; line-height: 1.5; text-wrap: wrap; opacity: 0.8">
|
||||||
return '\n'.join(map(lambda x: '> ' + x, resoning_buffer.split('\n'))) + \
|
{''.join([f'<p style="margin: 1.25em 0;">{line}</p>' for line in reasoning_buffer.split('\n')])}
|
||||||
'\n\n' + result
|
</div>\n\n''' + result
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
@@ -268,7 +273,7 @@ def get_predict_function(
|
|||||||
inputs 是本次问询的输入
|
inputs 是本次问询的输入
|
||||||
top_p, temperature是chatGPT的内部调优参数
|
top_p, temperature是chatGPT的内部调优参数
|
||||||
history 是之前的对话列表(注意无论是inputs还是history,内容太长了都会触发token数量溢出的错误)
|
history 是之前的对话列表(注意无论是inputs还是history,内容太长了都会触发token数量溢出的错误)
|
||||||
chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容
|
chatbot 为WebUI中显示的对话列表,修改它,然后yield出去,可以直接修改对话界面内容
|
||||||
additional_fn代表点击的哪个按钮,按钮见functional.py
|
additional_fn代表点击的哪个按钮,按钮见functional.py
|
||||||
"""
|
"""
|
||||||
from .bridge_all import model_info
|
from .bridge_all import model_info
|
||||||
@@ -299,7 +304,7 @@ def get_predict_function(
|
|||||||
) # 刷新界面
|
) # 刷新界面
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
|
|
||||||
headers, playload = generate_message(
|
headers, payload = generate_message(
|
||||||
input=inputs,
|
input=inputs,
|
||||||
model=llm_kwargs["llm_model"],
|
model=llm_kwargs["llm_model"],
|
||||||
key=APIKEY,
|
key=APIKEY,
|
||||||
@@ -321,7 +326,7 @@ def get_predict_function(
|
|||||||
endpoint,
|
endpoint,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
proxies=None if disable_proxy else proxies,
|
proxies=None if disable_proxy else proxies,
|
||||||
json=playload,
|
json=payload,
|
||||||
stream=True,
|
stream=True,
|
||||||
timeout=TIMEOUT_SECONDS,
|
timeout=TIMEOUT_SECONDS,
|
||||||
)
|
)
|
||||||
@@ -367,7 +372,7 @@ def get_predict_function(
|
|||||||
chunk_decoded = chunk.decode()
|
chunk_decoded = chunk.decode()
|
||||||
chatbot[-1] = (
|
chatbot[-1] = (
|
||||||
chatbot[-1][0],
|
chatbot[-1][0],
|
||||||
"[Local Message] {finish_reason},获得以下报错信息:\n"
|
f"[Local Message] {finish_reason},获得以下报错信息:\n"
|
||||||
+ chunk_decoded,
|
+ chunk_decoded,
|
||||||
)
|
)
|
||||||
yield from update_ui(
|
yield from update_ui(
|
||||||
@@ -385,7 +390,9 @@ def get_predict_function(
|
|||||||
if reasoning:
|
if reasoning:
|
||||||
gpt_replying_buffer += response_text
|
gpt_replying_buffer += response_text
|
||||||
gpt_reasoning_buffer += reasoning_content
|
gpt_reasoning_buffer += reasoning_content
|
||||||
history[-1] = '\n'.join(map(lambda x: '> ' + x, gpt_reasoning_buffer.split('\n'))) + '\n\n' + gpt_replying_buffer
|
history[-1] = f'''<div style="padding: 1em; line-height: 1.5; text-wrap: wrap; opacity: 0.8">
|
||||||
|
{''.join([f'<p style="margin: 1.25em 0;">{line}</p>' for line in gpt_reasoning_buffer.split('\n')])}
|
||||||
|
</div>\n\n''' + gpt_replying_buffer
|
||||||
else:
|
else:
|
||||||
gpt_replying_buffer += response_text
|
gpt_replying_buffer += response_text
|
||||||
# 如果这里抛出异常,一般是文本过长,详情见get_full_error的输出
|
# 如果这里抛出异常,一般是文本过长,详情见get_full_error的输出
|
||||||
|
|||||||
在新工单中引用
屏蔽一个用户