适配 google gemini 优化为从用户input中提取文件 (#1419)

适配 google gemini 优化为从用户input中提取文件
这个提交包含在:
XIao
2023-12-31 17:13:50 +08:00
提交者 qingxu fu
父节点 a96f842b3a
当前提交 a7c960dcb0
共有 5 个文件被更改,包括 472 次插入95 次删除

查看文件

@@ -11,8 +11,10 @@ import glob
import math
from latex2mathml.converter import convert as tex2mathml
from functools import wraps, lru_cache
pj = os.path.join
default_user_name = 'default_user'
"""
========================================================================
第一部分
@@ -26,6 +28,7 @@ default_user_name = 'default_user'
========================================================================
"""
class ChatBotWithCookies(list):
def __init__(self, cookie):
"""
@@ -67,18 +70,18 @@ def ArgsGeneralWrapper(f):
else:
user_name = default_user_name
cookies.update({
'top_p':top_p,
'top_p': top_p,
'api_key': cookies['api_key'],
'llm_model': llm_model,
'temperature':temperature,
'temperature': temperature,
'user_name': user_name,
})
llm_kwargs = {
'api_key': cookies['api_key'],
'llm_model': llm_model,
'top_p':top_p,
'top_p': top_p,
'max_length': max_length,
'temperature':temperature,
'temperature': temperature,
'client_ip': request.client.host,
'most_recent_uploaded': cookies.get('most_recent_uploaded')
}
@@ -87,7 +90,7 @@ def ArgsGeneralWrapper(f):
}
chatbot_with_cookie = ChatBotWithCookies(cookies)
chatbot_with_cookie.write_list(chatbot)
if cookies.get('lock_plugin', None) is None:
# 正常状态
if len(args) == 0: # 插件通道
@@ -103,8 +106,10 @@ def ArgsGeneralWrapper(f):
final_cookies = chatbot_with_cookie.get_cookies()
# len(args) != 0 代表“提交”键对话通道,或者基础功能通道
if len(args) != 0 and 'files_to_promote' in final_cookies and len(final_cookies['files_to_promote']) > 0:
chatbot_with_cookie.append(["检测到**滞留的缓存文档**,请及时处理。", "请及时点击“**保存当前对话**”获取所有滞留文档。"])
chatbot_with_cookie.append(
["检测到**滞留的缓存文档**,请及时处理。", "请及时点击“**保存当前对话**”获取所有滞留文档。"])
yield from update_ui(chatbot_with_cookie, final_cookies['history'], msg="检测到被滞留的缓存文档")
return decorated
@@ -129,6 +134,7 @@ def update_ui(chatbot, history, msg='正常', **kwargs): # 刷新界面
yield cookies, chatbot_gr, history, msg
def update_ui_lastest_msg(lastmsg, chatbot, history, delay=1): # 刷新界面
"""
刷新用户界面
@@ -147,6 +153,7 @@ def trimmed_format_exc():
replace_path = "."
return str.replace(current_path, replace_path)
def CatchException(f):
"""
装饰器函数,捕捉函数f中的异常并封装到一个生成器中返回,并显示到聊天当中。
@@ -164,9 +171,9 @@ def CatchException(f):
if len(chatbot_with_cookie) == 0:
chatbot_with_cookie.clear()
chatbot_with_cookie.append(["插件调度异常", "异常原因"])
chatbot_with_cookie[-1] = (chatbot_with_cookie[-1][0],
f"[Local Message] 插件调用出错: \n\n{tb_str} \n\n当前代理可用性: \n\n{check_proxy(proxies)}")
yield from update_ui(chatbot=chatbot_with_cookie, history=history, msg=f'异常 {e}') # 刷新界面
chatbot_with_cookie[-1] = (chatbot_with_cookie[-1][0], f"[Local Message] 插件调用出错: \n\n{tb_str} \n")
yield from update_ui(chatbot=chatbot_with_cookie, history=history, msg=f'异常 {e}') # 刷新界面
return decorated
@@ -209,6 +216,7 @@ def HotReload(f):
========================================================================
"""
def get_reduce_token_percent(text):
"""
* 此函数未来将被弃用
@@ -220,9 +228,9 @@ def get_reduce_token_percent(text):
EXCEED_ALLO = 500 # 稍微留一点余地,否则在回复时会因余量太少出问题
max_limit = float(match[0]) - EXCEED_ALLO
current_tokens = float(match[1])
ratio = max_limit/current_tokens
ratio = max_limit / current_tokens
assert ratio > 0 and ratio < 1
return ratio, str(int(current_tokens-max_limit))
return ratio, str(int(current_tokens - max_limit))
except:
return 0.5, '不详'
@@ -242,7 +250,7 @@ def write_history_to_file(history, file_basename=None, file_fullname=None, auto_
with open(file_fullname, 'w', encoding='utf8') as f:
f.write('# GPT-Academic Report\n')
for i, content in enumerate(history):
try:
try:
if type(content) != str: content = str(content)
except:
continue
@@ -268,8 +276,6 @@ def regular_txt_to_markdown(text):
return text
def report_exception(chatbot, history, a, b):
"""
向chatbot中添加错误信息
@@ -286,7 +292,7 @@ def text_divide_paragraph(text):
suf = '</div>'
if text.startswith(pre) and text.endswith(suf):
return text
if '```' in text:
# careful input
return text
@@ -312,7 +318,7 @@ def markdown_convertion(txt):
if txt.startswith(pre) and txt.endswith(suf):
# print('警告,输入了已经经过转化的字符串,二次转化可能出问题')
return txt # 已经被转化过,不需要再次转化
markdown_extension_configs = {
'mdx_math': {
'enable_dollar_delimiter': True,
@@ -352,7 +358,8 @@ def markdown_convertion(txt):
"""
解决一个mdx_math的bug单$包裹begin命令时多余<script>
"""
content = content.replace('<script type="math/tex">\n<script type="math/tex; mode=display">', '<script type="math/tex; mode=display">')
content = content.replace('<script type="math/tex">\n<script type="math/tex; mode=display">',
'<script type="math/tex; mode=display">')
content = content.replace('</script>\n</script>', '</script>')
return content
@@ -363,16 +370,16 @@ def markdown_convertion(txt):
if '```' in txt and '```reference' not in txt: return False
if '$' not in txt and '\\[' not in txt: return False
mathpatterns = {
r'(?<!\\|\$)(\$)([^\$]+)(\$)': {'allow_multi_lines': False}, #  $...$
r'(?<!\\)(\$\$)([^\$]+)(\$\$)': {'allow_multi_lines': True}, # $$...$$
r'(?<!\\)(\\\[)(.+?)(\\\])': {'allow_multi_lines': False}, # \[...\]
# r'(?<!\\)(\\\()(.+?)(\\\))': {'allow_multi_lines': False}, # \(...\)
# r'(?<!\\)(\\begin{([a-z]+?\*?)})(.+?)(\\end{\2})': {'allow_multi_lines': True}, # \begin...\end
# r'(?<!\\)(\$`)([^`]+)(`\$)': {'allow_multi_lines': False}, # $`...`$
r'(?<!\\|\$)(\$)([^\$]+)(\$)': {'allow_multi_lines': False}, #  $...$
r'(?<!\\)(\$\$)([^\$]+)(\$\$)': {'allow_multi_lines': True}, # $$...$$
r'(?<!\\)(\\\[)(.+?)(\\\])': {'allow_multi_lines': False}, # \[...\]
# r'(?<!\\)(\\\()(.+?)(\\\))': {'allow_multi_lines': False}, # \(...\)
# r'(?<!\\)(\\begin{([a-z]+?\*?)})(.+?)(\\end{\2})': {'allow_multi_lines': True}, # \begin...\end
# r'(?<!\\)(\$`)([^`]+)(`\$)': {'allow_multi_lines': False}, # $`...`$
}
matches = []
for pattern, property in mathpatterns.items():
flags = re.ASCII|re.DOTALL if property['allow_multi_lines'] else re.ASCII
flags = re.ASCII | re.DOTALL if property['allow_multi_lines'] else re.ASCII
matches.extend(re.findall(pattern, txt, flags))
if len(matches) == 0: return False
contain_any_eq = False
@@ -380,16 +387,16 @@ def markdown_convertion(txt):
for match in matches:
if len(match) != 3: return False
eq_canidate = match[1]
if illegal_pattern.search(eq_canidate):
if illegal_pattern.search(eq_canidate):
return False
else:
else:
contain_any_eq = True
return contain_any_eq
def fix_markdown_indent(txt):
# fix markdown indent
if (' - ' not in txt) or ('. ' not in txt):
return txt # do not need to fix, fast escape
if (' - ' not in txt) or ('. ' not in txt):
return txt # do not need to fix, fast escape
# walk through the lines and fix non-standard indentation
lines = txt.split("\n")
pattern = re.compile(r'^\s+-')
@@ -401,7 +408,7 @@ def markdown_convertion(txt):
stripped_string = line.lstrip()
num_spaces = len(line) - len(stripped_string)
if (num_spaces % 4) == 3:
num_spaces_should_be = math.ceil(num_spaces/4) * 4
num_spaces_should_be = math.ceil(num_spaces / 4) * 4
lines[i] = ' ' * num_spaces_should_be + stripped_string
return '\n'.join(lines)
@@ -409,7 +416,8 @@ def markdown_convertion(txt):
if is_equation(txt): # 有$标识的公式符号,且没有代码段```的标识
# convert everything to html format
split = markdown.markdown(text='---')
convert_stage_1 = markdown.markdown(text=txt, extensions=['sane_lists', 'tables', 'mdx_math', 'fenced_code'], extension_configs=markdown_extension_configs)
convert_stage_1 = markdown.markdown(text=txt, extensions=['sane_lists', 'tables', 'mdx_math', 'fenced_code'],
extension_configs=markdown_extension_configs)
convert_stage_1 = markdown_bug_hunt(convert_stage_1)
# 1. convert to easy-to-copy tex (do not render math)
convert_stage_2_1, n = re.subn(find_equation_pattern, replace_math_no_render, convert_stage_1, flags=re.DOTALL)
@@ -441,8 +449,7 @@ def close_up_code_segment_during_stream(gpt_reply):
segments = gpt_reply.split('```')
n_mark = len(segments) - 1
if n_mark % 2 == 1:
# print('输出代码片段中!')
return gpt_reply+'\n```'
return gpt_reply + '\n```' # 输出代码片段中!
else:
return gpt_reply
@@ -533,7 +540,7 @@ def find_recent_files(directory):
current_time = time.time()
one_minute_ago = current_time - 60
recent_files = []
if not os.path.exists(directory):
if not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
for filename in os.listdir(directory):
file_path = pj(directory, filename)
@@ -559,6 +566,7 @@ def file_already_in_downloadzone(file, user_path):
except:
return False
def promote_file_to_downloadzone(file, rename_file=None, chatbot=None):
# 将文件复制一份到下载区
import shutil
@@ -581,8 +589,10 @@ def promote_file_to_downloadzone(file, rename_file=None, chatbot=None):
if not os.path.exists(new_path): shutil.copyfile(file, new_path)
# 将文件添加到chatbot cookie中
if chatbot is not None:
if 'files_to_promote' in chatbot._cookies: current = chatbot._cookies['files_to_promote']
else: current = []
if 'files_to_promote' in chatbot._cookies:
current = chatbot._cookies['files_to_promote']
else:
current = []
if new_path not in current: # 避免把同一个文件添加多次
chatbot._cookies.update({'files_to_promote': [new_path] + current})
return new_path
@@ -605,8 +615,10 @@ def del_outdated_uploads(outdate_time_seconds, target_path_base=None):
for subdirectory in glob.glob(f'{user_upload_dir}/*'):
subdirectory_time = os.path.getmtime(subdirectory)
if subdirectory_time < one_hour_ago:
try: shutil.rmtree(subdirectory)
except: pass
try:
shutil.rmtree(subdirectory)
except:
pass
return
@@ -679,9 +691,9 @@ def on_file_uploaded(request: gradio.Request, files, chatbot, txt, txt2, checkbo
time_tag = gen_time_str()
target_path_base = get_upload_folder(user_name, tag=time_tag)
os.makedirs(target_path_base, exist_ok=True)
# 移除过时的旧文件从而节省空间&保护隐私
outdate_time_seconds = 3600 # 一小时
outdate_time_seconds = 3600 # 一小时
del_outdated_uploads(outdate_time_seconds, get_upload_folder(user_name))
# 逐个文件转移到目标路径
@@ -690,21 +702,20 @@ def on_file_uploaded(request: gradio.Request, files, chatbot, txt, txt2, checkbo
file_origin_name = os.path.basename(file.orig_name)
this_file_path = pj(target_path_base, file_origin_name)
shutil.move(file.name, this_file_path)
upload_msg += extract_archive(file_path=this_file_path, dest_dir=this_file_path+'.extract')
if "浮动输入区" in checkboxes:
txt, txt2 = "", target_path_base
else:
txt, txt2 = target_path_base, ""
upload_msg += extract_archive(file_path=this_file_path, dest_dir=this_file_path + '.extract')
# 整理文件集合 输出消息
moved_files = [fp for fp in glob.glob(f'{target_path_base}/**/*', recursive=True)]
moved_files_str = to_markdown_tabs(head=['文件'], tabs=[moved_files])
chatbot.append(['我上传了文件,请查收',
chatbot.append(['我上传了文件,请查收',
f'[Local Message] 收到以下文件: \n\n{moved_files_str}' +
f'\n\n调用路径参数已自动修正到: \n\n{txt}' +
f'\n\n现在您点击任意函数插件时,以上文件将被作为输入参数'+upload_msg])
f'\n\n现在您点击任意函数插件时,以上文件将被作为输入参数' + upload_msg])
txt, txt2 = target_path_base, ""
if "浮动输入区" in checkboxes:
txt, txt2 = txt2, txt
# 记录近期文件
cookies.update({
'most_recent_uploaded': {
@@ -732,34 +743,40 @@ def on_report_generated(cookies, files, chatbot):
chatbot.append(['报告如何远程获取?', f'报告已经添加到右侧“文件上传区”(可能处于折叠状态),请查收。{file_links}'])
return cookies, report_files, chatbot
def load_chat_cookies():
API_KEY, LLM_MODEL, AZURE_API_KEY = get_conf('API_KEY', 'LLM_MODEL', 'AZURE_API_KEY')
AZURE_CFG_ARRAY, NUM_CUSTOM_BASIC_BTN = get_conf('AZURE_CFG_ARRAY', 'NUM_CUSTOM_BASIC_BTN')
# deal with azure openai key
if is_any_api_key(AZURE_API_KEY):
if is_any_api_key(API_KEY): API_KEY = API_KEY + ',' + AZURE_API_KEY
else: API_KEY = AZURE_API_KEY
if is_any_api_key(API_KEY):
API_KEY = API_KEY + ',' + AZURE_API_KEY
else:
API_KEY = AZURE_API_KEY
if len(AZURE_CFG_ARRAY) > 0:
for azure_model_name, azure_cfg_dict in AZURE_CFG_ARRAY.items():
if not azure_model_name.startswith('azure'):
if not azure_model_name.startswith('azure'):
raise ValueError("AZURE_CFG_ARRAY中配置的模型必须以azure开头")
AZURE_API_KEY_ = azure_cfg_dict["AZURE_API_KEY"]
if is_any_api_key(AZURE_API_KEY_):
if is_any_api_key(API_KEY): API_KEY = API_KEY + ',' + AZURE_API_KEY_
else: API_KEY = AZURE_API_KEY_
if is_any_api_key(API_KEY):
API_KEY = API_KEY + ',' + AZURE_API_KEY_
else:
API_KEY = AZURE_API_KEY_
customize_fn_overwrite_ = {}
for k in range(NUM_CUSTOM_BASIC_BTN):
customize_fn_overwrite_.update({
customize_fn_overwrite_.update({
"自定义按钮" + str(k+1):{
"Title": r"",
"Prefix": r"请在自定义菜单中定义提示词前缀.",
"Suffix": r"请在自定义菜单中定义提示词后缀",
"Title": r"",
"Prefix": r"请在自定义菜单中定义提示词前缀.",
"Suffix": r"请在自定义菜单中定义提示词后缀",
}
})
return {'api_key': API_KEY, 'llm_model': LLM_MODEL, 'customize_fn_overwrite': customize_fn_overwrite_}
def is_openai_api_key(key):
CUSTOM_API_KEY_PATTERN = get_conf('CUSTOM_API_KEY_PATTERN')
if len(CUSTOM_API_KEY_PATTERN) != 0:
@@ -768,14 +785,17 @@ def is_openai_api_key(key):
API_MATCH_ORIGINAL = re.match(r"sk-[a-zA-Z0-9]{48}$", key)
return bool(API_MATCH_ORIGINAL)
def is_azure_api_key(key):
API_MATCH_AZURE = re.match(r"[a-zA-Z0-9]{32}$", key)
return bool(API_MATCH_AZURE)
def is_api2d_key(key):
API_MATCH_API2D = re.match(r"fk[a-zA-Z0-9]{6}-[a-zA-Z0-9]{32}$", key)
return bool(API_MATCH_API2D)
def is_any_api_key(key):
if ',' in key:
keys = key.split(',')
@@ -785,24 +805,26 @@ def is_any_api_key(key):
else:
return is_openai_api_key(key) or is_api2d_key(key) or is_azure_api_key(key)
def what_keys(keys):
avail_key_list = {'OpenAI Key':0, "Azure Key":0, "API2D Key":0}
avail_key_list = {'OpenAI Key': 0, "Azure Key": 0, "API2D Key": 0}
key_list = keys.split(',')
for k in key_list:
if is_openai_api_key(k):
if is_openai_api_key(k):
avail_key_list['OpenAI Key'] += 1
for k in key_list:
if is_api2d_key(k):
if is_api2d_key(k):
avail_key_list['API2D Key'] += 1
for k in key_list:
if is_azure_api_key(k):
if is_azure_api_key(k):
avail_key_list['Azure Key'] += 1
return f"检测到: OpenAI Key {avail_key_list['OpenAI Key']} 个, Azure Key {avail_key_list['Azure Key']} 个, API2D Key {avail_key_list['API2D Key']}"
def select_api_key(keys, llm_model):
import random
avail_key_list = []
@@ -826,6 +848,7 @@ def select_api_key(keys, llm_model):
api_key = random.choice(avail_key_list) # 随机负载均衡
return api_key
def read_env_variable(arg, default_value):
"""
环境变量可以是 `GPT_ACADEMIC_CONFIG`(优先),也可以直接是`CONFIG`
@@ -843,10 +866,10 @@ def read_env_variable(arg, default_value):
set GPT_ACADEMIC_AUTHENTICATION=[("username", "password"), ("username2", "password2")]
"""
from colorful import print亮红, print亮绿
arg_with_prefix = "GPT_ACADEMIC_" + arg
if arg_with_prefix in os.environ:
arg_with_prefix = "GPT_ACADEMIC_" + arg
if arg_with_prefix in os.environ:
env_arg = os.environ[arg_with_prefix]
elif arg in os.environ:
elif arg in os.environ:
env_arg = os.environ[arg]
else:
raise KeyError
@@ -856,7 +879,7 @@ def read_env_variable(arg, default_value):
env_arg = env_arg.strip()
if env_arg == 'True': r = True
elif env_arg == 'False': r = False
else: print('enter True or False, but have:', env_arg); r = default_value
else: print('Enter True or False, but have:', env_arg); r = default_value
elif isinstance(default_value, int):
r = int(env_arg)
elif isinstance(default_value, float):
@@ -880,13 +903,14 @@ def read_env_variable(arg, default_value):
print亮绿(f"[ENV_VAR] 成功读取环境变量{arg}")
return r
@lru_cache(maxsize=128)
def read_single_conf_with_lru_cache(arg):
from colorful import print亮红, print亮绿, print亮蓝
try:
# 优先级1. 获取环境变量作为配置
default_ref = getattr(importlib.import_module('config'), arg) # 读取默认值作为数据类型转换的参考
r = read_env_variable(arg, default_ref)
default_ref = getattr(importlib.import_module('config'), arg) # 读取默认值作为数据类型转换的参考
r = read_env_variable(arg, default_ref)
except:
try:
# 优先级2. 获取config_private中的配置
@@ -899,7 +923,7 @@ def read_single_conf_with_lru_cache(arg):
if arg == 'API_URL_REDIRECT':
oai_rd = r.get("https://api.openai.com/v1/chat/completions", None) # API_URL_REDIRECT填写格式是错误的,请阅读`https://github.com/binary-husky/gpt_academic/wiki/项目配置说明`
if oai_rd and not oai_rd.endswith('/completions'):
print亮红( "\n\n[API_URL_REDIRECT] API_URL_REDIRECT填错了。请阅读`https://github.com/binary-husky/gpt_academic/wiki/项目配置说明`。如果您确信自己没填错,无视此消息即可。")
print亮红("\n\n[API_URL_REDIRECT] API_URL_REDIRECT填错了。请阅读`https://github.com/binary-husky/gpt_academic/wiki/项目配置说明`。如果您确信自己没填错,无视此消息即可。")
time.sleep(5)
if arg == 'API_KEY':
print亮蓝(f"[API_KEY] 本项目现已支持OpenAI和Azure的api-key。也支持同时填写多个api-key,如API_KEY=\"openai-key1,openai-key2,azure-key3\"")
@@ -907,9 +931,9 @@ def read_single_conf_with_lru_cache(arg):
if is_any_api_key(r):
print亮绿(f"[API_KEY] 您的 API_KEY 是: {r[:15]}*** API_KEY 导入成功")
else:
print亮红( "[API_KEY] 您的 API_KEY 不满足任何一种已知的密钥格式,请在config文件中修改API密钥之后再运行。")
print亮红("[API_KEY] 您的 API_KEY 不满足任何一种已知的密钥格式,请在config文件中修改API密钥之后再运行。")
if arg == 'proxies':
if not read_single_conf_with_lru_cache('USE_PROXY'): r = None # 检查USE_PROXY,防止proxies单独起作用
if not read_single_conf_with_lru_cache('USE_PROXY'): r = None # 检查USE_PROXY,防止proxies单独起作用
if r is None:
print亮红('[PROXY] 网络代理状态未配置。无代理状态下很可能无法访问OpenAI家族的模型。建议检查USE_PROXY选项是否修改。')
else:
@@ -953,17 +977,20 @@ class DummyWith():
在上下文执行开始的情况下,__enter__()方法会在代码块被执行前被调用,
而在上下文执行结束时,__exit__()方法则会被调用。
"""
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
return
def run_gradio_in_subpath(demo, auth, port, custom_path):
"""
把gradio的运行地址更改到指定的二次路径上
"""
def is_path_legal(path: str)->bool:
def is_path_legal(path: str) -> bool:
'''
check path for sub url
path: path to check
@@ -988,7 +1015,7 @@ def run_gradio_in_subpath(demo, auth, port, custom_path):
app = FastAPI()
if custom_path != "/":
@app.get("/")
def read_main():
def read_main():
return {"message": f"Gradio is running at: {custom_path}"}
app = gr.mount_gradio_app(app, demo, path=custom_path)
uvicorn.run(app, host="0.0.0.0", port=port) # , auth=auth
@@ -999,13 +1026,13 @@ def clip_history(inputs, history, tokenizer, max_token_limit):
reduce the length of history by clipping.
this function search for the longest entries to clip, little by little,
until the number of token of history is reduced under threshold.
通过裁剪来缩短历史记录的长度。
通过裁剪来缩短历史记录的长度。
此函数逐渐地搜索最长的条目进行剪辑,
直到历史记录的标记数量降低到阈值以下。
"""
import numpy as np
from request_llms.bridge_all import model_info
def get_token_num(txt):
def get_token_num(txt):
return len(tokenizer.encode(txt, disallowed_special=()))
input_token_num = get_token_num(inputs)
@@ -1039,14 +1066,15 @@ def clip_history(inputs, history, tokenizer, max_token_limit):
while n_token > max_token_limit:
where = np.argmax(everything_token)
encoded = tokenizer.encode(everything[where], disallowed_special=())
clipped_encoded = encoded[:len(encoded)-delta]
everything[where] = tokenizer.decode(clipped_encoded)[:-1] # -1 to remove the may-be illegal char
clipped_encoded = encoded[:len(encoded) - delta]
everything[where] = tokenizer.decode(clipped_encoded)[:-1] # -1 to remove the may-be illegal char
everything_token[where] = get_token_num(everything[where])
n_token = get_token_num('\n'.join(everything))
history = everything[1:]
return history
"""
========================================================================
第三部分
@@ -1058,6 +1086,7 @@ def clip_history(inputs, history, tokenizer, max_token_limit):
========================================================================
"""
def zip_folder(source_folder, dest_folder, zip_name):
import zipfile
import os
@@ -1089,15 +1118,18 @@ def zip_folder(source_folder, dest_folder, zip_name):
print(f"Zip file created at {zip_file}")
def zip_result(folder):
t = gen_time_str()
zip_folder(folder, get_log_folder(), f'{t}-result.zip')
return pj(get_log_folder(), f'{t}-result.zip')
def gen_time_str():
import time
return time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
def get_log_folder(user=default_user_name, plugin_name='shared'):
if user is None: user = default_user_name
PATH_LOGGING = get_conf('PATH_LOGGING')
@@ -1108,29 +1140,36 @@ def get_log_folder(user=default_user_name, plugin_name='shared'):
if not os.path.exists(_dir): os.makedirs(_dir)
return _dir
def get_upload_folder(user=default_user_name, tag=None):
PATH_PRIVATE_UPLOAD = get_conf('PATH_PRIVATE_UPLOAD')
if user is None: user = default_user_name
if tag is None or len(tag)==0:
if tag is None or len(tag) == 0:
target_path_base = pj(PATH_PRIVATE_UPLOAD, user)
else:
target_path_base = pj(PATH_PRIVATE_UPLOAD, user, tag)
return target_path_base
def is_the_upload_folder(string):
PATH_PRIVATE_UPLOAD = get_conf('PATH_PRIVATE_UPLOAD')
pattern = r'^PATH_PRIVATE_UPLOAD[\\/][A-Za-z0-9_-]+[\\/]\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}$'
pattern = pattern.replace('PATH_PRIVATE_UPLOAD', PATH_PRIVATE_UPLOAD)
if re.match(pattern, string): return True
else: return False
if re.match(pattern, string):
return True
else:
return False
def get_user(chatbotwithcookies):
return chatbotwithcookies._cookies.get('user_name', default_user_name)
class ProxyNetworkActivate():
"""
这段代码定义了一个名为ProxyNetworkActivate的空上下文管理器, 用于给一小段代码上代理
"""
def __init__(self, task=None) -> None:
self.task = task
if not task:
@@ -1158,32 +1197,36 @@ class ProxyNetworkActivate():
if 'HTTPS_PROXY' in os.environ: os.environ.pop('HTTPS_PROXY')
return
def objdump(obj, file='objdump.tmp'):
import pickle
with open(file, 'wb+') as f:
pickle.dump(obj, f)
return
def objload(file='objdump.tmp'):
import pickle, os
if not os.path.exists(file):
if not os.path.exists(file):
return
with open(file, 'rb') as f:
return pickle.load(f)
def Singleton(cls):
"""
一个单实例装饰器
"""
_instance = {}
def _singleton(*args, **kargs):
if cls not in _instance:
_instance[cls] = cls(*args, **kargs)
return _instance[cls]
return _singleton
"""
========================================================================
第四部分
@@ -1197,6 +1240,7 @@ def Singleton(cls):
========================================================================
"""
def set_conf(key, value):
from toolbox import read_single_conf_with_lru_cache, get_conf
read_single_conf_with_lru_cache.cache_clear()
@@ -1205,10 +1249,12 @@ def set_conf(key, value):
altered = get_conf(key)
return altered
def set_multi_conf(dic):
for k, v in dic.items(): set_conf(k, v)
return
def get_plugin_handle(plugin_name):
"""
e.g. plugin_name = 'crazy_functions.批量Markdown翻译->Markdown翻译指定语言'
@@ -1220,12 +1266,14 @@ def get_plugin_handle(plugin_name):
f_hot_reload = getattr(importlib.import_module(module, fn_name), fn_name)
return f_hot_reload
def get_chat_handle():
"""
"""
from request_llms.bridge_all import predict_no_ui_long_connection
return predict_no_ui_long_connection
def get_plugin_default_kwargs():
"""
"""
@@ -1234,9 +1282,9 @@ def get_plugin_default_kwargs():
llm_kwargs = {
'api_key': cookies['api_key'],
'llm_model': cookies['llm_model'],
'top_p':1.0,
'top_p': 1.0,
'max_length': None,
'temperature':1.0,
'temperature': 1.0,
}
chatbot = ChatBotWithCookies(llm_kwargs)
@@ -1247,11 +1295,12 @@ def get_plugin_default_kwargs():
"plugin_kwargs": {},
"chatbot_with_cookie": chatbot,
"history": [],
"system_prompt": "You are a good AI.",
"system_prompt": "You are a good AI.",
"web_port": None
}
return DEFAULT_FN_GROUPS_kwargs
def get_chat_default_kwargs():
"""
"""
@@ -1259,9 +1308,9 @@ def get_chat_default_kwargs():
llm_kwargs = {
'api_key': cookies['api_key'],
'llm_model': cookies['llm_model'],
'top_p':1.0,
'top_p': 1.0,
'max_length': None,
'temperature':1.0,
'temperature': 1.0,
}
default_chat_kwargs = {
"inputs": "Hello there, are you ready?",
@@ -1284,15 +1333,15 @@ def get_pictures_list(path):
def have_any_recent_upload_image_files(chatbot):
_5min = 5 * 60
if chatbot is None: return False, None # chatbot is None
if chatbot is None: return False, None # chatbot is None
most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
if not most_recent_uploaded: return False, None # most_recent_uploaded is None
if not most_recent_uploaded: return False, None # most_recent_uploaded is None
if time.time() - most_recent_uploaded["time"] < _5min:
most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
path = most_recent_uploaded['path']
file_manifest = get_pictures_list(path)
if len(file_manifest) == 0: return False, None
return True, file_manifest # most_recent_uploaded is new
return True, file_manifest # most_recent_uploaded is new
else:
return False, None # most_recent_uploaded is too old
@@ -1307,6 +1356,7 @@ def get_max_token(llm_kwargs):
from request_llms.bridge_all import model_info
return model_info[llm_kwargs['llm_model']]['max_token']
def check_packages(packages=[]):
import importlib.util
for p in packages: