镜像自地址
https://github.com/binary-husky/gpt_academic.git
已同步 2025-12-08 07:26:48 +00:00
version 3.6
这个提交包含在:
221
toolbox.py
221
toolbox.py
@@ -7,10 +7,11 @@ import os
|
||||
import gradio
|
||||
import shutil
|
||||
import glob
|
||||
import math
|
||||
from latex2mathml.converter import convert as tex2mathml
|
||||
from functools import wraps, lru_cache
|
||||
pj = os.path.join
|
||||
|
||||
default_user_name = 'default_user'
|
||||
"""
|
||||
========================================================================
|
||||
第一部分
|
||||
@@ -60,11 +61,16 @@ def ArgsGeneralWrapper(f):
|
||||
txt_passon = txt
|
||||
if txt == "" and txt2 != "": txt_passon = txt2
|
||||
# 引入一个有cookie的chatbot
|
||||
if request.username is not None:
|
||||
user_name = request.username
|
||||
else:
|
||||
user_name = default_user_name
|
||||
cookies.update({
|
||||
'top_p':top_p,
|
||||
'api_key': cookies['api_key'],
|
||||
'llm_model': llm_model,
|
||||
'temperature':temperature,
|
||||
'user_name': user_name,
|
||||
})
|
||||
llm_kwargs = {
|
||||
'api_key': cookies['api_key'],
|
||||
@@ -151,13 +157,13 @@ def CatchException(f):
|
||||
except Exception as e:
|
||||
from check_proxy import check_proxy
|
||||
from toolbox import get_conf
|
||||
proxies, = get_conf('proxies')
|
||||
proxies = get_conf('proxies')
|
||||
tb_str = '```\n' + trimmed_format_exc() + '```'
|
||||
if len(chatbot_with_cookie) == 0:
|
||||
chatbot_with_cookie.clear()
|
||||
chatbot_with_cookie.append(["插件调度异常", "异常原因"])
|
||||
chatbot_with_cookie[-1] = (chatbot_with_cookie[-1][0],
|
||||
f"[Local Message] 实验性函数调用出错: \n\n{tb_str} \n\n当前代理可用性: \n\n{check_proxy(proxies)}")
|
||||
f"[Local Message] 插件调用出错: \n\n{tb_str} \n\n当前代理可用性: \n\n{check_proxy(proxies)}")
|
||||
yield from update_ui(chatbot=chatbot_with_cookie, history=history, msg=f'异常 {e}') # 刷新界面
|
||||
return decorated
|
||||
|
||||
@@ -186,7 +192,7 @@ def HotReload(f):
|
||||
其他小工具:
|
||||
- write_history_to_file: 将结果写入markdown文件中
|
||||
- regular_txt_to_markdown: 将普通文本转换为Markdown格式的文本。
|
||||
- report_execption: 向chatbot中添加简单的意外错误信息
|
||||
- report_exception: 向chatbot中添加简单的意外错误信息
|
||||
- text_divide_paragraph: 将文本按照段落分隔符分割开,生成带有段落标签的HTML代码。
|
||||
- markdown_convertion: 用多种方式组合,将markdown转化为好看的html
|
||||
- format_io: 接管gradio默认的markdown处理方式
|
||||
@@ -259,7 +265,7 @@ def regular_txt_to_markdown(text):
|
||||
|
||||
|
||||
|
||||
def report_execption(chatbot, history, a, b):
|
||||
def report_exception(chatbot, history, a, b):
|
||||
"""
|
||||
向chatbot中添加错误信息
|
||||
"""
|
||||
@@ -278,9 +284,12 @@ def text_divide_paragraph(text):
|
||||
|
||||
if '```' in text:
|
||||
# careful input
|
||||
return pre + text + suf
|
||||
return text
|
||||
elif '</div>' in text:
|
||||
# careful input
|
||||
return text
|
||||
else:
|
||||
# wtf input
|
||||
# whatever input
|
||||
lines = text.split("\n")
|
||||
for i, line in enumerate(lines):
|
||||
lines[i] = lines[i].replace(" ", " ")
|
||||
@@ -372,6 +381,26 @@ def markdown_convertion(txt):
|
||||
contain_any_eq = True
|
||||
return contain_any_eq
|
||||
|
||||
def fix_markdown_indent(txt):
|
||||
# fix markdown indent
|
||||
if (' - ' not in txt) or ('. ' not in txt):
|
||||
return txt # do not need to fix, fast escape
|
||||
# walk through the lines and fix non-standard indentation
|
||||
lines = txt.split("\n")
|
||||
pattern = re.compile(r'^\s+-')
|
||||
activated = False
|
||||
for i, line in enumerate(lines):
|
||||
if line.startswith('- ') or line.startswith('1. '):
|
||||
activated = True
|
||||
if activated and pattern.match(line):
|
||||
stripped_string = line.lstrip()
|
||||
num_spaces = len(line) - len(stripped_string)
|
||||
if (num_spaces % 4) == 3:
|
||||
num_spaces_should_be = math.ceil(num_spaces/4) * 4
|
||||
lines[i] = ' ' * num_spaces_should_be + stripped_string
|
||||
return '\n'.join(lines)
|
||||
|
||||
txt = fix_markdown_indent(txt)
|
||||
if is_equation(txt): # 有$标识的公式符号,且没有代码段```的标识
|
||||
# convert everything to html format
|
||||
split = markdown.markdown(text='---')
|
||||
@@ -513,40 +542,60 @@ def find_recent_files(directory):
|
||||
|
||||
return recent_files
|
||||
|
||||
|
||||
def file_already_in_downloadzone(file, user_path):
|
||||
try:
|
||||
parent_path = os.path.abspath(user_path)
|
||||
child_path = os.path.abspath(file)
|
||||
if os.path.samefile(os.path.commonpath([parent_path, child_path]), parent_path):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except:
|
||||
return False
|
||||
|
||||
def promote_file_to_downloadzone(file, rename_file=None, chatbot=None):
|
||||
# 将文件复制一份到下载区
|
||||
import shutil
|
||||
if rename_file is None: rename_file = f'{gen_time_str()}-{os.path.basename(file)}'
|
||||
new_path = pj(get_log_folder(), rename_file)
|
||||
# 如果已经存在,先删除
|
||||
if os.path.exists(new_path) and not os.path.samefile(new_path, file): os.remove(new_path)
|
||||
# 把文件复制过去
|
||||
if not os.path.exists(new_path): shutil.copyfile(file, new_path)
|
||||
# 将文件添加到chatbot cookie中,避免多用户干扰
|
||||
if chatbot is not None:
|
||||
user_name = get_user(chatbot)
|
||||
else:
|
||||
user_name = default_user_name
|
||||
|
||||
user_path = get_log_folder(user_name, plugin_name=None)
|
||||
if file_already_in_downloadzone(file, user_path):
|
||||
new_path = file
|
||||
else:
|
||||
user_path = get_log_folder(user_name, plugin_name='downloadzone')
|
||||
if rename_file is None: rename_file = f'{gen_time_str()}-{os.path.basename(file)}'
|
||||
new_path = pj(user_path, rename_file)
|
||||
# 如果已经存在,先删除
|
||||
if os.path.exists(new_path) and not os.path.samefile(new_path, file): os.remove(new_path)
|
||||
# 把文件复制过去
|
||||
if not os.path.exists(new_path): shutil.copyfile(file, new_path)
|
||||
# 将文件添加到chatbot cookie中
|
||||
if chatbot is not None:
|
||||
if 'files_to_promote' in chatbot._cookies: current = chatbot._cookies['files_to_promote']
|
||||
else: current = []
|
||||
chatbot._cookies.update({'files_to_promote': [new_path] + current})
|
||||
return new_path
|
||||
|
||||
|
||||
def disable_auto_promotion(chatbot):
|
||||
chatbot._cookies.update({'files_to_promote': []})
|
||||
return
|
||||
|
||||
def is_the_upload_folder(string):
|
||||
PATH_PRIVATE_UPLOAD, = get_conf('PATH_PRIVATE_UPLOAD')
|
||||
pattern = r'^PATH_PRIVATE_UPLOAD/[A-Za-z0-9_-]+/\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}$'
|
||||
pattern = pattern.replace('PATH_PRIVATE_UPLOAD', PATH_PRIVATE_UPLOAD)
|
||||
if re.match(pattern, string): return True
|
||||
else: return False
|
||||
|
||||
def del_outdated_uploads(outdate_time_seconds):
|
||||
PATH_PRIVATE_UPLOAD, = get_conf('PATH_PRIVATE_UPLOAD')
|
||||
def del_outdated_uploads(outdate_time_seconds, target_path_base=None):
|
||||
if target_path_base is None:
|
||||
user_upload_dir = get_conf('PATH_PRIVATE_UPLOAD')
|
||||
else:
|
||||
user_upload_dir = target_path_base
|
||||
current_time = time.time()
|
||||
one_hour_ago = current_time - outdate_time_seconds
|
||||
# Get a list of all subdirectories in the PATH_PRIVATE_UPLOAD folder
|
||||
# Get a list of all subdirectories in the user_upload_dir folder
|
||||
# Remove subdirectories that are older than one hour
|
||||
for subdirectory in glob.glob(f'{PATH_PRIVATE_UPLOAD}/*/*'):
|
||||
for subdirectory in glob.glob(f'{user_upload_dir}/*'):
|
||||
subdirectory_time = os.path.getmtime(subdirectory)
|
||||
if subdirectory_time < one_hour_ago:
|
||||
try: shutil.rmtree(subdirectory)
|
||||
@@ -559,17 +608,16 @@ def on_file_uploaded(request: gradio.Request, files, chatbot, txt, txt2, checkbo
|
||||
"""
|
||||
if len(files) == 0:
|
||||
return chatbot, txt
|
||||
|
||||
# 移除过时的旧文件从而节省空间&保护隐私
|
||||
outdate_time_seconds = 60
|
||||
del_outdated_uploads(outdate_time_seconds)
|
||||
|
||||
# 创建工作路径
|
||||
user_name = "default" if not request.username else request.username
|
||||
user_name = default_user_name if not request.username else request.username
|
||||
time_tag = gen_time_str()
|
||||
PATH_PRIVATE_UPLOAD, = get_conf('PATH_PRIVATE_UPLOAD')
|
||||
target_path_base = pj(PATH_PRIVATE_UPLOAD, user_name, time_tag)
|
||||
target_path_base = get_upload_folder(user_name, tag=time_tag)
|
||||
os.makedirs(target_path_base, exist_ok=True)
|
||||
|
||||
# 移除过时的旧文件从而节省空间&保护隐私
|
||||
outdate_time_seconds = 3600 # 一小时
|
||||
del_outdated_uploads(outdate_time_seconds, get_upload_folder(user_name))
|
||||
|
||||
# 逐个文件转移到目标路径
|
||||
upload_msg = ''
|
||||
@@ -604,13 +652,14 @@ def on_file_uploaded(request: gradio.Request, files, chatbot, txt, txt2, checkbo
|
||||
|
||||
|
||||
def on_report_generated(cookies, files, chatbot):
|
||||
from toolbox import find_recent_files
|
||||
PATH_LOGGING, = get_conf('PATH_LOGGING')
|
||||
# from toolbox import find_recent_files
|
||||
# PATH_LOGGING = get_conf('PATH_LOGGING')
|
||||
if 'files_to_promote' in cookies:
|
||||
report_files = cookies['files_to_promote']
|
||||
cookies.pop('files_to_promote')
|
||||
else:
|
||||
report_files = find_recent_files(PATH_LOGGING)
|
||||
report_files = []
|
||||
# report_files = find_recent_files(PATH_LOGGING)
|
||||
if len(report_files) == 0:
|
||||
return cookies, None, chatbot
|
||||
# files.extend(report_files)
|
||||
@@ -621,13 +670,34 @@ def on_report_generated(cookies, files, chatbot):
|
||||
|
||||
def load_chat_cookies():
|
||||
API_KEY, LLM_MODEL, AZURE_API_KEY = get_conf('API_KEY', 'LLM_MODEL', 'AZURE_API_KEY')
|
||||
AZURE_CFG_ARRAY, NUM_CUSTOM_BASIC_BTN = get_conf('AZURE_CFG_ARRAY', 'NUM_CUSTOM_BASIC_BTN')
|
||||
|
||||
# deal with azure openai key
|
||||
if is_any_api_key(AZURE_API_KEY):
|
||||
if is_any_api_key(API_KEY): API_KEY = API_KEY + ',' + AZURE_API_KEY
|
||||
else: API_KEY = AZURE_API_KEY
|
||||
return {'api_key': API_KEY, 'llm_model': LLM_MODEL}
|
||||
if len(AZURE_CFG_ARRAY) > 0:
|
||||
for azure_model_name, azure_cfg_dict in AZURE_CFG_ARRAY.items():
|
||||
if not azure_model_name.startswith('azure'):
|
||||
raise ValueError("AZURE_CFG_ARRAY中配置的模型必须以azure开头")
|
||||
AZURE_API_KEY_ = azure_cfg_dict["AZURE_API_KEY"]
|
||||
if is_any_api_key(AZURE_API_KEY_):
|
||||
if is_any_api_key(API_KEY): API_KEY = API_KEY + ',' + AZURE_API_KEY_
|
||||
else: API_KEY = AZURE_API_KEY_
|
||||
|
||||
customize_fn_overwrite_ = {}
|
||||
for k in range(NUM_CUSTOM_BASIC_BTN):
|
||||
customize_fn_overwrite_.update({
|
||||
"自定义按钮" + str(k+1):{
|
||||
"Title": r"",
|
||||
"Prefix": r"请在自定义菜单中定义提示词前缀.",
|
||||
"Suffix": r"请在自定义菜单中定义提示词后缀",
|
||||
}
|
||||
})
|
||||
return {'api_key': API_KEY, 'llm_model': LLM_MODEL, 'customize_fn_overwrite': customize_fn_overwrite_}
|
||||
|
||||
def is_openai_api_key(key):
|
||||
CUSTOM_API_KEY_PATTERN, = get_conf('CUSTOM_API_KEY_PATTERN')
|
||||
CUSTOM_API_KEY_PATTERN = get_conf('CUSTOM_API_KEY_PATTERN')
|
||||
if len(CUSTOM_API_KEY_PATTERN) != 0:
|
||||
API_MATCH_ORIGINAL = re.match(CUSTOM_API_KEY_PATTERN, key)
|
||||
else:
|
||||
@@ -762,6 +832,11 @@ def read_single_conf_with_lru_cache(arg):
|
||||
r = getattr(importlib.import_module('config'), arg)
|
||||
|
||||
# 在读取API_KEY时,检查一下是不是忘了改config
|
||||
if arg == 'API_URL_REDIRECT':
|
||||
oai_rd = r.get("https://api.openai.com/v1/chat/completions", None) # API_URL_REDIRECT填写格式是错误的,请阅读`https://github.com/binary-husky/gpt_academic/wiki/项目配置说明`
|
||||
if oai_rd and not oai_rd.endswith('/completions'):
|
||||
print亮红( "\n\n[API_URL_REDIRECT] API_URL_REDIRECT填错了。请阅读`https://github.com/binary-husky/gpt_academic/wiki/项目配置说明`。如果您确信自己没填错,无视此消息即可。")
|
||||
time.sleep(5)
|
||||
if arg == 'API_KEY':
|
||||
print亮蓝(f"[API_KEY] 本项目现已支持OpenAI和Azure的api-key。也支持同时填写多个api-key,如API_KEY=\"openai-key1,openai-key2,azure-key3\"")
|
||||
print亮蓝(f"[API_KEY] 您既可以在config.py中修改api-key(s),也可以在问题输入区输入临时的api-key(s),然后回车键提交后即可生效。")
|
||||
@@ -786,6 +861,7 @@ def get_conf(*args):
|
||||
for arg in args:
|
||||
r = read_single_conf_with_lru_cache(arg)
|
||||
res.append(r)
|
||||
if len(res) == 1: return res[0]
|
||||
return res
|
||||
|
||||
|
||||
@@ -857,7 +933,7 @@ def clip_history(inputs, history, tokenizer, max_token_limit):
|
||||
直到历史记录的标记数量降低到阈值以下。
|
||||
"""
|
||||
import numpy as np
|
||||
from request_llm.bridge_all import model_info
|
||||
from request_llms.bridge_all import model_info
|
||||
def get_token_num(txt):
|
||||
return len(tokenizer.encode(txt, disallowed_special=()))
|
||||
input_token_num = get_token_num(inputs)
|
||||
@@ -946,12 +1022,35 @@ def gen_time_str():
|
||||
import time
|
||||
return time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
|
||||
|
||||
def get_log_folder(user='default', plugin_name='shared'):
|
||||
PATH_LOGGING, = get_conf('PATH_LOGGING')
|
||||
_dir = pj(PATH_LOGGING, user, plugin_name)
|
||||
def get_log_folder(user=default_user_name, plugin_name='shared'):
|
||||
if user is None: user = default_user_name
|
||||
PATH_LOGGING = get_conf('PATH_LOGGING')
|
||||
if plugin_name is None:
|
||||
_dir = pj(PATH_LOGGING, user)
|
||||
else:
|
||||
_dir = pj(PATH_LOGGING, user, plugin_name)
|
||||
if not os.path.exists(_dir): os.makedirs(_dir)
|
||||
return _dir
|
||||
|
||||
def get_upload_folder(user=default_user_name, tag=None):
|
||||
PATH_PRIVATE_UPLOAD = get_conf('PATH_PRIVATE_UPLOAD')
|
||||
if user is None: user = default_user_name
|
||||
if tag is None or len(tag)==0:
|
||||
target_path_base = pj(PATH_PRIVATE_UPLOAD, user)
|
||||
else:
|
||||
target_path_base = pj(PATH_PRIVATE_UPLOAD, user, tag)
|
||||
return target_path_base
|
||||
|
||||
def is_the_upload_folder(string):
|
||||
PATH_PRIVATE_UPLOAD = get_conf('PATH_PRIVATE_UPLOAD')
|
||||
pattern = r'^PATH_PRIVATE_UPLOAD[\\/][A-Za-z0-9_-]+[\\/]\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}$'
|
||||
pattern = pattern.replace('PATH_PRIVATE_UPLOAD', PATH_PRIVATE_UPLOAD)
|
||||
if re.match(pattern, string): return True
|
||||
else: return False
|
||||
|
||||
def get_user(chatbotwithcookies):
|
||||
return chatbotwithcookies._cookies.get('user_name', default_user_name)
|
||||
|
||||
class ProxyNetworkActivate():
|
||||
"""
|
||||
这段代码定义了一个名为TempProxy的空上下文管理器, 用于给一小段代码上代理
|
||||
@@ -964,13 +1063,13 @@ class ProxyNetworkActivate():
|
||||
else:
|
||||
# 给定了task, 我们检查一下
|
||||
from toolbox import get_conf
|
||||
WHEN_TO_USE_PROXY, = get_conf('WHEN_TO_USE_PROXY')
|
||||
WHEN_TO_USE_PROXY = get_conf('WHEN_TO_USE_PROXY')
|
||||
self.valid = (task in WHEN_TO_USE_PROXY)
|
||||
|
||||
def __enter__(self):
|
||||
if not self.valid: return self
|
||||
from toolbox import get_conf
|
||||
proxies, = get_conf('proxies')
|
||||
proxies = get_conf('proxies')
|
||||
if 'no_proxy' in os.environ: os.environ.pop('no_proxy')
|
||||
if proxies is not None:
|
||||
if 'http' in proxies: os.environ['HTTP_PROXY'] = proxies['http']
|
||||
@@ -1012,7 +1111,7 @@ def Singleton(cls):
|
||||
"""
|
||||
========================================================================
|
||||
第四部分
|
||||
接驳虚空终端:
|
||||
接驳void-terminal:
|
||||
- set_conf: 在运行过程中动态地修改配置
|
||||
- set_multi_conf: 在运行过程中动态地修改多个配置
|
||||
- get_plugin_handle: 获取插件的句柄
|
||||
@@ -1027,7 +1126,7 @@ def set_conf(key, value):
|
||||
read_single_conf_with_lru_cache.cache_clear()
|
||||
get_conf.cache_clear()
|
||||
os.environ[key] = str(value)
|
||||
altered, = get_conf(key)
|
||||
altered = get_conf(key)
|
||||
return altered
|
||||
|
||||
def set_multi_conf(dic):
|
||||
@@ -1048,20 +1147,17 @@ def get_plugin_handle(plugin_name):
|
||||
def get_chat_handle():
|
||||
"""
|
||||
"""
|
||||
from request_llm.bridge_all import predict_no_ui_long_connection
|
||||
from request_llms.bridge_all import predict_no_ui_long_connection
|
||||
return predict_no_ui_long_connection
|
||||
|
||||
def get_plugin_default_kwargs():
|
||||
"""
|
||||
"""
|
||||
from toolbox import get_conf, ChatBotWithCookies
|
||||
|
||||
WEB_PORT, LLM_MODEL, API_KEY = \
|
||||
get_conf('WEB_PORT', 'LLM_MODEL', 'API_KEY')
|
||||
|
||||
from toolbox import ChatBotWithCookies
|
||||
cookies = load_chat_cookies()
|
||||
llm_kwargs = {
|
||||
'api_key': API_KEY,
|
||||
'llm_model': LLM_MODEL,
|
||||
'api_key': cookies['api_key'],
|
||||
'llm_model': cookies['llm_model'],
|
||||
'top_p':1.0,
|
||||
'max_length': None,
|
||||
'temperature':1.0,
|
||||
@@ -1076,25 +1172,21 @@ def get_plugin_default_kwargs():
|
||||
"chatbot_with_cookie": chatbot,
|
||||
"history": [],
|
||||
"system_prompt": "You are a good AI.",
|
||||
"web_port": WEB_PORT
|
||||
"web_port": None
|
||||
}
|
||||
return DEFAULT_FN_GROUPS_kwargs
|
||||
|
||||
def get_chat_default_kwargs():
|
||||
"""
|
||||
"""
|
||||
from toolbox import get_conf
|
||||
|
||||
LLM_MODEL, API_KEY = get_conf('LLM_MODEL', 'API_KEY')
|
||||
|
||||
cookies = load_chat_cookies()
|
||||
llm_kwargs = {
|
||||
'api_key': API_KEY,
|
||||
'llm_model': LLM_MODEL,
|
||||
'api_key': cookies['api_key'],
|
||||
'llm_model': cookies['llm_model'],
|
||||
'top_p':1.0,
|
||||
'max_length': None,
|
||||
'temperature':1.0,
|
||||
}
|
||||
|
||||
default_chat_kwargs = {
|
||||
"inputs": "Hello there, are you ready?",
|
||||
"llm_kwargs": llm_kwargs,
|
||||
@@ -1106,3 +1198,12 @@ def get_chat_default_kwargs():
|
||||
|
||||
return default_chat_kwargs
|
||||
|
||||
def get_max_token(llm_kwargs):
|
||||
from request_llms.bridge_all import model_info
|
||||
return model_info[llm_kwargs['llm_model']]['max_token']
|
||||
|
||||
def check_packages(packages=[]):
|
||||
import importlib.util
|
||||
for p in packages:
|
||||
spam_spec = importlib.util.find_spec(p)
|
||||
if spam_spec is None: raise ModuleNotFoundError
|
||||
在新工单中引用
屏蔽一个用户