镜像自地址
https://github.com/binary-husky/gpt_academic.git
已同步 2025-12-06 06:26:47 +00:00
优化自译解功能
这个提交包含在:
47
toolbox.py
47
toolbox.py
@@ -1,14 +1,34 @@
|
||||
import markdown, mdtex2html, threading, importlib, traceback, importlib, inspect
|
||||
import markdown, mdtex2html, threading, importlib, traceback, importlib, inspect, re
|
||||
from show_math import convert as convert_math
|
||||
from functools import wraps
|
||||
import re
|
||||
|
||||
def predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temperature, history=[], sys_prompt=''):
|
||||
def get_reduce_token_percent(e):
|
||||
try:
|
||||
# text = "maximum context length is 4097 tokens. However, your messages resulted in 4870 tokens"
|
||||
pattern = r"(\d+)\s+tokens\b"
|
||||
match = re.findall(pattern, text)
|
||||
eps = 50 # 稍微留一点余地, 确保下次别再超过token
|
||||
max_limit = float(match[0]) - eps
|
||||
current_tokens = float(match[1])
|
||||
ratio = max_limit/current_tokens
|
||||
assert ratio > 0 and ratio < 1
|
||||
return ratio
|
||||
except:
|
||||
return 0.5
|
||||
|
||||
def predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temperature, history=[], sys_prompt='', long_connection=False):
|
||||
"""
|
||||
调用简单的predict_no_ui接口,但是依然保留了些许界面心跳功能,当对话太长时,会自动采用二分法截断
|
||||
i_say: 当前输入
|
||||
i_say_show_user: 显示到对话界面上的当前输入,例如,输入整个文件时,你绝对不想把文件的内容都糊到对话界面上
|
||||
chatbot: 对话界面句柄
|
||||
top_p, temperature: gpt参数
|
||||
history: gpt参数 对话历史
|
||||
sys_prompt: gpt参数 sys_prompt
|
||||
long_connection: 是否采用更稳定的连接方式(推荐)
|
||||
"""
|
||||
import time
|
||||
from predict import predict_no_ui
|
||||
from predict import predict_no_ui, predict_no_ui_long_connection
|
||||
from toolbox import get_conf
|
||||
TIMEOUT_SECONDS, MAX_RETRY = get_conf('TIMEOUT_SECONDS', 'MAX_RETRY')
|
||||
# 多线程的时候,需要一个mutable结构在不同线程之间传递信息
|
||||
@@ -18,18 +38,26 @@ def predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temp
|
||||
def mt(i_say, history):
|
||||
while True:
|
||||
try:
|
||||
mutable[0] = predict_no_ui(inputs=i_say, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt)
|
||||
if long_connection:
|
||||
mutable[0] = predict_no_ui_long_connection(inputs=i_say, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt)
|
||||
else:
|
||||
mutable[0] = predict_no_ui(inputs=i_say, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt)
|
||||
break
|
||||
except ConnectionAbortedError as e:
|
||||
except ConnectionAbortedError as token_exceeded_error:
|
||||
# 尝试计算比例,尽可能多地保留文本
|
||||
p_ratio = get_reduce_token_percent(str(token_exceeded_error))
|
||||
if len(history) > 0:
|
||||
history = [his[len(his)//2:] for his in history if his is not None]
|
||||
history = [his[ int(len(his) *p_ratio): ] for his in history if his is not None]
|
||||
mutable[1] = 'Warning! History conversation is too long, cut into half. '
|
||||
else:
|
||||
i_say = i_say[:len(i_say)//2]
|
||||
i_say = i_say[: int(len(i_say) *p_ratio) ]
|
||||
mutable[1] = 'Warning! Input file is too long, cut into half. '
|
||||
except TimeoutError as e:
|
||||
mutable[0] = '[Local Message] Failed with timeout.'
|
||||
raise TimeoutError
|
||||
except Exception as e:
|
||||
mutable[0] = f'[Local Message] Failed with {str(e)}.'
|
||||
raise RuntimeError(f'[Local Message] Failed with {str(e)}.')
|
||||
# 创建新线程发出http请求
|
||||
thread_name = threading.Thread(target=mt, args=(i_say, history)); thread_name.start()
|
||||
# 原来的线程则负责持续更新UI,实现一个超时倒计时,并等待新线程的任务完成
|
||||
@@ -56,6 +84,7 @@ def write_results_to_file(history, file_name=None):
|
||||
with open(f'./gpt_log/{file_name}', 'w', encoding = 'utf8') as f:
|
||||
f.write('# chatGPT 分析报告\n')
|
||||
for i, content in enumerate(history):
|
||||
if type(content) != str: content = str(content)
|
||||
if i%2==0: f.write('## ')
|
||||
f.write(content)
|
||||
f.write('\n\n')
|
||||
@@ -269,7 +298,7 @@ def get_conf(*args):
|
||||
# 正确的 API_KEY 是 "sk-" + 48 位大小写字母数字的组合
|
||||
API_MATCH = re.match(r"sk-[a-zA-Z0-9]{48}$", r)
|
||||
if API_MATCH:
|
||||
print("您的 API_KEY 是: ", r, "\nAPI_KEY 导入成功")
|
||||
print(f"您的 API_KEY 是: {r[:15]}*** \nAPI_KEY 导入成功")
|
||||
else:
|
||||
assert False, "正确的 API_KEY 是 'sk-' + '48 位大小写字母数字' 的组合,请在config文件中修改API密钥, 添加海外代理之后再运行。" + \
|
||||
"(如果您刚更新过代码,请确保旧版config_private文件中没有遗留任何新增键值)"
|
||||
|
||||
在新工单中引用
屏蔽一个用户