优化自译解功能

这个提交包含在:
Your Name
2023-03-31 22:36:46 +08:00
父节点 ac681d3201
当前提交 65b1d78516
共有 4 个文件被更改,包括 164 次插入78 次删除

查看文件

@@ -1,14 +1,34 @@
import markdown, mdtex2html, threading, importlib, traceback, importlib, inspect
import markdown, mdtex2html, threading, importlib, traceback, importlib, inspect, re
from show_math import convert as convert_math
from functools import wraps
import re
def predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temperature, history=[], sys_prompt=''):
def get_reduce_token_percent(e):
try:
# text = "maximum context length is 4097 tokens. However, your messages resulted in 4870 tokens"
pattern = r"(\d+)\s+tokens\b"
match = re.findall(pattern, text)
eps = 50 # 稍微留一点余地, 确保下次别再超过token
max_limit = float(match[0]) - eps
current_tokens = float(match[1])
ratio = max_limit/current_tokens
assert ratio > 0 and ratio < 1
return ratio
except:
return 0.5
def predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temperature, history=[], sys_prompt='', long_connection=False):
"""
调用简单的predict_no_ui接口,但是依然保留了些许界面心跳功能,当对话太长时,会自动采用二分法截断
i_say: 当前输入
i_say_show_user: 显示到对话界面上的当前输入,例如,输入整个文件时,你绝对不想把文件的内容都糊到对话界面上
chatbot: 对话界面句柄
top_p, temperature: gpt参数
history: gpt参数 对话历史
sys_prompt: gpt参数 sys_prompt
long_connection: 是否采用更稳定的连接方式(推荐)
"""
import time
from predict import predict_no_ui
from predict import predict_no_ui, predict_no_ui_long_connection
from toolbox import get_conf
TIMEOUT_SECONDS, MAX_RETRY = get_conf('TIMEOUT_SECONDS', 'MAX_RETRY')
# 多线程的时候,需要一个mutable结构在不同线程之间传递信息
@@ -18,18 +38,26 @@ def predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temp
def mt(i_say, history):
while True:
try:
mutable[0] = predict_no_ui(inputs=i_say, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt)
if long_connection:
mutable[0] = predict_no_ui_long_connection(inputs=i_say, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt)
else:
mutable[0] = predict_no_ui(inputs=i_say, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt)
break
except ConnectionAbortedError as e:
except ConnectionAbortedError as token_exceeded_error:
# 尝试计算比例,尽可能多地保留文本
p_ratio = get_reduce_token_percent(str(token_exceeded_error))
if len(history) > 0:
history = [his[len(his)//2:] for his in history if his is not None]
history = [his[ int(len(his) *p_ratio): ] for his in history if his is not None]
mutable[1] = 'Warning! History conversation is too long, cut into half. '
else:
i_say = i_say[:len(i_say)//2]
i_say = i_say[: int(len(i_say) *p_ratio) ]
mutable[1] = 'Warning! Input file is too long, cut into half. '
except TimeoutError as e:
mutable[0] = '[Local Message] Failed with timeout.'
raise TimeoutError
except Exception as e:
mutable[0] = f'[Local Message] Failed with {str(e)}.'
raise RuntimeError(f'[Local Message] Failed with {str(e)}.')
# 创建新线程发出http请求
thread_name = threading.Thread(target=mt, args=(i_say, history)); thread_name.start()
# 原来的线程则负责持续更新UI,实现一个超时倒计时,并等待新线程的任务完成
@@ -56,6 +84,7 @@ def write_results_to_file(history, file_name=None):
with open(f'./gpt_log/{file_name}', 'w', encoding = 'utf8') as f:
f.write('# chatGPT 分析报告\n')
for i, content in enumerate(history):
if type(content) != str: content = str(content)
if i%2==0: f.write('## ')
f.write(content)
f.write('\n\n')
@@ -269,7 +298,7 @@ def get_conf(*args):
# 正确的 API_KEY 是 "sk-" + 48 位大小写字母数字的组合
API_MATCH = re.match(r"sk-[a-zA-Z0-9]{48}$", r)
if API_MATCH:
print("您的 API_KEY 是: ", r, "\nAPI_KEY 导入成功")
print(f"您的 API_KEY 是: {r[:15]}*** \nAPI_KEY 导入成功")
else:
assert False, "正确的 API_KEY 是 'sk-' + '48 位大小写字母数字' 的组合,请在config文件中修改API密钥, 添加海外代理之后再运行。" + \
"如果您刚更新过代码,请确保旧版config_private文件中没有遗留任何新增键值"