镜像自地址
https://github.com/binary-husky/gpt_academic.git
已同步 2025-12-06 06:26:47 +00:00
update by pull
这个提交包含在:
258
toolbox.py
258
toolbox.py
@@ -1,13 +1,10 @@
|
||||
import markdown
|
||||
import mdtex2html
|
||||
import threading
|
||||
import importlib
|
||||
import traceback
|
||||
import inspect
|
||||
import re
|
||||
from latex2mathml.converter import convert as tex2mathml
|
||||
from functools import wraps, lru_cache
|
||||
|
||||
############################### 插件输入输出接驳区 #######################################
|
||||
class ChatBotWithCookies(list):
|
||||
def __init__(self, cookie):
|
||||
@@ -25,9 +22,9 @@ class ChatBotWithCookies(list):
|
||||
|
||||
def ArgsGeneralWrapper(f):
|
||||
"""
|
||||
装饰器函数,用于重组输入参数,改变输入参数的顺序与结构。
|
||||
装饰器函数,用于重组输入参数,改变输入参数的顺序与结构。
|
||||
"""
|
||||
def decorated(cookies, txt, txt2, top_p, temperature, chatbot, history, system_prompt, txt_pattern, *args):
|
||||
def decorated(cookies, max_length, llm_model, txt, txt2, top_p, temperature, chatbot, history, system_prompt, txt_pattern, *args):
|
||||
txt_passon = txt
|
||||
if txt == "" and txt2 != "": txt_passon = txt2
|
||||
# 引入一个有cookie的chatbot
|
||||
@@ -37,8 +34,9 @@ def ArgsGeneralWrapper(f):
|
||||
})
|
||||
llm_kwargs = {
|
||||
'api_key': cookies['api_key'],
|
||||
'llm_model': cookies['llm_model'],
|
||||
'llm_model': llm_model,
|
||||
'top_p':top_p,
|
||||
'max_length': max_length,
|
||||
'temperature':temperature,
|
||||
}
|
||||
# plugin_kwargs = {
|
||||
@@ -56,129 +54,10 @@ def update_ui(chatbot, history, msg='正常', **kwargs): # 刷新界面
|
||||
"""
|
||||
assert isinstance(chatbot, ChatBotWithCookies), "在传递chatbot的过程中不要将其丢弃。必要时,可用clear将其清空,然后用for+append循环重新赋值。"
|
||||
yield chatbot.get_cookies(), chatbot, history, msg
|
||||
############################### ################## #######################################
|
||||
##########################################################################################
|
||||
|
||||
def get_reduce_token_percent(text):
|
||||
"""
|
||||
* 此函数未来将被弃用
|
||||
"""
|
||||
try:
|
||||
# text = "maximum context length is 4097 tokens. However, your messages resulted in 4870 tokens"
|
||||
pattern = r"(\d+)\s+tokens\b"
|
||||
match = re.findall(pattern, text)
|
||||
EXCEED_ALLO = 500 # 稍微留一点余地,否则在回复时会因余量太少出问题
|
||||
max_limit = float(match[0]) - EXCEED_ALLO
|
||||
current_tokens = float(match[1])
|
||||
ratio = max_limit/current_tokens
|
||||
assert ratio > 0 and ratio < 1
|
||||
return ratio, str(int(current_tokens-max_limit))
|
||||
except:
|
||||
return 0.5, '不详'
|
||||
|
||||
def predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, llm_kwargs, history=[], sys_prompt='', long_connection=True):
|
||||
"""
|
||||
* 此函数未来将被弃用(替代函数 request_gpt_model_in_new_thread_with_ui_alive 文件 chatgpt_academic/crazy_functions/crazy_utils)
|
||||
|
||||
调用简单的predict_no_ui接口,但是依然保留了些许界面心跳功能,当对话太长时,会自动采用二分法截断
|
||||
i_say: 当前输入
|
||||
i_say_show_user: 显示到对话界面上的当前输入,例如,输入整个文件时,你绝对不想把文件的内容都糊到对话界面上
|
||||
chatbot: 对话界面句柄
|
||||
top_p, temperature: gpt参数
|
||||
history: gpt参数 对话历史
|
||||
sys_prompt: gpt参数 sys_prompt
|
||||
long_connection: 是否采用更稳定的连接方式(推荐)(已弃用)
|
||||
"""
|
||||
import time
|
||||
from request_llm.bridge_chatgpt import predict_no_ui_long_connection
|
||||
from toolbox import get_conf
|
||||
TIMEOUT_SECONDS, MAX_RETRY = get_conf('TIMEOUT_SECONDS', 'MAX_RETRY')
|
||||
# 多线程的时候,需要一个mutable结构在不同线程之间传递信息
|
||||
# list就是最简单的mutable结构,我们第一个位置放gpt输出,第二个位置传递报错信息
|
||||
mutable = [None, '']
|
||||
# multi-threading worker
|
||||
|
||||
def mt(i_say, history):
|
||||
while True:
|
||||
try:
|
||||
mutable[0] = predict_no_ui_long_connection(
|
||||
inputs=i_say, llm_kwargs=llm_kwargs, history=history, sys_prompt=sys_prompt)
|
||||
|
||||
except ConnectionAbortedError as token_exceeded_error:
|
||||
# 尝试计算比例,尽可能多地保留文本
|
||||
p_ratio, n_exceed = get_reduce_token_percent(
|
||||
str(token_exceeded_error))
|
||||
if len(history) > 0:
|
||||
history = [his[int(len(his) * p_ratio):]
|
||||
for his in history if his is not None]
|
||||
else:
|
||||
i_say = i_say[: int(len(i_say) * p_ratio)]
|
||||
mutable[1] = f'警告,文本过长将进行截断,Token溢出数:{n_exceed},截断比例:{(1-p_ratio):.0%}。'
|
||||
except TimeoutError as e:
|
||||
mutable[0] = '[Local Message] 请求超时。'
|
||||
raise TimeoutError
|
||||
except Exception as e:
|
||||
mutable[0] = f'[Local Message] 异常:{str(e)}.'
|
||||
raise RuntimeError(f'[Local Message] 异常:{str(e)}.')
|
||||
# 创建新线程发出http请求
|
||||
thread_name = threading.Thread(target=mt, args=(i_say, history))
|
||||
thread_name.start()
|
||||
# 原来的线程则负责持续更新UI,实现一个超时倒计时,并等待新线程的任务完成
|
||||
cnt = 0
|
||||
while thread_name.is_alive():
|
||||
cnt += 1
|
||||
chatbot[-1] = (i_say_show_user,
|
||||
f"[Local Message] {mutable[1]}waiting gpt response {cnt}/{TIMEOUT_SECONDS*2*(MAX_RETRY+1)}"+''.join(['.']*(cnt % 4)))
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
time.sleep(1)
|
||||
# 把gpt的输出从mutable中取出来
|
||||
gpt_say = mutable[0]
|
||||
if gpt_say == '[Local Message] Failed with timeout.':
|
||||
raise TimeoutError
|
||||
return gpt_say
|
||||
|
||||
|
||||
def write_results_to_file(history, file_name=None):
|
||||
"""
|
||||
将对话记录history以Markdown格式写入文件中。如果没有指定文件名,则使用当前时间生成文件名。
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
if file_name is None:
|
||||
# file_name = time.strftime("chatGPT分析报告%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
|
||||
file_name = 'chatGPT分析报告' + \
|
||||
time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
|
||||
os.makedirs('./gpt_log/', exist_ok=True)
|
||||
with open(f'./gpt_log/{file_name}', 'w', encoding='utf8') as f:
|
||||
f.write('# chatGPT 分析报告\n')
|
||||
for i, content in enumerate(history):
|
||||
try: # 这个bug没找到触发条件,暂时先这样顶一下
|
||||
if type(content) != str:
|
||||
content = str(content)
|
||||
except:
|
||||
continue
|
||||
if i % 2 == 0:
|
||||
f.write('## ')
|
||||
f.write(content)
|
||||
f.write('\n\n')
|
||||
res = '以上材料已经被写入' + os.path.abspath(f'./gpt_log/{file_name}')
|
||||
print(res)
|
||||
return res
|
||||
|
||||
|
||||
def regular_txt_to_markdown(text):
|
||||
"""
|
||||
将普通文本转换为Markdown格式的文本。
|
||||
"""
|
||||
text = text.replace('\n', '\n\n')
|
||||
text = text.replace('\n\n\n', '\n\n')
|
||||
text = text.replace('\n\n\n', '\n\n')
|
||||
return text
|
||||
|
||||
|
||||
def CatchException(f):
|
||||
"""
|
||||
装饰器函数,捕捉函数f中的异常并封装到一个生成器中返回,并显示到聊天当中。
|
||||
装饰器函数,捕捉函数f中的异常并封装到一个生成器中返回,并显示到聊天当中。
|
||||
"""
|
||||
@wraps(f)
|
||||
def decorated(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT):
|
||||
@@ -215,9 +94,70 @@ def HotReload(f):
|
||||
return decorated
|
||||
|
||||
|
||||
####################################### 其他小工具 #####################################
|
||||
|
||||
def get_reduce_token_percent(text):
|
||||
"""
|
||||
* 此函数未来将被弃用
|
||||
"""
|
||||
try:
|
||||
# text = "maximum context length is 4097 tokens. However, your messages resulted in 4870 tokens"
|
||||
pattern = r"(\d+)\s+tokens\b"
|
||||
match = re.findall(pattern, text)
|
||||
EXCEED_ALLO = 500 # 稍微留一点余地,否则在回复时会因余量太少出问题
|
||||
max_limit = float(match[0]) - EXCEED_ALLO
|
||||
current_tokens = float(match[1])
|
||||
ratio = max_limit/current_tokens
|
||||
assert ratio > 0 and ratio < 1
|
||||
return ratio, str(int(current_tokens-max_limit))
|
||||
except:
|
||||
return 0.5, '不详'
|
||||
|
||||
|
||||
|
||||
def write_results_to_file(history, file_name=None):
|
||||
"""
|
||||
将对话记录history以Markdown格式写入文件中。如果没有指定文件名,则使用当前时间生成文件名。
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
if file_name is None:
|
||||
# file_name = time.strftime("chatGPT分析报告%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
|
||||
file_name = 'chatGPT分析报告' + \
|
||||
time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
|
||||
os.makedirs('./gpt_log/', exist_ok=True)
|
||||
with open(f'./gpt_log/{file_name}', 'w', encoding='utf8') as f:
|
||||
f.write('# chatGPT 分析报告\n')
|
||||
for i, content in enumerate(history):
|
||||
try: # 这个bug没找到触发条件,暂时先这样顶一下
|
||||
if type(content) != str:
|
||||
content = str(content)
|
||||
except:
|
||||
continue
|
||||
if i % 2 == 0:
|
||||
f.write('## ')
|
||||
f.write(content)
|
||||
f.write('\n\n')
|
||||
res = '以上材料已经被写入' + os.path.abspath(f'./gpt_log/{file_name}')
|
||||
print(res)
|
||||
return res
|
||||
|
||||
|
||||
def regular_txt_to_markdown(text):
|
||||
"""
|
||||
将普通文本转换为Markdown格式的文本。
|
||||
"""
|
||||
text = text.replace('\n', '\n\n')
|
||||
text = text.replace('\n\n\n', '\n\n')
|
||||
text = text.replace('\n\n\n', '\n\n')
|
||||
return text
|
||||
|
||||
|
||||
|
||||
|
||||
def report_execption(chatbot, history, a, b):
|
||||
"""
|
||||
向chatbot中添加错误信息
|
||||
向chatbot中添加错误信息
|
||||
"""
|
||||
chatbot.append((a, b))
|
||||
history.append(a)
|
||||
@@ -226,7 +166,7 @@ def report_execption(chatbot, history, a, b):
|
||||
|
||||
def text_divide_paragraph(text):
|
||||
"""
|
||||
将文本按照段落分隔符分割开,生成带有段落标签的HTML代码。
|
||||
将文本按照段落分隔符分割开,生成带有段落标签的HTML代码。
|
||||
"""
|
||||
if '```' in text:
|
||||
# careful input
|
||||
@@ -242,7 +182,7 @@ def text_divide_paragraph(text):
|
||||
|
||||
def markdown_convertion(txt):
|
||||
"""
|
||||
将Markdown格式的文本转换为HTML格式。如果包含数学公式,则先将公式转换为HTML格式。
|
||||
将Markdown格式的文本转换为HTML格式。如果包含数学公式,则先将公式转换为HTML格式。
|
||||
"""
|
||||
pre = '<div class="markdown-body">'
|
||||
suf = '</div>'
|
||||
@@ -334,7 +274,7 @@ def close_up_code_segment_during_stream(gpt_reply):
|
||||
|
||||
def format_io(self, y):
|
||||
"""
|
||||
将输入和输出解析为HTML格式。将y中最后一项的输入部分段落化,并将输出部分的Markdown和数学公式转换为HTML格式。
|
||||
将输入和输出解析为HTML格式。将y中最后一项的输入部分段落化,并将输出部分的Markdown和数学公式转换为HTML格式。
|
||||
"""
|
||||
if y is None or y == []:
|
||||
return []
|
||||
@@ -350,7 +290,7 @@ def format_io(self, y):
|
||||
|
||||
def find_free_port():
|
||||
"""
|
||||
返回当前系统中可用的未使用端口。
|
||||
返回当前系统中可用的未使用端口。
|
||||
"""
|
||||
import socket
|
||||
from contextlib import closing
|
||||
@@ -429,7 +369,7 @@ def find_recent_files(directory):
|
||||
return recent_files
|
||||
|
||||
|
||||
def on_file_uploaded(files, chatbot, txt):
|
||||
def on_file_uploaded(files, chatbot, txt, txt2, checkboxes):
|
||||
if len(files) == 0:
|
||||
return chatbot, txt
|
||||
import shutil
|
||||
@@ -451,13 +391,18 @@ def on_file_uploaded(files, chatbot, txt):
|
||||
dest_dir=f'private_upload/{time_tag}/{file_origin_name}.extract')
|
||||
moved_files = [fp for fp in glob.glob(
|
||||
'private_upload/**/*', recursive=True)]
|
||||
txt = f'private_upload/{time_tag}'
|
||||
if "底部输入区" in checkboxes:
|
||||
txt = ""
|
||||
txt2 = f'private_upload/{time_tag}'
|
||||
else:
|
||||
txt = f'private_upload/{time_tag}'
|
||||
txt2 = ""
|
||||
moved_files_str = '\t\n\n'.join(moved_files)
|
||||
chatbot.append(['我上传了文件,请查收',
|
||||
f'[Local Message] 收到以下文件: \n\n{moved_files_str}' +
|
||||
f'\n\n调用路径参数已自动修正到: \n\n{txt}' +
|
||||
f'\n\n现在您点击任意“红颜色”标识的函数插件时,以上文件将被作为输入参数'+err_msg])
|
||||
return chatbot, txt
|
||||
return chatbot, txt, txt2
|
||||
|
||||
|
||||
def on_report_generated(files, chatbot):
|
||||
@@ -470,9 +415,43 @@ def on_report_generated(files, chatbot):
|
||||
return report_files, chatbot
|
||||
|
||||
def is_openai_api_key(key):
|
||||
# 正确的 API_KEY 是 "sk-" + 48 位大小写字母数字的组合
|
||||
API_MATCH = re.match(r"sk-[a-zA-Z0-9]{48}$", key)
|
||||
return API_MATCH
|
||||
return bool(API_MATCH)
|
||||
|
||||
def is_api2d_key(key):
|
||||
if key.startswith('fk') and len(key) == 41:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def is_any_api_key(key):
|
||||
if ',' in key:
|
||||
keys = key.split(',')
|
||||
for k in keys:
|
||||
if is_any_api_key(k): return True
|
||||
return False
|
||||
else:
|
||||
return is_openai_api_key(key) or is_api2d_key(key)
|
||||
|
||||
|
||||
def select_api_key(keys, llm_model):
|
||||
import random
|
||||
avail_key_list = []
|
||||
key_list = keys.split(',')
|
||||
|
||||
if llm_model.startswith('gpt-'):
|
||||
for k in key_list:
|
||||
if is_openai_api_key(k): avail_key_list.append(k)
|
||||
|
||||
if llm_model.startswith('api2d-'):
|
||||
for k in key_list:
|
||||
if is_api2d_key(k): avail_key_list.append(k)
|
||||
|
||||
if len(avail_key_list) == 0:
|
||||
raise RuntimeError(f"您提供的api-key不满足要求,不包含任何可用于{llm_model}的api-key。")
|
||||
|
||||
api_key = random.choice(avail_key_list) # 随机负载均衡
|
||||
return api_key
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def read_single_conf_with_lru_cache(arg):
|
||||
@@ -483,14 +462,13 @@ def read_single_conf_with_lru_cache(arg):
|
||||
r = getattr(importlib.import_module('config'), arg)
|
||||
# 在读取API_KEY时,检查一下是不是忘了改config
|
||||
if arg == 'API_KEY':
|
||||
if is_openai_api_key(r):
|
||||
if is_any_api_key(r):
|
||||
print亮绿(f"[API_KEY] 您的 API_KEY 是: {r[:15]}*** API_KEY 导入成功")
|
||||
else:
|
||||
print亮红( "[API_KEY] 正确的 API_KEY 是 'sk-' + '48 位大小写字母数字' 的组合,请在config文件中修改API密钥, 添加海外代理之后再运行。" + \
|
||||
"(如果您刚更新过代码,请确保旧版config_private文件中没有遗留任何新增键值)")
|
||||
print亮红( "[API_KEY] 正确的 API_KEY 是'sk'开头的51位密钥(OpenAI),或者 'fk'开头的41位密钥,请在config文件中修改API密钥之后再运行。")
|
||||
if arg == 'proxies':
|
||||
if r is None:
|
||||
print亮红('[PROXY] 网络代理状态:未配置。无代理状态下很可能无法访问。建议:检查USE_PROXY选项是否修改。')
|
||||
print亮红('[PROXY] 网络代理状态:未配置。无代理状态下很可能无法访问OpenAI家族的模型。建议:检查USE_PROXY选项是否修改。')
|
||||
else:
|
||||
print亮绿('[PROXY] 网络代理状态:已配置。配置信息如下:', r)
|
||||
assert isinstance(r, dict), 'proxies格式错误,请注意proxies选项的格式,不要遗漏括号。'
|
||||
|
||||
在新工单中引用
屏蔽一个用户