diff --git a/check_proxy.py b/check_proxy.py
index b5ee17d3..6124a6ef 100644
--- a/check_proxy.py
+++ b/check_proxy.py
@@ -1,24 +1,36 @@
from loguru import logger
def check_proxy(proxies, return_ip=False):
+ """
+ 检查代理配置并返回结果。
+
+ Args:
+ proxies (dict): 包含http和https代理配置的字典。
+ return_ip (bool, optional): 是否返回代理的IP地址。默认为False。
+
+ Returns:
+ str or None: 检查的结果信息或代理的IP地址(如果`return_ip`为True)。
+ """
import requests
proxies_https = proxies['https'] if proxies is not None else '无'
ip = None
try:
- response = requests.get("https://ipapi.co/json/", proxies=proxies, timeout=4)
+ response = requests.get("https://ipapi.co/json/", proxies=proxies, timeout=4) # ⭐ 执行GET请求以获取代理信息
data = response.json()
if 'country_name' in data:
country = data['country_name']
result = f"代理配置 {proxies_https}, 代理所在地:{country}"
- if 'ip' in data: ip = data['ip']
+ if 'ip' in data:
+ ip = data['ip']
elif 'error' in data:
- alternative, ip = _check_with_backup_source(proxies)
+ alternative, ip = _check_with_backup_source(proxies) # ⭐ 调用备用方法检查代理配置
if alternative is None:
result = f"代理配置 {proxies_https}, 代理所在地:未知,IP查询频率受限"
else:
result = f"代理配置 {proxies_https}, 代理所在地:{alternative}"
else:
result = f"代理配置 {proxies_https}, 代理数据解析失败:{data}"
+
if not return_ip:
logger.warning(result)
return result
@@ -33,17 +45,33 @@ def check_proxy(proxies, return_ip=False):
return ip
def _check_with_backup_source(proxies):
+ """
+ 通过备份源检查代理,并获取相应信息。
+
+ Args:
+ proxies (dict): 包含代理信息的字典。
+
+ Returns:
+ tuple: 代理信息(geo)和IP地址(ip)的元组。
+ """
import random, string, requests
random_string = ''.join(random.choices(string.ascii_letters + string.digits, k=32))
try:
- res_json = requests.get(f"http://{random_string}.edns.ip-api.com/json", proxies=proxies, timeout=4).json()
+ res_json = requests.get(f"http://{random_string}.edns.ip-api.com/json", proxies=proxies, timeout=4).json() # ⭐ 执行代理检查和备份源请求
return res_json['dns']['geo'], res_json['dns']['ip']
except:
return None, None
def backup_and_download(current_version, remote_version):
"""
- 一键更新协议:备份和下载
+ 一键更新协议:备份当前版本,下载远程版本并解压缩。
+
+ Args:
+ current_version (str): 当前版本号。
+ remote_version (str): 远程版本号。
+
+ Returns:
+ str: 新版本目录的路径。
"""
from toolbox import get_conf
import shutil
@@ -60,7 +88,7 @@ def backup_and_download(current_version, remote_version):
proxies = get_conf('proxies')
try: r = requests.get('https://github.com/binary-husky/chatgpt_academic/archive/refs/heads/master.zip', proxies=proxies, stream=True)
except: r = requests.get('https://public.agent-matrix.com/publish/master.zip', proxies=proxies, stream=True)
- zip_file_path = backup_dir+'/master.zip'
+ zip_file_path = backup_dir+'/master.zip' # ⭐ 保存备份文件的路径
with open(zip_file_path, 'wb+') as f:
f.write(r.content)
dst_path = new_version_dir
@@ -76,6 +104,17 @@ def backup_and_download(current_version, remote_version):
def patch_and_restart(path):
"""
一键更新协议:覆盖和重启
+
+ Args:
+ path (str): 新版本代码所在的路径
+
+ 注意事项:
+ 如果您的程序没有使用config_private.py私密配置文件,则会将config.py重命名为config_private.py以避免配置丢失。
+
+ 更新流程:
+ - 复制最新版本代码到当前目录
+ - 更新pip包依赖
+ - 如果更新失败,则提示手动安装依赖库并重启
"""
from distutils import dir_util
import shutil
@@ -84,32 +123,43 @@ def patch_and_restart(path):
import time
import glob
from shared_utils.colorful import log亮黄, log亮绿, log亮红
- # if not using config_private, move origin config.py as config_private.py
+
if not os.path.exists('config_private.py'):
log亮黄('由于您没有设置config_private.py私密配置,现将您的现有配置移动至config_private.py以防止配置丢失,',
'另外您可以随时在history子文件夹下找回旧版的程序。')
shutil.copyfile('config.py', 'config_private.py')
+
path_new_version = glob.glob(path + '/*-master')[0]
- dir_util.copy_tree(path_new_version, './')
+ dir_util.copy_tree(path_new_version, './') # ⭐ 将最新版本代码复制到当前目录
+
log亮绿('代码已经更新,即将更新pip包依赖……')
for i in reversed(range(5)): time.sleep(1); log亮绿(i)
+
try:
import subprocess
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-r', 'requirements.txt'])
except:
log亮红('pip包依赖安装出现问题,需要手动安装新增的依赖库 `python -m pip install -r requirements.txt`,然后在用常规的`python main.py`的方式启动。')
+
log亮绿('更新完成,您可以随时在history子文件夹下找回旧版的程序,5s之后重启')
log亮红('假如重启失败,您可能需要手动安装新增的依赖库 `python -m pip install -r requirements.txt`,然后在用常规的`python main.py`的方式启动。')
log亮绿(' ------------------------------ -----------------------------------')
+
for i in reversed(range(8)): time.sleep(1); log亮绿(i)
- os.execl(sys.executable, sys.executable, *sys.argv)
+ os.execl(sys.executable, sys.executable, *sys.argv) # 重启程序
def get_current_version():
+ """
+ 获取当前的版本号。
+
+ Returns:
+ str: 当前的版本号。如果无法获取版本号,则返回空字符串。
+ """
import json
try:
with open('./version', 'r', encoding='utf8') as f:
- current_version = json.loads(f.read())['version']
+ current_version = json.loads(f.read())['version'] # ⭐ 从读取的json数据中提取版本号
except:
current_version = ""
return current_version
@@ -118,6 +168,12 @@ def get_current_version():
def auto_update(raise_error=False):
"""
一键更新协议:查询版本和用户意见
+
+ Args:
+ raise_error (bool, optional): 是否在出错时抛出错误。默认为 False。
+
+ Returns:
+ None
"""
try:
from toolbox import get_conf
@@ -137,13 +193,13 @@ def auto_update(raise_error=False):
current_version = json.loads(current_version)['version']
if (remote_version - current_version) >= 0.01-1e-5:
from shared_utils.colorful import log亮黄
- log亮黄(f'\n新版本可用。新版本:{remote_version},当前版本:{current_version}。{new_feature}')
+ log亮黄(f'\n新版本可用。新版本:{remote_version},当前版本:{current_version}。{new_feature}') # ⭐ 在控制台打印新版本信息
logger.info('(1)Github更新地址:\nhttps://github.com/binary-husky/chatgpt_academic\n')
user_instruction = input('(2)是否一键更新代码(Y+回车=确认,输入其他/无输入+回车=不更新)?')
if user_instruction in ['Y', 'y']:
- path = backup_and_download(current_version, remote_version)
+ path = backup_and_download(current_version, remote_version) # ⭐ 备份并下载文件
try:
- patch_and_restart(path)
+ patch_and_restart(path) # ⭐ 执行覆盖并重启操作
except:
msg = '更新失败。'
if raise_error:
@@ -163,6 +219,9 @@ def auto_update(raise_error=False):
logger.info(msg)
def warm_up_modules():
+ """
+ 预热模块,加载特定模块并执行预热操作。
+ """
logger.info('正在执行一些模块的预热 ...')
from toolbox import ProxyNetworkActivate
from request_llms.bridge_all import model_info
@@ -173,6 +232,16 @@ def warm_up_modules():
enc.encode("模块预热", disallowed_special=())
def warm_up_vectordb():
+ """
+ 执行一些模块的预热操作。
+
+ 本函数主要用于执行一些模块的预热操作,确保在后续的流程中能够顺利运行。
+
+ ⭐ 关键作用:预热模块
+
+ Returns:
+ None
+ """
logger.info('正在执行一些模块的预热 ...')
from toolbox import ProxyNetworkActivate
with ProxyNetworkActivate("Warmup_Modules"):
@@ -185,4 +254,4 @@ if __name__ == '__main__':
os.environ['no_proxy'] = '*' # 避免代理网络产生意外污染
from toolbox import get_conf
proxies = get_conf('proxies')
- check_proxy(proxies)
+ check_proxy(proxies)
\ No newline at end of file
diff --git a/crazy_functional.py b/crazy_functional.py
index de07c1bb..92bc2842 100644
--- a/crazy_functional.py
+++ b/crazy_functional.py
@@ -49,6 +49,7 @@ def get_crazy_functions():
from crazy_functions.Image_Generate import 图片生成_DALLE2, 图片生成_DALLE3, 图片修改_DALLE2
from crazy_functions.Image_Generate_Wrap import ImageGen_Wrap
from crazy_functions.SourceCode_Comment import 注释Python项目
+ from crazy_functions.SourceCode_Comment_Wrap import SourceCodeComment_Wrap
function_plugins = {
"虚空终端": {
@@ -71,6 +72,7 @@ def get_crazy_functions():
"AsButton": False,
"Info": "上传一系列python源文件(或者压缩包), 为这些代码添加docstring | 输入参数为路径",
"Function": HotReload(注释Python项目),
+ "Class": SourceCodeComment_Wrap,
},
"载入对话历史存档(先上传存档或输入路径)": {
"Group": "对话",
diff --git a/crazy_functions/Latex_Function.py b/crazy_functions/Latex_Function.py
index af020775..51b03283 100644
--- a/crazy_functions/Latex_Function.py
+++ b/crazy_functions/Latex_Function.py
@@ -3,7 +3,7 @@ from toolbox import CatchException, report_exception, update_ui_lastest_msg, zip
from functools import partial
from loguru import logger
-import glob, os, requests, time, json, tarfile
+import glob, os, requests, time, json, tarfile, threading
pj = os.path.join
ARXIV_CACHE_DIR = get_conf("ARXIV_CACHE_DIR")
@@ -138,25 +138,43 @@ def arxiv_download(chatbot, history, txt, allow_cache=True):
cached_translation_pdf = check_cached_translation_pdf(arxiv_id)
if cached_translation_pdf and allow_cache: return cached_translation_pdf, arxiv_id
- url_tar = url_.replace('/abs/', '/e-print/')
- translation_dir = pj(ARXIV_CACHE_DIR, arxiv_id, 'e-print')
extract_dst = pj(ARXIV_CACHE_DIR, arxiv_id, 'extract')
- os.makedirs(translation_dir, exist_ok=True)
-
- # <-------------- download arxiv source file ------------->
+ translation_dir = pj(ARXIV_CACHE_DIR, arxiv_id, 'e-print')
dst = pj(translation_dir, arxiv_id + '.tar')
- if os.path.exists(dst):
- yield from update_ui_lastest_msg("调用缓存", chatbot=chatbot, history=history) # 刷新界面
+ os.makedirs(translation_dir, exist_ok=True)
+ # <-------------- download arxiv source file ------------->
+
+ def fix_url_and_download():
+ # for url_tar in [url_.replace('/abs/', '/e-print/'), url_.replace('/abs/', '/src/')]:
+ for url_tar in [url_.replace('/abs/', '/src/'), url_.replace('/abs/', '/e-print/')]:
+ proxies = get_conf('proxies')
+ r = requests.get(url_tar, proxies=proxies)
+ if r.status_code == 200:
+ with open(dst, 'wb+') as f:
+ f.write(r.content)
+ return True
+ return False
+
+ if os.path.exists(dst) and allow_cache:
+ yield from update_ui_lastest_msg(f"调用缓存 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
+ success = True
else:
- yield from update_ui_lastest_msg("开始下载", chatbot=chatbot, history=history) # 刷新界面
- proxies = get_conf('proxies')
- r = requests.get(url_tar, proxies=proxies)
- with open(dst, 'wb+') as f:
- f.write(r.content)
+ yield from update_ui_lastest_msg(f"开始下载 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
+ success = fix_url_and_download()
+ yield from update_ui_lastest_msg(f"下载完成 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
+
+
+ if not success:
+ yield from update_ui_lastest_msg(f"下载失败 {arxiv_id}", chatbot=chatbot, history=history)
+ raise tarfile.ReadError(f"论文下载失败 {arxiv_id}")
+
# <-------------- extract file ------------->
- yield from update_ui_lastest_msg("下载完成", chatbot=chatbot, history=history) # 刷新界面
from toolbox import extract_archive
- extract_archive(file_path=dst, dest_dir=extract_dst)
+ try:
+ extract_archive(file_path=dst, dest_dir=extract_dst)
+ except tarfile.ReadError:
+ os.remove(dst)
+ raise tarfile.ReadError(f"论文下载失败")
return extract_dst, arxiv_id
@@ -320,11 +338,17 @@ def Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot,
# <-------------- more requirements ------------->
if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
more_req = plugin_kwargs.get("advanced_arg", "")
- no_cache = more_req.startswith("--no-cache")
- if no_cache: more_req.lstrip("--no-cache")
+
+ no_cache = ("--no-cache" in more_req)
+ if no_cache: more_req = more_req.replace("--no-cache", "").strip()
+
+ allow_gptac_cloud_io = ("--allow-cloudio" in more_req) # 从云端下载翻译结果,以及上传翻译结果到云端
+ if allow_gptac_cloud_io: more_req = more_req.replace("--allow-cloudio", "").strip()
+
allow_cache = not no_cache
_switch_prompt_ = partial(switch_prompt, more_requirement=more_req)
+
# <-------------- check deps ------------->
try:
import glob, os, time, subprocess
@@ -351,6 +375,20 @@ def Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot,
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return
+ # #################################################################
+ if allow_gptac_cloud_io and arxiv_id:
+ # 访问 GPTAC学术云,查询云端是否存在该论文的翻译版本
+ from crazy_functions.latex_fns.latex_actions import check_gptac_cloud
+ success, downloaded = check_gptac_cloud(arxiv_id, chatbot)
+ if success:
+ chatbot.append([
+ f"检测到GPTAC云端存在翻译版本, 如果不满意翻译结果, 请禁用云端分享, 然后重新执行。",
+ None
+ ])
+ yield from update_ui(chatbot=chatbot, history=history)
+ return
+ #################################################################
+
if os.path.exists(txt):
project_folder = txt
else:
@@ -388,14 +426,21 @@ def Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot,
# <-------------- zip PDF ------------->
zip_res = zip_result(project_folder)
if success:
+ if allow_gptac_cloud_io and arxiv_id:
+ # 如果用户允许,我们将翻译好的arxiv论文PDF上传到GPTAC学术云
+ from crazy_functions.latex_fns.latex_actions import upload_to_gptac_cloud_if_user_allow
+ threading.Thread(target=upload_to_gptac_cloud_if_user_allow,
+ args=(chatbot, arxiv_id), daemon=True).start()
+
chatbot.append((f"成功啦", '请查收结果(压缩包)...'))
- yield from update_ui(chatbot=chatbot, history=history);
+ yield from update_ui(chatbot=chatbot, history=history)
time.sleep(1) # 刷新界面
promote_file_to_downloadzone(file=zip_res, chatbot=chatbot)
+
else:
chatbot.append((f"失败了",
'虽然PDF生成失败了, 但请查收结果(压缩包), 内含已经翻译的Tex文档, 您可以到Github Issue区, 用该压缩包进行反馈。如系统是Linux,请检查系统字体(见Github wiki) ...'))
- yield from update_ui(chatbot=chatbot, history=history);
+ yield from update_ui(chatbot=chatbot, history=history)
time.sleep(1) # 刷新界面
promote_file_to_downloadzone(file=zip_res, chatbot=chatbot)
diff --git a/crazy_functions/Latex_Function_Wrap.py b/crazy_functions/Latex_Function_Wrap.py
index 5d7b1f31..cef56965 100644
--- a/crazy_functions/Latex_Function_Wrap.py
+++ b/crazy_functions/Latex_Function_Wrap.py
@@ -30,6 +30,8 @@ class Arxiv_Localize(GptAcademicPluginTemplate):
default_value="", type="string").model_dump_json(), # 高级参数输入区,自动同步
"allow_cache":
ArgProperty(title="是否允许从缓存中调取结果", options=["允许缓存", "从头执行"], default_value="允许缓存", description="无", type="dropdown").model_dump_json(),
+ "allow_cloudio":
+ ArgProperty(title="是否允许从GPTAC学术云下载(或者上传)翻译结果(仅针对Arxiv论文)", options=["允许", "禁止"], default_value="禁止", description="共享文献,互助互利", type="dropdown").model_dump_json(),
}
return gui_definition
@@ -38,9 +40,14 @@ class Arxiv_Localize(GptAcademicPluginTemplate):
执行插件
"""
allow_cache = plugin_kwargs["allow_cache"]
+ allow_cloudio = plugin_kwargs["allow_cloudio"]
advanced_arg = plugin_kwargs["advanced_arg"]
if allow_cache == "从头执行": plugin_kwargs["advanced_arg"] = "--no-cache " + plugin_kwargs["advanced_arg"]
+
+ # 从云端下载翻译结果,以及上传翻译结果到云端;人人为我,我为人人。
+ if allow_cloudio == "允许": plugin_kwargs["advanced_arg"] = "--allow-cloudio " + plugin_kwargs["advanced_arg"]
+
yield from Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)
diff --git a/crazy_functions/Markdown_Translate.py b/crazy_functions/Markdown_Translate.py
index 858d13da..45d4c712 100644
--- a/crazy_functions/Markdown_Translate.py
+++ b/crazy_functions/Markdown_Translate.py
@@ -65,7 +65,7 @@ def 多文件翻译(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
pfg.file_contents.append(file_content)
# <-------- 拆分过长的Markdown文件 ---------->
- pfg.run_file_split(max_token_limit=2048)
+ pfg.run_file_split(max_token_limit=1024)
n_split = len(pfg.sp_file_contents)
# <-------- 多线程翻译开始 ---------->
diff --git a/crazy_functions/SourceCode_Comment.py b/crazy_functions/SourceCode_Comment.py
index 20390800..9d9969ab 100644
--- a/crazy_functions/SourceCode_Comment.py
+++ b/crazy_functions/SourceCode_Comment.py
@@ -6,7 +6,10 @@ from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_ver
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
from crazy_functions.agent_fns.python_comment_agent import PythonCodeComment
from crazy_functions.diagram_fns.file_tree import FileNode
+from crazy_functions.agent_fns.watchdog import WatchDog
from shared_utils.advanced_markdown_format import markdown_convertion_for_file
+from loguru import logger
+
def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
@@ -24,12 +27,13 @@ def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
file_tree_struct.add_file(file_path, file_path)
# <第一步,逐个文件分析,多线程>
+ lang = "" if not plugin_kwargs["use_chinese"] else " (you must use Chinese)"
for index, fp in enumerate(file_manifest):
# 读取文件
with open(fp, 'r', encoding='utf-8', errors='replace') as f:
file_content = f.read()
prefix = ""
- i_say = prefix + f'Please conclude the following source code at {os.path.relpath(fp, project_folder)} with only one sentence, the code is:\n```{file_content}```'
+ i_say = prefix + f'Please conclude the following source code at {os.path.relpath(fp, project_folder)} with only one sentence{lang}, the code is:\n```{file_content}```'
i_say_show_user = prefix + f'[{index+1}/{len(file_manifest)}] 请用一句话对下面的程序文件做一个整体概述: {fp}'
# 装载请求内容
MAX_TOKEN_SINGLE_FILE = 2560
@@ -37,7 +41,7 @@ def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
inputs_array.append(i_say)
inputs_show_user_array.append(i_say_show_user)
history_array.append([])
- sys_prompt_array.append("You are a software architecture analyst analyzing a source code project. Do not dig into details, tell me what the code is doing in general. Your answer must be short, simple and clear.")
+ sys_prompt_array.append(f"You are a software architecture analyst analyzing a source code project. Do not dig into details, tell me what the code is doing in general. Your answer must be short, simple and clear{lang}.")
# 文件读取完成,对每一个源代码文件,生成一个请求线程,发送到大模型进行分析
gpt_response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
inputs_array = inputs_array,
@@ -50,10 +54,20 @@ def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
)
# <第二步,逐个文件分析,生成带注释文件>
+ tasks = ["" for _ in range(len(file_manifest))]
+ def bark_fn(tasks):
+ for i in range(len(tasks)): tasks[i] = "watchdog is dead"
+ wd = WatchDog(timeout=10, bark_fn=lambda: bark_fn(tasks), interval=3, msg="ThreadWatcher timeout")
+ wd.begin_watch()
from concurrent.futures import ThreadPoolExecutor
executor = ThreadPoolExecutor(max_workers=get_conf('DEFAULT_WORKER_NUM'))
- def _task_multi_threading(i_say, gpt_say, fp, file_tree_struct):
- pcc = PythonCodeComment(llm_kwargs, language='English')
+ def _task_multi_threading(i_say, gpt_say, fp, file_tree_struct, index):
+ language = 'Chinese' if plugin_kwargs["use_chinese"] else 'English'
+ def observe_window_update(x):
+ if tasks[index] == "watchdog is dead":
+ raise TimeoutError("ThreadWatcher: watchdog is dead")
+ tasks[index] = x
+ pcc = PythonCodeComment(llm_kwargs, plugin_kwargs, language=language, observe_window_update=observe_window_update)
pcc.read_file(path=fp, brief=gpt_say)
revised_path, revised_content = pcc.begin_comment_source_code(None, None)
file_tree_struct.manifest[fp].revised_path = revised_path
@@ -65,7 +79,8 @@ def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
with open("crazy_functions/agent_fns/python_comment_compare.html", 'r', encoding='utf-8') as f:
html_template = f.read()
warp = lambda x: "```python\n\n" + x + "\n\n```"
- from themes.theme import advanced_css
+ from themes.theme import load_dynamic_theme
+ _, advanced_css, _, _ = load_dynamic_theme("Default")
html_template = html_template.replace("ADVANCED_CSS", advanced_css)
html_template = html_template.replace("REPLACE_CODE_FILE_LEFT", pcc.get_markdown_block_in_html(markdown_convertion_for_file(warp(pcc.original_content))))
html_template = html_template.replace("REPLACE_CODE_FILE_RIGHT", pcc.get_markdown_block_in_html(markdown_convertion_for_file(warp(revised_content))))
@@ -73,17 +88,21 @@ def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
file_tree_struct.manifest[fp].compare_html = compare_html_path
with open(compare_html_path, 'w', encoding='utf-8') as f:
f.write(html_template)
- # print('done 1')
+ tasks[index] = ""
chatbot.append([None, f"正在处理:"])
futures = []
+ index = 0
for i_say, gpt_say, fp in zip(gpt_response_collection[0::2], gpt_response_collection[1::2], file_manifest):
- future = executor.submit(_task_multi_threading, i_say, gpt_say, fp, file_tree_struct)
+ future = executor.submit(_task_multi_threading, i_say, gpt_say, fp, file_tree_struct, index)
+ index += 1
futures.append(future)
+ # <第三步,等待任务完成>
cnt = 0
while True:
cnt += 1
+ wd.feed()
time.sleep(3)
worker_done = [h.done() for h in futures]
remain = len(worker_done) - sum(worker_done)
@@ -92,14 +111,18 @@ def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
preview_html_list = []
for done, fp in zip(worker_done, file_manifest):
if not done: continue
- preview_html_list.append(file_tree_struct.manifest[fp].compare_html)
+ if hasattr(file_tree_struct.manifest[fp], 'compare_html'):
+ preview_html_list.append(file_tree_struct.manifest[fp].compare_html)
+ else:
+ logger.error(f"文件: {fp} 的注释结果未能成功")
file_links = generate_file_link(preview_html_list)
yield from update_ui_lastest_msg(
- f"剩余源文件数量: {remain}.\n\n" +
- f"已完成的文件: {sum(worker_done)}.\n\n" +
+ f"当前任务:
{'
'.join(tasks)}.
" +
+ f"剩余源文件数量: {remain}.
" +
+ f"已完成的文件: {sum(worker_done)}.
" +
file_links +
- "\n\n" +
+ "
" +
''.join(['.']*(cnt % 10 + 1)
), chatbot=chatbot, history=history, delay=0)
yield from update_ui(chatbot=chatbot, history=[]) # 刷新界面
@@ -120,6 +143,7 @@ def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
@CatchException
def 注释Python项目(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
history = [] # 清空历史,以免输入溢出
+ plugin_kwargs["use_chinese"] = plugin_kwargs.get("use_chinese", False)
import glob, os
if os.path.exists(txt):
project_folder = txt
diff --git a/crazy_functions/SourceCode_Comment_Wrap.py b/crazy_functions/SourceCode_Comment_Wrap.py
new file mode 100644
index 00000000..b7425526
--- /dev/null
+++ b/crazy_functions/SourceCode_Comment_Wrap.py
@@ -0,0 +1,36 @@
+
+from toolbox import get_conf, update_ui
+from crazy_functions.plugin_template.plugin_class_template import GptAcademicPluginTemplate, ArgProperty
+from crazy_functions.SourceCode_Comment import 注释Python项目
+
+class SourceCodeComment_Wrap(GptAcademicPluginTemplate):
+ def __init__(self):
+ """
+ 请注意`execute`会执行在不同的线程中,因此您在定义和使用类变量时,应当慎之又慎!
+ """
+ pass
+
+ def define_arg_selection_menu(self):
+ """
+ 定义插件的二级选项菜单
+ """
+ gui_definition = {
+ "main_input":
+ ArgProperty(title="路径", description="程序路径(上传文件后自动填写)", default_value="", type="string").model_dump_json(), # 主输入,自动从输入框同步
+ "use_chinese":
+ ArgProperty(title="注释语言", options=["英文", "中文"], default_value="英文", description="无", type="dropdown").model_dump_json(),
+ # "use_emoji":
+ # ArgProperty(title="在注释中使用emoji", options=["禁止", "允许"], default_value="禁止", description="无", type="dropdown").model_dump_json(),
+ }
+ return gui_definition
+
+ def execute(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
+ """
+ 执行插件
+ """
+ if plugin_kwargs["use_chinese"] == "中文":
+ plugin_kwargs["use_chinese"] = True
+ else:
+ plugin_kwargs["use_chinese"] = False
+
+ yield from 注释Python项目(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)
diff --git a/crazy_functions/agent_fns/python_comment_agent.py b/crazy_functions/agent_fns/python_comment_agent.py
index dd4b6ce8..9f19b17e 100644
--- a/crazy_functions/agent_fns/python_comment_agent.py
+++ b/crazy_functions/agent_fns/python_comment_agent.py
@@ -68,6 +68,7 @@ Be aware:
1. You must NOT modify the indent of code.
2. You are NOT authorized to change or translate non-comment code, and you are NOT authorized to add empty lines either, toggle qu.
3. Use {LANG} to add comments and docstrings. Do NOT translate Chinese that is already in the code.
+4. Besides adding a docstring, use the ⭐ symbol to annotate the most core and important line of code within the function, explaining its role.
------------------ Example ------------------
INPUT:
@@ -116,10 +117,66 @@ def zip_result(folder):
'''
+revise_funtion_prompt_chinese = '''
+您需要阅读以下代码,并根据以下说明修订源代码({FILE_BASENAME}):
+1. 如果源代码中包含函数的话, 你应该分析给定函数实现了什么功能
+2. 如果源代码中包含函数的话, 你需要为函数添加docstring, docstring必须使用中文
+
+请注意:
+1. 你不得修改代码的缩进
+2. 你无权更改或翻译代码中的非注释部分,也不允许添加空行
+3. 使用 {LANG} 添加注释和文档字符串。不要翻译代码中已有的中文
+4. 除了添加docstring之外, 使用⭐符号给该函数中最核心、最重要的一行代码添加注释,并说明其作用
+
+------------------ 示例 ------------------
+INPUT:
+```
+L0000 |
+L0001 |def zip_result(folder):
+L0002 | t = gen_time_str()
+L0003 | zip_folder(folder, get_log_folder(), f"result.zip")
+L0004 | return os.path.join(get_log_folder(), f"result.zip")
+L0005 |
+L0006 |
+```
+
+OUTPUT:
+
+
+该函数用于压缩指定文件夹,并返回生成的`zip`文件的路径。
+
+
+```
+def zip_result(folder):
+ """
+ 该函数将指定的文件夹压缩成ZIP文件, 并将其存储在日志文件夹中。
+
+ 输入参数:
+ folder (str): 需要压缩的文件夹的路径。
+ 返回值:
+ str: 日志文件夹中创建的ZIP文件的路径。
+ """
+ t = gen_time_str()
+ zip_folder(folder, get_log_folder(), f"result.zip") # ⭐ 执行文件夹的压缩
+ return os.path.join(get_log_folder(), f"result.zip")
+```
+
+------------------ End of Example ------------------
+
+
+------------------ the real INPUT you need to process NOW ({FILE_BASENAME}) ------------------
+```
+{THE_CODE}
+```
+{INDENT_REMINDER}
+{BRIEF_REMINDER}
+{HINT_REMINDER}
+'''
+
class PythonCodeComment():
- def __init__(self, llm_kwargs, language) -> None:
+ def __init__(self, llm_kwargs, plugin_kwargs, language, observe_window_update) -> None:
self.original_content = ""
self.full_context = []
self.full_context_with_line_no = []
@@ -127,7 +184,13 @@ class PythonCodeComment():
self.page_limit = 100 # 100 lines of code each page
self.ignore_limit = 20
self.llm_kwargs = llm_kwargs
+ self.plugin_kwargs = plugin_kwargs
self.language = language
+ self.observe_window_update = observe_window_update
+ if self.language == "chinese":
+ self.core_prompt = revise_funtion_prompt_chinese
+ else:
+ self.core_prompt = revise_funtion_prompt
self.path = None
self.file_basename = None
self.file_brief = ""
@@ -258,7 +321,7 @@ class PythonCodeComment():
hint_reminder = "" if hint is None else f"(Reminder: do not ignore or modify code such as `{hint}`, provide complete code in the OUTPUT.)"
self.llm_kwargs['temperature'] = 0
result = predict_no_ui_long_connection(
- inputs=revise_funtion_prompt.format(
+ inputs=self.core_prompt.format(
LANG=self.language,
FILE_BASENAME=self.file_basename,
THE_CODE=code,
@@ -348,6 +411,7 @@ class PythonCodeComment():
try:
# yield from update_ui_lastest_msg(f"({self.file_basename}) 正在读取下一段代码片段:\n", chatbot=chatbot, history=history, delay=0)
next_batch, line_no_start, line_no_end = self.get_next_batch()
+ self.observe_window_update(f"正在处理{self.file_basename} - {line_no_start}/{len(self.full_context)}\n")
# yield from update_ui_lastest_msg(f"({self.file_basename}) 处理代码片段:\n\n{next_batch}", chatbot=chatbot, history=history, delay=0)
hint = None
diff --git a/crazy_functions/ast_fns/comment_remove.py b/crazy_functions/ast_fns/comment_remove.py
index 1c482afd..b37c90e0 100644
--- a/crazy_functions/ast_fns/comment_remove.py
+++ b/crazy_functions/ast_fns/comment_remove.py
@@ -1,39 +1,47 @@
-import ast
+import token
+import tokenize
+import copy
+import io
-class CommentRemover(ast.NodeTransformer):
- def visit_FunctionDef(self, node):
- # 移除函数的文档字符串
- if (node.body and isinstance(node.body[0], ast.Expr) and
- isinstance(node.body[0].value, ast.Str)):
- node.body = node.body[1:]
- self.generic_visit(node)
- return node
- def visit_ClassDef(self, node):
- # 移除类的文档字符串
- if (node.body and isinstance(node.body[0], ast.Expr) and
- isinstance(node.body[0].value, ast.Str)):
- node.body = node.body[1:]
- self.generic_visit(node)
- return node
+def remove_python_comments(input_source: str) -> str:
+ source_flag = copy.copy(input_source)
+ source = io.StringIO(input_source)
+ ls = input_source.split('\n')
+ prev_toktype = token.INDENT
+ readline = source.readline
- def visit_Module(self, node):
- # 移除模块的文档字符串
- if (node.body and isinstance(node.body[0], ast.Expr) and
- isinstance(node.body[0].value, ast.Str)):
- node.body = node.body[1:]
- self.generic_visit(node)
- return node
-
+ def get_char_index(lineno, col):
+ # find the index of the char in the source code
+ if lineno == 1:
+ return len('\n'.join(ls[:(lineno-1)])) + col
+ else:
+ return len('\n'.join(ls[:(lineno-1)])) + col + 1
+
+ def replace_char_between(start_lineno, start_col, end_lineno, end_col, source, replace_char, ls):
+ # replace char between start_lineno, start_col and end_lineno, end_col with replace_char, but keep '\n' and ' '
+ b = get_char_index(start_lineno, start_col)
+ e = get_char_index(end_lineno, end_col)
+ for i in range(b, e):
+ if source[i] == '\n':
+ source = source[:i] + '\n' + source[i+1:]
+ elif source[i] == ' ':
+ source = source[:i] + ' ' + source[i+1:]
+ else:
+ source = source[:i] + replace_char + source[i+1:]
+ return source
+
+ tokgen = tokenize.generate_tokens(readline)
+ for toktype, ttext, (slineno, scol), (elineno, ecol), ltext in tokgen:
+ if toktype == token.STRING and (prev_toktype == token.INDENT):
+ source_flag = replace_char_between(slineno, scol, elineno, ecol, source_flag, ' ', ls)
+ elif toktype == token.STRING and (prev_toktype == token.NEWLINE):
+ source_flag = replace_char_between(slineno, scol, elineno, ecol, source_flag, ' ', ls)
+ elif toktype == tokenize.COMMENT:
+ source_flag = replace_char_between(slineno, scol, elineno, ecol, source_flag, ' ', ls)
+ prev_toktype = toktype
+ return source_flag
-def remove_python_comments(source_code):
- # 解析源代码为 AST
- tree = ast.parse(source_code)
- # 移除注释
- transformer = CommentRemover()
- tree = transformer.visit(tree)
- # 将处理后的 AST 转换回源代码
- return ast.unparse(tree)
# 示例使用
if __name__ == "__main__":
diff --git a/crazy_functions/latex_fns/latex_actions.py b/crazy_functions/latex_fns/latex_actions.py
index 4293f0d0..b7dee4ec 100644
--- a/crazy_functions/latex_fns/latex_actions.py
+++ b/crazy_functions/latex_fns/latex_actions.py
@@ -3,7 +3,7 @@ import re
import shutil
import numpy as np
from loguru import logger
-from toolbox import update_ui, update_ui_lastest_msg, get_log_folder
+from toolbox import update_ui, update_ui_lastest_msg, get_log_folder, gen_time_str
from toolbox import get_conf, promote_file_to_downloadzone
from crazy_functions.latex_fns.latex_toolbox import PRESERVE, TRANSFORM
from crazy_functions.latex_fns.latex_toolbox import set_forbidden_text, set_forbidden_text_begin_end, set_forbidden_text_careful_brace
@@ -468,3 +468,70 @@ def write_html(sp_file_contents, sp_file_result, chatbot, project_folder):
except:
from toolbox import trimmed_format_exc
logger.error('writing html result failed:', trimmed_format_exc())
+
+
+def upload_to_gptac_cloud_if_user_allow(chatbot, arxiv_id):
+ try:
+ # 如果用户允许,我们将arxiv论文PDF上传到GPTAC学术云
+ from toolbox import map_file_to_sha256
+ # 检查是否顺利,如果没有生成预期的文件,则跳过
+ is_result_good = False
+ for file_path in chatbot._cookies.get("files_to_promote", []):
+ if file_path.endswith('translate_zh.pdf'):
+ is_result_good = True
+ if not is_result_good:
+ return
+ # 上传文件
+ for file_path in chatbot._cookies.get("files_to_promote", []):
+ align_name = None
+ # normalized name
+ for name in ['translate_zh.pdf', 'comparison.pdf']:
+ if file_path.endswith(name): align_name = name
+ # if match any align name
+ if align_name:
+ logger.info(f'Uploading to GPTAC cloud as the user has set `allow_cloud_io`: {file_path}')
+ with open(file_path, 'rb') as f:
+ import requests
+ url = 'https://cloud-2.agent-matrix.com/arxiv_tf_paper_normal_upload'
+ files = {'file': (align_name, f, 'application/octet-stream')}
+ data = {
+ 'arxiv_id': arxiv_id,
+ 'file_hash': map_file_to_sha256(file_path),
+ 'language': 'zh',
+ 'trans_prompt': 'to_be_implemented',
+ 'llm_model': 'to_be_implemented',
+ 'llm_model_param': 'to_be_implemented',
+ }
+ resp = requests.post(url=url, files=files, data=data, timeout=30)
+ logger.info(f'Uploading terminate ({resp.status_code})`: {file_path}')
+ except:
+ # 如果上传失败,不会中断程序,因为这是次要功能
+ pass
+
+def check_gptac_cloud(arxiv_id, chatbot):
+ import requests
+ success = False
+ downloaded = []
+ try:
+ for pdf_target in ['translate_zh.pdf', 'comparison.pdf']:
+ url = 'https://cloud-2.agent-matrix.com/arxiv_tf_paper_normal_exist'
+ data = {
+ 'arxiv_id': arxiv_id,
+ 'name': pdf_target,
+ }
+ resp = requests.post(url=url, data=data)
+ cache_hit_result = resp.text.strip('"')
+ if cache_hit_result.startswith("http"):
+ url = cache_hit_result
+ logger.info(f'Downloading from GPTAC cloud: {url}')
+ resp = requests.get(url=url, timeout=30)
+ target = os.path.join(get_log_folder(plugin_name='gptac_cloud'), gen_time_str(), pdf_target)
+ os.makedirs(os.path.dirname(target), exist_ok=True)
+ with open(target, 'wb') as f:
+ f.write(resp.content)
+ new_path = promote_file_to_downloadzone(target, chatbot=chatbot)
+ success = True
+ downloaded.append(new_path)
+ except:
+ pass
+ return success, downloaded
diff --git a/crazy_functions/latex_fns/latex_pickle_io.py b/crazy_functions/latex_fns/latex_pickle_io.py
index 7b93ea87..d951bf58 100644
--- a/crazy_functions/latex_fns/latex_pickle_io.py
+++ b/crazy_functions/latex_fns/latex_pickle_io.py
@@ -6,12 +6,16 @@ class SafeUnpickler(pickle.Unpickler):
def get_safe_classes(self):
from crazy_functions.latex_fns.latex_actions import LatexPaperFileGroup, LatexPaperSplit
from crazy_functions.latex_fns.latex_toolbox import LinkedListNode
+ from numpy.core.multiarray import scalar
+ from numpy import dtype
# 定义允许的安全类
safe_classes = {
# 在这里添加其他安全的类
'LatexPaperFileGroup': LatexPaperFileGroup,
'LatexPaperSplit': LatexPaperSplit,
'LinkedListNode': LinkedListNode,
+ 'scalar': scalar,
+ 'dtype': dtype,
}
return safe_classes
@@ -22,8 +26,6 @@ class SafeUnpickler(pickle.Unpickler):
for class_name in self.safe_classes.keys():
if (class_name in f'{module}.{name}'):
match_class_name = class_name
- if module == 'numpy' or module.startswith('numpy.'):
- return super().find_class(module, name)
if match_class_name is not None:
return self.safe_classes[match_class_name]
# 如果尝试加载未授权的类,则抛出异常
diff --git a/crazy_functions/latex_fns/latex_toolbox.py b/crazy_functions/latex_fns/latex_toolbox.py
index a49ffc4e..3a42243a 100644
--- a/crazy_functions/latex_fns/latex_toolbox.py
+++ b/crazy_functions/latex_fns/latex_toolbox.py
@@ -697,15 +697,6 @@ def _merge_pdfs_ng(pdf1_path, pdf2_path, output_path):
),
0,
)
- if "/Annots" in page1:
- page1_annot_id = [annot.idnum for annot in page1["/Annots"]]
- else:
- page1_annot_id = []
-
- if "/Annots" in page2:
- page2_annot_id = [annot.idnum for annot in page2["/Annots"]]
- else:
- page2_annot_id = []
if "/Annots" in new_page:
annotations = new_page["/Annots"]
for i, annot in enumerate(annotations):
@@ -720,114 +711,148 @@ def _merge_pdfs_ng(pdf1_path, pdf2_path, output_path):
if "/S" in action and action["/S"] == "/GoTo":
# 内部链接:跳转到文档中的某个页面
dest = action.get("/D") # 目标页或目标位置
- if dest and annot.idnum in page2_annot_id:
- # 获取原始文件中跳转信息,包括跳转页面
- destination = pdf2_reader.named_destinations[
- dest
- ]
- page_number = (
- pdf2_reader.get_destination_page_number(
- destination
- )
- )
- # 更新跳转信息,跳转到对应的页面和,指定坐标 (100, 150),缩放比例为 100%
- # “/D”:[10,'/XYZ',100,100,0]
- annot_obj["/A"].update(
- {
- NameObject("/D"): ArrayObject(
- [
- NumberObject(page_number),
- destination.dest_array[1],
- FloatObject(
- destination.dest_array[2]
- + int(
- page1.mediaBox.getWidth()
- )
- ),
- destination.dest_array[3],
- destination.dest_array[4],
- ]
- ) # 确保键和值是 PdfObject
- }
- )
- rect = annot_obj.get("/Rect")
- # 更新点击坐标
- rect = ArrayObject(
- [
- FloatObject(
- rect[0]
- + int(page1.mediaBox.getWidth())
- ),
- rect[1],
- FloatObject(
- rect[2]
- + int(page1.mediaBox.getWidth())
- ),
- rect[3],
+ # if dest and annot.idnum in page2_annot_id:
+ # if dest in pdf2_reader.named_destinations:
+ if dest and page2.annotations:
+ if annot in page2.annotations:
+ # 获取原始文件中跳转信息,包括跳转页面
+ destination = pdf2_reader.named_destinations[
+ dest
]
- )
- annot_obj.update(
- {
- NameObject(
- "/Rect"
- ): rect # 确保键和值是 PdfObject
- }
- )
- if dest and annot.idnum in page1_annot_id:
- # 获取原始文件中跳转信息,包括跳转页面
- destination = pdf1_reader.named_destinations[
- dest
- ]
- page_number = (
- pdf1_reader.get_destination_page_number(
- destination
+ page_number = (
+ pdf2_reader.get_destination_page_number(
+ destination
+ )
)
- )
- # 更新跳转信息,跳转到对应的页面和,指定坐标 (100, 150),缩放比例为 100%
- # “/D”:[10,'/XYZ',100,100,0]
- annot_obj["/A"].update(
- {
- NameObject("/D"): ArrayObject(
- [
- NumberObject(page_number),
- destination.dest_array[1],
- FloatObject(
- destination.dest_array[2]
- ),
- destination.dest_array[3],
- destination.dest_array[4],
- ]
- ) # 确保键和值是 PdfObject
- }
- )
- rect = annot_obj.get("/Rect")
- rect = ArrayObject(
- [
- FloatObject(rect[0]),
- rect[1],
- FloatObject(rect[2]),
- rect[3],
+ # 更新跳转信息,跳转到对应的页面和,指定坐标 (100, 150),缩放比例为 100%
+ # “/D”:[10,'/XYZ',100,100,0]
+ if destination.dest_array[1] == "/XYZ":
+ annot_obj["/A"].update(
+ {
+ NameObject("/D"): ArrayObject(
+ [
+ NumberObject(page_number),
+ destination.dest_array[1],
+ FloatObject(
+ destination.dest_array[
+ 2
+ ]
+ + int(
+ page1.mediaBox.getWidth()
+ )
+ ),
+ destination.dest_array[3],
+ destination.dest_array[4],
+ ]
+ ) # 确保键和值是 PdfObject
+ }
+ )
+ else:
+ annot_obj["/A"].update(
+ {
+ NameObject("/D"): ArrayObject(
+ [
+ NumberObject(page_number),
+ destination.dest_array[1],
+ ]
+ ) # 确保键和值是 PdfObject
+ }
+ )
+
+ rect = annot_obj.get("/Rect")
+ # 更新点击坐标
+ rect = ArrayObject(
+ [
+ FloatObject(
+ rect[0]
+ + int(page1.mediaBox.getWidth())
+ ),
+ rect[1],
+ FloatObject(
+ rect[2]
+ + int(page1.mediaBox.getWidth())
+ ),
+ rect[3],
+ ]
+ )
+ annot_obj.update(
+ {
+ NameObject(
+ "/Rect"
+ ): rect # 确保键和值是 PdfObject
+ }
+ )
+ # if dest and annot.idnum in page1_annot_id:
+ # if dest in pdf1_reader.named_destinations:
+ if dest and page1.annotations:
+ if annot in page1.annotations:
+ # 获取原始文件中跳转信息,包括跳转页面
+ destination = pdf1_reader.named_destinations[
+ dest
]
- )
- annot_obj.update(
- {
- NameObject(
- "/Rect"
- ): rect # 确保键和值是 PdfObject
- }
- )
+ page_number = (
+ pdf1_reader.get_destination_page_number(
+ destination
+ )
+ )
+ # 更新跳转信息,跳转到对应的页面和,指定坐标 (100, 150),缩放比例为 100%
+ # “/D”:[10,'/XYZ',100,100,0]
+ if destination.dest_array[1] == "/XYZ":
+ annot_obj["/A"].update(
+ {
+ NameObject("/D"): ArrayObject(
+ [
+ NumberObject(page_number),
+ destination.dest_array[1],
+ FloatObject(
+ destination.dest_array[
+ 2
+ ]
+ ),
+ destination.dest_array[3],
+ destination.dest_array[4],
+ ]
+ ) # 确保键和值是 PdfObject
+ }
+ )
+ else:
+ annot_obj["/A"].update(
+ {
+ NameObject("/D"): ArrayObject(
+ [
+ NumberObject(page_number),
+ destination.dest_array[1],
+ ]
+ ) # 确保键和值是 PdfObject
+ }
+ )
+
+ rect = annot_obj.get("/Rect")
+ rect = ArrayObject(
+ [
+ FloatObject(rect[0]),
+ rect[1],
+ FloatObject(rect[2]),
+ rect[3],
+ ]
+ )
+ annot_obj.update(
+ {
+ NameObject(
+ "/Rect"
+ ): rect # 确保键和值是 PdfObject
+ }
+ )
elif "/S" in action and action["/S"] == "/URI":
# 外部链接:跳转到某个URI
uri = action.get("/URI")
-
output_writer.addPage(new_page)
# Save the merged PDF file
with open(output_path, "wb") as output_file:
output_writer.write(output_file)
-
def _merge_pdfs_legacy(pdf1_path, pdf2_path, output_path):
import PyPDF2 # PyPDF2这个库有严重的内存泄露问题,把它放到子进程中运行,从而方便内存的释放
diff --git a/crazy_functions/pdf_fns/parse_pdf_via_doc2x.py b/crazy_functions/pdf_fns/parse_pdf_via_doc2x.py
index d64aa91c..97c62fbf 100644
--- a/crazy_functions/pdf_fns/parse_pdf_via_doc2x.py
+++ b/crazy_functions/pdf_fns/parse_pdf_via_doc2x.py
@@ -4,7 +4,9 @@ from toolbox import promote_file_to_downloadzone, extract_archive
from toolbox import generate_file_link, zip_folder
from crazy_functions.crazy_utils import get_files_from_everything
from shared_utils.colorful import *
+from loguru import logger
import os
+import time
def refresh_key(doc2x_api_key):
import requests, json
@@ -22,105 +24,140 @@ def refresh_key(doc2x_api_key):
raise RuntimeError(format("[ERROR] status code: %d, body: %s" % (res.status_code, res.text)))
return doc2x_api_key
+
+
def 解析PDF_DOC2X_转Latex(pdf_file_path):
+ zip_file_path, unzipped_folder = 解析PDF_DOC2X(pdf_file_path, format='tex')
+ return unzipped_folder
+
+
+def 解析PDF_DOC2X(pdf_file_path, format='tex'):
+ """
+ format: 'tex', 'md', 'docx'
+ """
import requests, json, os
DOC2X_API_KEY = get_conf('DOC2X_API_KEY')
latex_dir = get_log_folder(plugin_name="pdf_ocr_latex")
+ markdown_dir = get_log_folder(plugin_name="pdf_ocr")
doc2x_api_key = DOC2X_API_KEY
- if doc2x_api_key.startswith('sk-'):
- url = "https://api.doc2x.noedgeai.com/api/v1/pdf"
- else:
- doc2x_api_key = refresh_key(doc2x_api_key)
- url = "https://api.doc2x.noedgeai.com/api/platform/pdf"
+
+ # < ------ 第1步:上传 ------ >
+ logger.info("Doc2x 第1步:上传")
+ with open(pdf_file_path, 'rb') as file:
+ res = requests.post(
+ "https://v2.doc2x.noedgeai.com/api/v2/parse/pdf",
+ headers={"Authorization": "Bearer " + doc2x_api_key},
+ data=file
+ )
+ # res_json = []
+ if res.status_code == 200:
+ res_json = res.json()
+ else:
+ raise RuntimeError(f"Doc2x return an error: {res.json()}")
+ uuid = res_json['data']['uid']
+
+ # < ------ 第2步:轮询等待 ------ >
+ logger.info("Doc2x 第2步:轮询等待")
+ params = {'uid': uuid}
+ while True:
+ res = requests.get(
+ 'https://v2.doc2x.noedgeai.com/api/v2/parse/status',
+ headers={"Authorization": "Bearer " + doc2x_api_key},
+ params=params
+ )
+ res_json = res.json()
+ if res_json['data']['status'] == "success":
+ break
+ elif res_json['data']['status'] == "processing":
+ time.sleep(3)
+ logger.info(f"Doc2x is processing at {res_json['data']['progress']}%")
+ elif res_json['data']['status'] == "failed":
+ raise RuntimeError(f"Doc2x return an error: {res_json}")
+
+
+ # < ------ 第3步:提交转化 ------ >
+ logger.info("Doc2x 第3步:提交转化")
+ data = {
+ "uid": uuid,
+ "to": format,
+ "formula_mode": "dollar",
+ "filename": "output"
+ }
res = requests.post(
- url,
- files={"file": open(pdf_file_path, "rb")},
- data={"ocr": "1"},
- headers={"Authorization": "Bearer " + doc2x_api_key}
+ 'https://v2.doc2x.noedgeai.com/api/v2/convert/parse',
+ headers={"Authorization": "Bearer " + doc2x_api_key},
+ json=data
)
- res_json = []
if res.status_code == 200:
- decoded = res.content.decode("utf-8")
- for z_decoded in decoded.split('\n'):
- if len(z_decoded) == 0: continue
- assert z_decoded.startswith("data: ")
- z_decoded = z_decoded[len("data: "):]
- decoded_json = json.loads(z_decoded)
- res_json.append(decoded_json)
+ res_json = res.json()
else:
- raise RuntimeError(format("[ERROR] status code: %d, body: %s" % (res.status_code, res.text)))
+ raise RuntimeError(f"Doc2x return an error: {res.json()}")
- uuid = res_json[0]['uuid']
- to = "latex" # latex, md, docx
- url = "https://api.doc2x.noedgeai.com/api/export"+"?request_id="+uuid+"&to="+to
- res = requests.get(url, headers={"Authorization": "Bearer " + doc2x_api_key})
- latex_zip_path = os.path.join(latex_dir, gen_time_str() + '.zip')
- latex_unzip_path = os.path.join(latex_dir, gen_time_str())
- if res.status_code == 200:
- with open(latex_zip_path, "wb") as f: f.write(res.content)
- else:
- raise RuntimeError(format("[ERROR] status code: %d, body: %s" % (res.status_code, res.text)))
+ # < ------ 第4步:等待结果 ------ >
+ logger.info("Doc2x 第4步:等待结果")
+ params = {'uid': uuid}
+ while True:
+ res = requests.get(
+ 'https://v2.doc2x.noedgeai.com/api/v2/convert/parse/result',
+ headers={"Authorization": "Bearer " + doc2x_api_key},
+ params=params
+ )
+ res_json = res.json()
+ if res_json['data']['status'] == "success":
+ break
+ elif res_json['data']['status'] == "processing":
+ time.sleep(3)
+ logger.info(f"Doc2x still processing")
+ elif res_json['data']['status'] == "failed":
+ raise RuntimeError(f"Doc2x return an error: {res_json}")
+
+ # < ------ 第5步:最后的处理 ------ >
+ logger.info("Doc2x 第5步:最后的处理")
+
+ if format=='tex':
+ target_path = latex_dir
+ if format=='md':
+ target_path = markdown_dir
+ os.makedirs(target_path, exist_ok=True)
+
+ max_attempt = 3
+ # < ------ 下载 ------ >
+ for attempt in range(max_attempt):
+ try:
+ result_url = res_json['data']['url']
+ res = requests.get(result_url)
+ zip_path = os.path.join(target_path, gen_time_str() + '.zip')
+ unzip_path = os.path.join(target_path, gen_time_str())
+ if res.status_code == 200:
+ with open(zip_path, "wb") as f: f.write(res.content)
+ else:
+ raise RuntimeError(f"Doc2x return an error: {res.json()}")
+ except Exception as e:
+ if attempt < max_attempt - 1:
+ logger.error(f"Failed to download latex file, retrying... {e}")
+ time.sleep(3)
+ continue
+ else:
+ raise e
+
+ # < ------ 解压 ------ >
import zipfile
- with zipfile.ZipFile(latex_zip_path, 'r') as zip_ref:
- zip_ref.extractall(latex_unzip_path)
-
-
- return latex_unzip_path
-
-
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
+ zip_ref.extractall(unzip_path)
+ return zip_path, unzip_path
def 解析PDF_DOC2X_单文件(fp, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, DOC2X_API_KEY, user_request):
-
def pdf2markdown(filepath):
- import requests, json, os
- markdown_dir = get_log_folder(plugin_name="pdf_ocr")
- doc2x_api_key = DOC2X_API_KEY
- if doc2x_api_key.startswith('sk-'):
- url = "https://api.doc2x.noedgeai.com/api/v1/pdf"
- else:
- doc2x_api_key = refresh_key(doc2x_api_key)
- url = "https://api.doc2x.noedgeai.com/api/platform/pdf"
-
- chatbot.append((None, "加载PDF文件,发送至DOC2X解析..."))
+ chatbot.append((None, f"Doc2x 解析中"))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
- res = requests.post(
- url,
- files={"file": open(filepath, "rb")},
- data={"ocr": "1"},
- headers={"Authorization": "Bearer " + doc2x_api_key}
- )
- res_json = []
- if res.status_code == 200:
- decoded = res.content.decode("utf-8")
- for z_decoded in decoded.split('\n'):
- if len(z_decoded) == 0: continue
- assert z_decoded.startswith("data: ")
- z_decoded = z_decoded[len("data: "):]
- decoded_json = json.loads(z_decoded)
- res_json.append(decoded_json)
- if 'limit exceeded' in decoded_json.get('status', ''):
- raise RuntimeError("Doc2x API 页数受限,请联系 Doc2x 方面,并更换新的 API 秘钥。")
- else:
- raise RuntimeError(format("[ERROR] status code: %d, body: %s" % (res.status_code, res.text)))
- uuid = res_json[0]['uuid']
- to = "md" # latex, md, docx
- url = "https://api.doc2x.noedgeai.com/api/export"+"?request_id="+uuid+"&to="+to
+ md_zip_path, unzipped_folder = 解析PDF_DOC2X(filepath, format='md')
- chatbot.append((None, f"读取解析: {url} ..."))
- yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
-
- res = requests.get(url, headers={"Authorization": "Bearer " + doc2x_api_key})
- md_zip_path = os.path.join(markdown_dir, gen_time_str() + '.zip')
- if res.status_code == 200:
- with open(md_zip_path, "wb") as f: f.write(res.content)
- else:
- raise RuntimeError(format("[ERROR] status code: %d, body: %s" % (res.status_code, res.text)))
promote_file_to_downloadzone(md_zip_path, chatbot=chatbot)
chatbot.append((None, f"完成解析 {md_zip_path} ..."))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
diff --git a/docker-compose.yml b/docker-compose.yml
index 06a35600..cd72e3af 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -180,6 +180,7 @@ version: '3'
services:
gpt_academic_with_latex:
image: ghcr.io/binary-husky/gpt_academic_with_latex:master # (Auto Built by Dockerfile: docs/GithubAction+NoLocal+Latex)
+ # 对于ARM64设备,请将以上镜像名称替换为 ghcr.io/binary-husky/gpt_academic_with_latex_arm:master
environment:
# 请查阅 `config.py` 以查看所有的配置信息
API_KEY: ' sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx '
diff --git a/docs/GithubAction+NoLocal+Latex b/docs/GithubAction+NoLocal+Latex
index 533c6e35..71d51796 100644
--- a/docs/GithubAction+NoLocal+Latex
+++ b/docs/GithubAction+NoLocal+Latex
@@ -1,4 +1,4 @@
-# 此Dockerfile适用于“无本地模型”的环境构建,如果需要使用chatglm等本地模型,请参考 docs/Dockerfile+ChatGLM
+# 此Dockerfile适用于"无本地模型"的环境构建,如果需要使用chatglm等本地模型,请参考 docs/Dockerfile+ChatGLM
# - 1 修改 `config.py`
# - 2 构建 docker build -t gpt-academic-nolocal-latex -f docs/GithubAction+NoLocal+Latex .
# - 3 运行 docker run -v /home/fuqingxu/arxiv_cache:/root/arxiv_cache --rm -it --net=host gpt-academic-nolocal-latex
@@ -7,15 +7,28 @@ FROM menghuan1918/ubuntu_uv_ctex:latest
ENV DEBIAN_FRONTEND=noninteractive
SHELL ["/bin/bash", "-c"]
WORKDIR /gpt
-COPY . .
-RUN /root/.cargo/bin/uv venv --seed \
- && source .venv/bin/activate \
- && /root/.cargo/bin/uv pip install openai numpy arxiv rich colorama Markdown pygments pymupdf python-docx pdfminer \
- && /root/.cargo/bin/uv pip install -r requirements.txt \
- && /root/.cargo/bin/uv clean
+
+# 先复制依赖文件
+COPY requirements.txt .
+
+# 安装依赖
+RUN pip install --break-system-packages openai numpy arxiv rich colorama Markdown pygments pymupdf python-docx pdfminer \
+ && pip install --break-system-packages -r requirements.txt \
+ && if [ "$(uname -m)" = "x86_64" ]; then \
+ pip install --break-system-packages nougat-ocr; \
+ fi \
+ && pip cache purge \
+ && rm -rf /root/.cache/pip/*
+
+# 创建非root用户
+RUN useradd -m gptuser && chown -R gptuser /gpt
+USER gptuser
+
+# 最后才复制代码文件,这样代码更新时只需重建最后几层,可以大幅减少docker pull所需的大小
+COPY --chown=gptuser:gptuser . .
# 可选步骤,用于预热模块
-RUN .venv/bin/python3 -c 'from check_proxy import warm_up_modules; warm_up_modules()'
+RUN python3 -c 'from check_proxy import warm_up_modules; warm_up_modules()'
# 启动
-CMD [".venv/bin/python3", "-u", "main.py"]
+CMD ["python3", "-u", "main.py"]
diff --git a/request_llms/bridge_all.py b/request_llms/bridge_all.py
index 2355485f..3eaa96ac 100644
--- a/request_llms/bridge_all.py
+++ b/request_llms/bridge_all.py
@@ -385,6 +385,14 @@ model_info = {
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
+ "glm-4-plus":{
+ "fn_with_ui": zhipu_ui,
+ "fn_without_ui": zhipu_noui,
+ "endpoint": None,
+ "max_token": 10124 * 8,
+ "tokenizer": tokenizer_gpt35,
+ "token_cnt": get_token_num_gpt35,
+ },
# api_2d (此后不需要在此处添加api2d的接口了,因为下面的代码会自动添加)
"api2d-gpt-4": {
@@ -1285,4 +1293,3 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot,
# 更新一下llm_kwargs的参数,否则会出现参数不匹配的问题
yield from method(inputs, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, stream, additional_fn)
-
diff --git a/request_llms/bridge_chatgpt.py b/request_llms/bridge_chatgpt.py
index d4cf1ef5..9e719a4e 100644
--- a/request_llms/bridge_chatgpt.py
+++ b/request_llms/bridge_chatgpt.py
@@ -202,10 +202,13 @@ def predict_no_ui_long_connection(inputs:str, llm_kwargs:dict, history:list=[],
if (time.time()-observe_window[1]) > watch_dog_patience:
raise RuntimeError("用户取消了程序。")
else: raise RuntimeError("意外Json结构:"+delta)
- if json_data and json_data['finish_reason'] == 'content_filter':
- raise RuntimeError("由于提问含不合规内容被Azure过滤。")
- if json_data and json_data['finish_reason'] == 'length':
+
+ finish_reason = json_data.get('finish_reason', None) if json_data else None
+ if finish_reason == 'content_filter':
+ raise RuntimeError("由于提问含不合规内容被过滤。")
+ if finish_reason == 'length':
raise ConnectionAbortedError("正常结束,但显示Token不足,导致输出不完整,请削减单次输入的文本量。")
+
return result
@@ -536,4 +539,3 @@ def generate_payload(inputs:str, llm_kwargs:dict, history:list, system_prompt:st
return headers,payload
-
diff --git a/shared_utils/fastapi_server.py b/shared_utils/fastapi_server.py
index 6c9b1d1c..2993c987 100644
--- a/shared_utils/fastapi_server.py
+++ b/shared_utils/fastapi_server.py
@@ -138,7 +138,9 @@ def start_app(app_block, CONCURRENT_COUNT, AUTHENTICATION, PORT, SSL_KEYFILE, SS
app_block.is_sagemaker = False
gradio_app = App.create_app(app_block)
-
+ for route in list(gradio_app.router.routes):
+ if route.path == "/proxy={url_path:path}":
+ gradio_app.router.routes.remove(route)
# --- --- replace gradio endpoint to forbid access to sensitive files --- ---
if len(AUTHENTICATION) > 0:
dependencies = []
@@ -154,9 +156,13 @@ def start_app(app_block, CONCURRENT_COUNT, AUTHENTICATION, PORT, SSL_KEYFILE, SS
@gradio_app.head("/file={path_or_url:path}", dependencies=dependencies)
@gradio_app.get("/file={path_or_url:path}", dependencies=dependencies)
async def file(path_or_url: str, request: fastapi.Request):
- if len(AUTHENTICATION) > 0:
- if not _authorize_user(path_or_url, request, gradio_app):
- return "越权访问!"
+ if not _authorize_user(path_or_url, request, gradio_app):
+ return "越权访问!"
+ stripped = path_or_url.lstrip().lower()
+ if stripped.startswith("https://") or stripped.startswith("http://"):
+ return "账户密码授权模式下, 禁止链接!"
+ if '../' in stripped:
+ return "非法路径!"
return await endpoint(path_or_url, request)
from fastapi import Request, status
@@ -167,6 +173,26 @@ def start_app(app_block, CONCURRENT_COUNT, AUTHENTICATION, PORT, SSL_KEYFILE, SS
response.delete_cookie('access-token')
response.delete_cookie('access-token-unsecure')
return response
+ else:
+ dependencies = []
+ endpoint = None
+ for route in list(gradio_app.router.routes):
+ if route.path == "/file/{path:path}":
+ gradio_app.router.routes.remove(route)
+ if route.path == "/file={path_or_url:path}":
+ dependencies = route.dependencies
+ endpoint = route.endpoint
+ gradio_app.router.routes.remove(route)
+ @gradio_app.get("/file/{path:path}", dependencies=dependencies)
+ @gradio_app.head("/file={path_or_url:path}", dependencies=dependencies)
+ @gradio_app.get("/file={path_or_url:path}", dependencies=dependencies)
+ async def file(path_or_url: str, request: fastapi.Request):
+ stripped = path_or_url.lstrip().lower()
+ if stripped.startswith("https://") or stripped.startswith("http://"):
+ return "账户密码授权模式下, 禁止链接!"
+ if '../' in stripped:
+ return "非法路径!"
+ return await endpoint(path_or_url, request)
# --- --- enable TTS (text-to-speech) functionality --- ---
TTS_TYPE = get_conf("TTS_TYPE")
diff --git a/tests/test_doc2x.py b/tests/test_doc2x.py
new file mode 100644
index 00000000..9d02c4b7
--- /dev/null
+++ b/tests/test_doc2x.py
@@ -0,0 +1,7 @@
+import init_test
+
+from crazy_functions.pdf_fns.parse_pdf_via_doc2x import 解析PDF_DOC2X_转Latex
+
+# 解析PDF_DOC2X_转Latex("gpt_log/arxiv_cache_old/2410.10819/workfolder/merge.pdf")
+# 解析PDF_DOC2X_转Latex("gpt_log/arxiv_cache_ooo/2410.07095/workfolder/merge.pdf")
+解析PDF_DOC2X_转Latex("2410.11190v2.pdf")