镜像自地址
https://github.com/binary-husky/gpt_academic.git
已同步 2025-12-06 14:36:48 +00:00
比较提交
90 次代码提交
boyin_essa
...
boyin_rag
| 作者 | SHA1 | 提交日期 | |
|---|---|---|---|
|
|
2f946f3e6c | ||
|
|
51ea7f3b5e | ||
|
|
795a6a9333 | ||
|
|
3beb22a347 | ||
|
|
b3aef6b393 | ||
|
|
cf51d4b205 | ||
|
|
bd9c88e896 | ||
|
|
27958b9030 | ||
|
|
9b9d77eded | ||
|
|
50dbff3a14 | ||
|
|
e1dc600030 | ||
|
|
6557c3822a | ||
|
|
81ab9f91a4 | ||
|
|
241c9641bb | ||
|
|
b2d6536974 | ||
|
|
12be7c16e9 | ||
|
|
724940a9d8 | ||
|
|
ea4cd95645 | ||
|
|
f8b60870e9 | ||
|
|
cbef9a908c | ||
|
|
21626a44d5 | ||
|
|
dd902e9519 | ||
|
|
68aa846a89 | ||
|
|
b8617921f4 | ||
|
|
c6687646e4 | ||
|
|
bfa72fb4cf | ||
|
|
61676d0536 | ||
|
|
df2ef7940c | ||
|
|
0afd27deca | ||
|
|
91f5e6b8f7 | ||
|
|
c10f2b45e5 | ||
|
|
7e2ede2d12 | ||
|
|
ec10e2a3ac | ||
|
|
4f0851f703 | ||
|
|
2821f27756 | ||
|
|
7474d43433 | ||
|
|
83489f9acf | ||
|
|
36e50d490d | ||
|
|
9172337695 | ||
|
|
180550b8f0 | ||
|
|
7497dcb852 | ||
|
|
5dab7b2290 | ||
|
|
23ef2ffb22 | ||
|
|
848d0f65c7 | ||
|
|
f0b0364f74 | ||
|
|
89dc6c7265 | ||
|
|
69f3755682 | ||
|
|
4727113243 | ||
|
|
21111d3bd0 | ||
|
|
be9aead04a | ||
|
|
701018f48c | ||
|
|
8733c4e1e9 | ||
|
|
8498ddf6bf | ||
|
|
3c3293818d | ||
|
|
310122f5a7 | ||
|
|
9adc0ade71 | ||
|
|
bbcdd9aa71 | ||
|
|
cdfe38d296 | ||
|
|
159f628dfe | ||
|
|
5888d038aa | ||
|
|
ee8213e936 | ||
|
|
a57dcbcaeb | ||
|
|
b812392a9d | ||
|
|
fce4fa1ec7 | ||
|
|
d13f1e270c | ||
|
|
85cf3d08eb | ||
|
|
584e747565 | ||
|
|
02ba653c19 | ||
|
|
2d12b5b27d | ||
|
|
a4bcd262f9 | ||
|
|
748e31102f | ||
|
|
97eef45ab7 | ||
|
|
0c0e2acb9b | ||
|
|
9fba8e0142 | ||
|
|
7d7867fb64 | ||
|
|
7ea791d83a | ||
|
|
f9dbaa39fb | ||
|
|
bbc2288c5b | ||
|
|
64ab916838 | ||
|
|
8fe559da9f | ||
|
|
09fd22091a | ||
|
|
df717f8bba | ||
|
|
e296719b23 | ||
|
|
2f343179a2 | ||
|
|
4d9604f2e9 | ||
|
|
bbf9e9f868 | ||
|
|
aa1f967dd7 | ||
|
|
0d082327c8 | ||
|
|
80acd9c875 | ||
|
|
17cd4f8210 |
2
.github/workflows/build-with-latex-arm.yml
vendored
2
.github/workflows/build-with-latex-arm.yml
vendored
@@ -46,6 +46,6 @@ jobs:
|
|||||||
context: .
|
context: .
|
||||||
push: true
|
push: true
|
||||||
platforms: linux/arm64
|
platforms: linux/arm64
|
||||||
file: docs/GithubAction+NoLocal+Latex+Arm
|
file: docs/GithubAction+NoLocal+Latex
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
tags: ${{ steps.meta.outputs.tags }}
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -160,4 +160,5 @@ test.*
|
|||||||
temp.*
|
temp.*
|
||||||
objdump*
|
objdump*
|
||||||
*.min.*.js
|
*.min.*.js
|
||||||
TODO
|
TODO
|
||||||
|
*.cursorrules
|
||||||
|
|||||||
@@ -1,24 +1,36 @@
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
def check_proxy(proxies, return_ip=False):
|
def check_proxy(proxies, return_ip=False):
|
||||||
|
"""
|
||||||
|
检查代理配置并返回结果。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
proxies (dict): 包含http和https代理配置的字典。
|
||||||
|
return_ip (bool, optional): 是否返回代理的IP地址。默认为False。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str or None: 检查的结果信息或代理的IP地址(如果`return_ip`为True)。
|
||||||
|
"""
|
||||||
import requests
|
import requests
|
||||||
proxies_https = proxies['https'] if proxies is not None else '无'
|
proxies_https = proxies['https'] if proxies is not None else '无'
|
||||||
ip = None
|
ip = None
|
||||||
try:
|
try:
|
||||||
response = requests.get("https://ipapi.co/json/", proxies=proxies, timeout=4)
|
response = requests.get("https://ipapi.co/json/", proxies=proxies, timeout=4) # ⭐ 执行GET请求以获取代理信息
|
||||||
data = response.json()
|
data = response.json()
|
||||||
if 'country_name' in data:
|
if 'country_name' in data:
|
||||||
country = data['country_name']
|
country = data['country_name']
|
||||||
result = f"代理配置 {proxies_https}, 代理所在地:{country}"
|
result = f"代理配置 {proxies_https}, 代理所在地:{country}"
|
||||||
if 'ip' in data: ip = data['ip']
|
if 'ip' in data:
|
||||||
|
ip = data['ip']
|
||||||
elif 'error' in data:
|
elif 'error' in data:
|
||||||
alternative, ip = _check_with_backup_source(proxies)
|
alternative, ip = _check_with_backup_source(proxies) # ⭐ 调用备用方法检查代理配置
|
||||||
if alternative is None:
|
if alternative is None:
|
||||||
result = f"代理配置 {proxies_https}, 代理所在地:未知,IP查询频率受限"
|
result = f"代理配置 {proxies_https}, 代理所在地:未知,IP查询频率受限"
|
||||||
else:
|
else:
|
||||||
result = f"代理配置 {proxies_https}, 代理所在地:{alternative}"
|
result = f"代理配置 {proxies_https}, 代理所在地:{alternative}"
|
||||||
else:
|
else:
|
||||||
result = f"代理配置 {proxies_https}, 代理数据解析失败:{data}"
|
result = f"代理配置 {proxies_https}, 代理数据解析失败:{data}"
|
||||||
|
|
||||||
if not return_ip:
|
if not return_ip:
|
||||||
logger.warning(result)
|
logger.warning(result)
|
||||||
return result
|
return result
|
||||||
@@ -33,17 +45,33 @@ def check_proxy(proxies, return_ip=False):
|
|||||||
return ip
|
return ip
|
||||||
|
|
||||||
def _check_with_backup_source(proxies):
|
def _check_with_backup_source(proxies):
|
||||||
|
"""
|
||||||
|
通过备份源检查代理,并获取相应信息。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
proxies (dict): 包含代理信息的字典。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: 代理信息(geo)和IP地址(ip)的元组。
|
||||||
|
"""
|
||||||
import random, string, requests
|
import random, string, requests
|
||||||
random_string = ''.join(random.choices(string.ascii_letters + string.digits, k=32))
|
random_string = ''.join(random.choices(string.ascii_letters + string.digits, k=32))
|
||||||
try:
|
try:
|
||||||
res_json = requests.get(f"http://{random_string}.edns.ip-api.com/json", proxies=proxies, timeout=4).json()
|
res_json = requests.get(f"http://{random_string}.edns.ip-api.com/json", proxies=proxies, timeout=4).json() # ⭐ 执行代理检查和备份源请求
|
||||||
return res_json['dns']['geo'], res_json['dns']['ip']
|
return res_json['dns']['geo'], res_json['dns']['ip']
|
||||||
except:
|
except:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
def backup_and_download(current_version, remote_version):
|
def backup_and_download(current_version, remote_version):
|
||||||
"""
|
"""
|
||||||
一键更新协议:备份和下载
|
一键更新协议:备份当前版本,下载远程版本并解压缩。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_version (str): 当前版本号。
|
||||||
|
remote_version (str): 远程版本号。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 新版本目录的路径。
|
||||||
"""
|
"""
|
||||||
from toolbox import get_conf
|
from toolbox import get_conf
|
||||||
import shutil
|
import shutil
|
||||||
@@ -60,7 +88,7 @@ def backup_and_download(current_version, remote_version):
|
|||||||
proxies = get_conf('proxies')
|
proxies = get_conf('proxies')
|
||||||
try: r = requests.get('https://github.com/binary-husky/chatgpt_academic/archive/refs/heads/master.zip', proxies=proxies, stream=True)
|
try: r = requests.get('https://github.com/binary-husky/chatgpt_academic/archive/refs/heads/master.zip', proxies=proxies, stream=True)
|
||||||
except: r = requests.get('https://public.agent-matrix.com/publish/master.zip', proxies=proxies, stream=True)
|
except: r = requests.get('https://public.agent-matrix.com/publish/master.zip', proxies=proxies, stream=True)
|
||||||
zip_file_path = backup_dir+'/master.zip'
|
zip_file_path = backup_dir+'/master.zip' # ⭐ 保存备份文件的路径
|
||||||
with open(zip_file_path, 'wb+') as f:
|
with open(zip_file_path, 'wb+') as f:
|
||||||
f.write(r.content)
|
f.write(r.content)
|
||||||
dst_path = new_version_dir
|
dst_path = new_version_dir
|
||||||
@@ -76,6 +104,17 @@ def backup_and_download(current_version, remote_version):
|
|||||||
def patch_and_restart(path):
|
def patch_and_restart(path):
|
||||||
"""
|
"""
|
||||||
一键更新协议:覆盖和重启
|
一键更新协议:覆盖和重启
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): 新版本代码所在的路径
|
||||||
|
|
||||||
|
注意事项:
|
||||||
|
如果您的程序没有使用config_private.py私密配置文件,则会将config.py重命名为config_private.py以避免配置丢失。
|
||||||
|
|
||||||
|
更新流程:
|
||||||
|
- 复制最新版本代码到当前目录
|
||||||
|
- 更新pip包依赖
|
||||||
|
- 如果更新失败,则提示手动安装依赖库并重启
|
||||||
"""
|
"""
|
||||||
from distutils import dir_util
|
from distutils import dir_util
|
||||||
import shutil
|
import shutil
|
||||||
@@ -84,32 +123,43 @@ def patch_and_restart(path):
|
|||||||
import time
|
import time
|
||||||
import glob
|
import glob
|
||||||
from shared_utils.colorful import log亮黄, log亮绿, log亮红
|
from shared_utils.colorful import log亮黄, log亮绿, log亮红
|
||||||
# if not using config_private, move origin config.py as config_private.py
|
|
||||||
if not os.path.exists('config_private.py'):
|
if not os.path.exists('config_private.py'):
|
||||||
log亮黄('由于您没有设置config_private.py私密配置,现将您的现有配置移动至config_private.py以防止配置丢失,',
|
log亮黄('由于您没有设置config_private.py私密配置,现将您的现有配置移动至config_private.py以防止配置丢失,',
|
||||||
'另外您可以随时在history子文件夹下找回旧版的程序。')
|
'另外您可以随时在history子文件夹下找回旧版的程序。')
|
||||||
shutil.copyfile('config.py', 'config_private.py')
|
shutil.copyfile('config.py', 'config_private.py')
|
||||||
|
|
||||||
path_new_version = glob.glob(path + '/*-master')[0]
|
path_new_version = glob.glob(path + '/*-master')[0]
|
||||||
dir_util.copy_tree(path_new_version, './')
|
dir_util.copy_tree(path_new_version, './') # ⭐ 将最新版本代码复制到当前目录
|
||||||
|
|
||||||
log亮绿('代码已经更新,即将更新pip包依赖……')
|
log亮绿('代码已经更新,即将更新pip包依赖……')
|
||||||
for i in reversed(range(5)): time.sleep(1); log亮绿(i)
|
for i in reversed(range(5)): time.sleep(1); log亮绿(i)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import subprocess
|
import subprocess
|
||||||
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-r', 'requirements.txt'])
|
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-r', 'requirements.txt'])
|
||||||
except:
|
except:
|
||||||
log亮红('pip包依赖安装出现问题,需要手动安装新增的依赖库 `python -m pip install -r requirements.txt`,然后在用常规的`python main.py`的方式启动。')
|
log亮红('pip包依赖安装出现问题,需要手动安装新增的依赖库 `python -m pip install -r requirements.txt`,然后在用常规的`python main.py`的方式启动。')
|
||||||
|
|
||||||
log亮绿('更新完成,您可以随时在history子文件夹下找回旧版的程序,5s之后重启')
|
log亮绿('更新完成,您可以随时在history子文件夹下找回旧版的程序,5s之后重启')
|
||||||
log亮红('假如重启失败,您可能需要手动安装新增的依赖库 `python -m pip install -r requirements.txt`,然后在用常规的`python main.py`的方式启动。')
|
log亮红('假如重启失败,您可能需要手动安装新增的依赖库 `python -m pip install -r requirements.txt`,然后在用常规的`python main.py`的方式启动。')
|
||||||
log亮绿(' ------------------------------ -----------------------------------')
|
log亮绿(' ------------------------------ -----------------------------------')
|
||||||
|
|
||||||
for i in reversed(range(8)): time.sleep(1); log亮绿(i)
|
for i in reversed(range(8)): time.sleep(1); log亮绿(i)
|
||||||
os.execl(sys.executable, sys.executable, *sys.argv)
|
os.execl(sys.executable, sys.executable, *sys.argv) # 重启程序
|
||||||
|
|
||||||
|
|
||||||
def get_current_version():
|
def get_current_version():
|
||||||
|
"""
|
||||||
|
获取当前的版本号。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 当前的版本号。如果无法获取版本号,则返回空字符串。
|
||||||
|
"""
|
||||||
import json
|
import json
|
||||||
try:
|
try:
|
||||||
with open('./version', 'r', encoding='utf8') as f:
|
with open('./version', 'r', encoding='utf8') as f:
|
||||||
current_version = json.loads(f.read())['version']
|
current_version = json.loads(f.read())['version'] # ⭐ 从读取的json数据中提取版本号
|
||||||
except:
|
except:
|
||||||
current_version = ""
|
current_version = ""
|
||||||
return current_version
|
return current_version
|
||||||
@@ -118,6 +168,12 @@ def get_current_version():
|
|||||||
def auto_update(raise_error=False):
|
def auto_update(raise_error=False):
|
||||||
"""
|
"""
|
||||||
一键更新协议:查询版本和用户意见
|
一键更新协议:查询版本和用户意见
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raise_error (bool, optional): 是否在出错时抛出错误。默认为 False。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from toolbox import get_conf
|
from toolbox import get_conf
|
||||||
@@ -137,13 +193,13 @@ def auto_update(raise_error=False):
|
|||||||
current_version = json.loads(current_version)['version']
|
current_version = json.loads(current_version)['version']
|
||||||
if (remote_version - current_version) >= 0.01-1e-5:
|
if (remote_version - current_version) >= 0.01-1e-5:
|
||||||
from shared_utils.colorful import log亮黄
|
from shared_utils.colorful import log亮黄
|
||||||
log亮黄(f'\n新版本可用。新版本:{remote_version},当前版本:{current_version}。{new_feature}')
|
log亮黄(f'\n新版本可用。新版本:{remote_version},当前版本:{current_version}。{new_feature}') # ⭐ 在控制台打印新版本信息
|
||||||
logger.info('(1)Github更新地址:\nhttps://github.com/binary-husky/chatgpt_academic\n')
|
logger.info('(1)Github更新地址:\nhttps://github.com/binary-husky/chatgpt_academic\n')
|
||||||
user_instruction = input('(2)是否一键更新代码(Y+回车=确认,输入其他/无输入+回车=不更新)?')
|
user_instruction = input('(2)是否一键更新代码(Y+回车=确认,输入其他/无输入+回车=不更新)?')
|
||||||
if user_instruction in ['Y', 'y']:
|
if user_instruction in ['Y', 'y']:
|
||||||
path = backup_and_download(current_version, remote_version)
|
path = backup_and_download(current_version, remote_version) # ⭐ 备份并下载文件
|
||||||
try:
|
try:
|
||||||
patch_and_restart(path)
|
patch_and_restart(path) # ⭐ 执行覆盖并重启操作
|
||||||
except:
|
except:
|
||||||
msg = '更新失败。'
|
msg = '更新失败。'
|
||||||
if raise_error:
|
if raise_error:
|
||||||
@@ -163,6 +219,9 @@ def auto_update(raise_error=False):
|
|||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
|
|
||||||
def warm_up_modules():
|
def warm_up_modules():
|
||||||
|
"""
|
||||||
|
预热模块,加载特定模块并执行预热操作。
|
||||||
|
"""
|
||||||
logger.info('正在执行一些模块的预热 ...')
|
logger.info('正在执行一些模块的预热 ...')
|
||||||
from toolbox import ProxyNetworkActivate
|
from toolbox import ProxyNetworkActivate
|
||||||
from request_llms.bridge_all import model_info
|
from request_llms.bridge_all import model_info
|
||||||
@@ -173,6 +232,16 @@ def warm_up_modules():
|
|||||||
enc.encode("模块预热", disallowed_special=())
|
enc.encode("模块预热", disallowed_special=())
|
||||||
|
|
||||||
def warm_up_vectordb():
|
def warm_up_vectordb():
|
||||||
|
"""
|
||||||
|
执行一些模块的预热操作。
|
||||||
|
|
||||||
|
本函数主要用于执行一些模块的预热操作,确保在后续的流程中能够顺利运行。
|
||||||
|
|
||||||
|
⭐ 关键作用:预热模块
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
logger.info('正在执行一些模块的预热 ...')
|
logger.info('正在执行一些模块的预热 ...')
|
||||||
from toolbox import ProxyNetworkActivate
|
from toolbox import ProxyNetworkActivate
|
||||||
with ProxyNetworkActivate("Warmup_Modules"):
|
with ProxyNetworkActivate("Warmup_Modules"):
|
||||||
@@ -185,4 +254,4 @@ if __name__ == '__main__':
|
|||||||
os.environ['no_proxy'] = '*' # 避免代理网络产生意外污染
|
os.environ['no_proxy'] = '*' # 避免代理网络产生意外污染
|
||||||
from toolbox import get_conf
|
from toolbox import get_conf
|
||||||
proxies = get_conf('proxies')
|
proxies = get_conf('proxies')
|
||||||
check_proxy(proxies)
|
check_proxy(proxies)
|
||||||
@@ -15,13 +15,13 @@ def get_crazy_functions():
|
|||||||
from crazy_functions.SourceCode_Analyse import 解析一个Rust项目
|
from crazy_functions.SourceCode_Analyse import 解析一个Rust项目
|
||||||
from crazy_functions.SourceCode_Analyse import 解析一个Java项目
|
from crazy_functions.SourceCode_Analyse import 解析一个Java项目
|
||||||
from crazy_functions.SourceCode_Analyse import 解析一个前端项目
|
from crazy_functions.SourceCode_Analyse import 解析一个前端项目
|
||||||
|
from crazy_functions.Arxiv_论文对话 import Arxiv论文对话
|
||||||
from crazy_functions.高级功能函数模板 import 高阶功能模板函数
|
from crazy_functions.高级功能函数模板 import 高阶功能模板函数
|
||||||
from crazy_functions.高级功能函数模板 import Demo_Wrap
|
from crazy_functions.高级功能函数模板 import Demo_Wrap
|
||||||
from crazy_functions.Latex全文润色 import Latex英文润色
|
from crazy_functions.Latex全文润色 import Latex英文润色
|
||||||
from crazy_functions.询问多个大语言模型 import 同时问询
|
from crazy_functions.询问多个大语言模型 import 同时问询
|
||||||
from crazy_functions.SourceCode_Analyse import 解析一个Lua项目
|
from crazy_functions.SourceCode_Analyse import 解析一个Lua项目
|
||||||
from crazy_functions.SourceCode_Analyse import 解析一个CSharp项目
|
from crazy_functions.SourceCode_Analyse import 解析一个CSharp项目
|
||||||
from crazy_functions.总结word文档 import 总结word文档
|
|
||||||
from crazy_functions.解析JupyterNotebook import 解析ipynb文件
|
from crazy_functions.解析JupyterNotebook import 解析ipynb文件
|
||||||
from crazy_functions.Conversation_To_File import 载入对话历史存档
|
from crazy_functions.Conversation_To_File import 载入对话历史存档
|
||||||
from crazy_functions.Conversation_To_File import 对话历史存档
|
from crazy_functions.Conversation_To_File import 对话历史存档
|
||||||
@@ -31,6 +31,8 @@ def get_crazy_functions():
|
|||||||
from crazy_functions.Markdown_Translate import Markdown英译中
|
from crazy_functions.Markdown_Translate import Markdown英译中
|
||||||
from crazy_functions.批量总结PDF文档 import 批量总结PDF文档
|
from crazy_functions.批量总结PDF文档 import 批量总结PDF文档
|
||||||
from crazy_functions.PDF_Translate import 批量翻译PDF文档
|
from crazy_functions.PDF_Translate import 批量翻译PDF文档
|
||||||
|
from crazy_functions.批量文件询问 import 批量文件询问
|
||||||
|
|
||||||
from crazy_functions.谷歌检索小助手 import 谷歌检索小助手
|
from crazy_functions.谷歌检索小助手 import 谷歌检索小助手
|
||||||
from crazy_functions.理解PDF文档内容 import 理解PDF文档内容标准文件输入
|
from crazy_functions.理解PDF文档内容 import 理解PDF文档内容标准文件输入
|
||||||
from crazy_functions.Latex全文润色 import Latex中文润色
|
from crazy_functions.Latex全文润色 import Latex中文润色
|
||||||
@@ -49,6 +51,7 @@ def get_crazy_functions():
|
|||||||
from crazy_functions.Image_Generate import 图片生成_DALLE2, 图片生成_DALLE3, 图片修改_DALLE2
|
from crazy_functions.Image_Generate import 图片生成_DALLE2, 图片生成_DALLE3, 图片修改_DALLE2
|
||||||
from crazy_functions.Image_Generate_Wrap import ImageGen_Wrap
|
from crazy_functions.Image_Generate_Wrap import ImageGen_Wrap
|
||||||
from crazy_functions.SourceCode_Comment import 注释Python项目
|
from crazy_functions.SourceCode_Comment import 注释Python项目
|
||||||
|
from crazy_functions.SourceCode_Comment_Wrap import SourceCodeComment_Wrap
|
||||||
|
|
||||||
function_plugins = {
|
function_plugins = {
|
||||||
"虚空终端": {
|
"虚空终端": {
|
||||||
@@ -58,33 +61,6 @@ def get_crazy_functions():
|
|||||||
"Info": "使用自然语言实现您的想法",
|
"Info": "使用自然语言实现您的想法",
|
||||||
"Function": HotReload(虚空终端),
|
"Function": HotReload(虚空终端),
|
||||||
},
|
},
|
||||||
"解析整个Python项目": {
|
|
||||||
"Group": "编程",
|
|
||||||
"Color": "stop",
|
|
||||||
"AsButton": True,
|
|
||||||
"Info": "解析一个Python项目的所有源文件(.py) | 输入参数为路径",
|
|
||||||
"Function": HotReload(解析一个Python项目),
|
|
||||||
},
|
|
||||||
"注释Python项目": {
|
|
||||||
"Group": "编程",
|
|
||||||
"Color": "stop",
|
|
||||||
"AsButton": False,
|
|
||||||
"Info": "上传一系列python源文件(或者压缩包), 为这些代码添加docstring | 输入参数为路径",
|
|
||||||
"Function": HotReload(注释Python项目),
|
|
||||||
},
|
|
||||||
"载入对话历史存档(先上传存档或输入路径)": {
|
|
||||||
"Group": "对话",
|
|
||||||
"Color": "stop",
|
|
||||||
"AsButton": False,
|
|
||||||
"Info": "载入对话历史存档 | 输入参数为路径",
|
|
||||||
"Function": HotReload(载入对话历史存档),
|
|
||||||
},
|
|
||||||
"删除所有本地对话历史记录(谨慎操作)": {
|
|
||||||
"Group": "对话",
|
|
||||||
"AsButton": False,
|
|
||||||
"Info": "删除所有本地对话历史记录,谨慎操作 | 不需要输入参数",
|
|
||||||
"Function": HotReload(删除所有本地对话历史记录),
|
|
||||||
},
|
|
||||||
"清除所有缓存文件(谨慎操作)": {
|
"清除所有缓存文件(谨慎操作)": {
|
||||||
"Group": "对话",
|
"Group": "对话",
|
||||||
"Color": "stop",
|
"Color": "stop",
|
||||||
@@ -92,14 +68,6 @@ def get_crazy_functions():
|
|||||||
"Info": "清除所有缓存文件,谨慎操作 | 不需要输入参数",
|
"Info": "清除所有缓存文件,谨慎操作 | 不需要输入参数",
|
||||||
"Function": HotReload(清除缓存),
|
"Function": HotReload(清除缓存),
|
||||||
},
|
},
|
||||||
"生成多种Mermaid图表(从当前对话或路径(.pdf/.md/.docx)中生产图表)": {
|
|
||||||
"Group": "对话",
|
|
||||||
"Color": "stop",
|
|
||||||
"AsButton": False,
|
|
||||||
"Info" : "基于当前对话或文件生成多种Mermaid图表,图表类型由模型判断",
|
|
||||||
"Function": None,
|
|
||||||
"Class": Mermaid_Gen
|
|
||||||
},
|
|
||||||
"Arxiv论文翻译": {
|
"Arxiv论文翻译": {
|
||||||
"Group": "学术",
|
"Group": "学术",
|
||||||
"Color": "stop",
|
"Color": "stop",
|
||||||
@@ -108,91 +76,25 @@ def get_crazy_functions():
|
|||||||
"Function": HotReload(Latex翻译中文并重新编译PDF), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
|
"Function": HotReload(Latex翻译中文并重新编译PDF), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
|
||||||
"Class": Arxiv_Localize, # 新一代插件需要注册Class
|
"Class": Arxiv_Localize, # 新一代插件需要注册Class
|
||||||
},
|
},
|
||||||
"批量总结Word文档": {
|
"批量文件询问": {
|
||||||
"Group": "学术",
|
"Group": "学术",
|
||||||
"Color": "stop",
|
"Color": "stop",
|
||||||
"AsButton": False,
|
"AsButton": False,
|
||||||
"Info": "批量总结word文档 | 输入参数为路径",
|
"AdvancedArgs": True,
|
||||||
"Function": HotReload(总结word文档),
|
"Info": "通过在高级参数区写入prompt,可自定义询问逻辑,默认情况下为总结逻辑 | 输入参数为路径",
|
||||||
|
"ArgsReminder": r"1、请不要更改上方输入框中以“private_upload/...”开头的路径。 "
|
||||||
|
r"2、请在下方高级参数区中输入你的prompt,文档中的内容将被添加你的prompt后。3、示例:“请总结下面的内容:”,此时,文档内容将添加在“:”后 ",
|
||||||
|
"Function": HotReload(批量文件询问),
|
||||||
},
|
},
|
||||||
"解析整个Matlab项目": {
|
"Arxiv论文对话": {
|
||||||
"Group": "编程",
|
|
||||||
"Color": "stop",
|
|
||||||
"AsButton": False,
|
|
||||||
"Info": "解析一个Matlab项目的所有源文件(.m) | 输入参数为路径",
|
|
||||||
"Function": HotReload(解析一个Matlab项目),
|
|
||||||
},
|
|
||||||
"解析整个C++项目头文件": {
|
|
||||||
"Group": "编程",
|
|
||||||
"Color": "stop",
|
|
||||||
"AsButton": False, # 加入下拉菜单中
|
|
||||||
"Info": "解析一个C++项目的所有头文件(.h/.hpp) | 输入参数为路径",
|
|
||||||
"Function": HotReload(解析一个C项目的头文件),
|
|
||||||
},
|
|
||||||
"解析整个C++项目(.cpp/.hpp/.c/.h)": {
|
|
||||||
"Group": "编程",
|
|
||||||
"Color": "stop",
|
|
||||||
"AsButton": False, # 加入下拉菜单中
|
|
||||||
"Info": "解析一个C++项目的所有源文件(.cpp/.hpp/.c/.h)| 输入参数为路径",
|
|
||||||
"Function": HotReload(解析一个C项目),
|
|
||||||
},
|
|
||||||
"解析整个Go项目": {
|
|
||||||
"Group": "编程",
|
|
||||||
"Color": "stop",
|
|
||||||
"AsButton": False, # 加入下拉菜单中
|
|
||||||
"Info": "解析一个Go项目的所有源文件 | 输入参数为路径",
|
|
||||||
"Function": HotReload(解析一个Golang项目),
|
|
||||||
},
|
|
||||||
"解析整个Rust项目": {
|
|
||||||
"Group": "编程",
|
|
||||||
"Color": "stop",
|
|
||||||
"AsButton": False, # 加入下拉菜单中
|
|
||||||
"Info": "解析一个Rust项目的所有源文件 | 输入参数为路径",
|
|
||||||
"Function": HotReload(解析一个Rust项目),
|
|
||||||
},
|
|
||||||
"解析整个Java项目": {
|
|
||||||
"Group": "编程",
|
|
||||||
"Color": "stop",
|
|
||||||
"AsButton": False, # 加入下拉菜单中
|
|
||||||
"Info": "解析一个Java项目的所有源文件 | 输入参数为路径",
|
|
||||||
"Function": HotReload(解析一个Java项目),
|
|
||||||
},
|
|
||||||
"解析整个前端项目(js,ts,css等)": {
|
|
||||||
"Group": "编程",
|
|
||||||
"Color": "stop",
|
|
||||||
"AsButton": False, # 加入下拉菜单中
|
|
||||||
"Info": "解析一个前端项目的所有源文件(js,ts,css等) | 输入参数为路径",
|
|
||||||
"Function": HotReload(解析一个前端项目),
|
|
||||||
},
|
|
||||||
"解析整个Lua项目": {
|
|
||||||
"Group": "编程",
|
|
||||||
"Color": "stop",
|
|
||||||
"AsButton": False, # 加入下拉菜单中
|
|
||||||
"Info": "解析一个Lua项目的所有源文件 | 输入参数为路径",
|
|
||||||
"Function": HotReload(解析一个Lua项目),
|
|
||||||
},
|
|
||||||
"解析整个CSharp项目": {
|
|
||||||
"Group": "编程",
|
|
||||||
"Color": "stop",
|
|
||||||
"AsButton": False, # 加入下拉菜单中
|
|
||||||
"Info": "解析一个CSharp项目的所有源文件 | 输入参数为路径",
|
|
||||||
"Function": HotReload(解析一个CSharp项目),
|
|
||||||
},
|
|
||||||
"解析Jupyter Notebook文件": {
|
|
||||||
"Group": "编程",
|
|
||||||
"Color": "stop",
|
|
||||||
"AsButton": False,
|
|
||||||
"Info": "解析Jupyter Notebook文件 | 输入参数为路径",
|
|
||||||
"Function": HotReload(解析ipynb文件),
|
|
||||||
"AdvancedArgs": True, # 调用时,唤起高级参数输入区(默认False)
|
|
||||||
"ArgsReminder": "若输入0,则不解析notebook中的Markdown块", # 高级参数输入区的显示提示
|
|
||||||
},
|
|
||||||
"读Tex论文写摘要": {
|
|
||||||
"Group": "学术",
|
"Group": "学术",
|
||||||
"Color": "stop",
|
"Color": "stop",
|
||||||
"AsButton": False,
|
"AsButton": False,
|
||||||
"Info": "读取Tex论文并写摘要 | 输入参数为路径",
|
"AdvancedArgs": True,
|
||||||
"Function": HotReload(读文章写摘要),
|
"Info": "在输入区中输入论文ID,在高级参数区中输入问题",
|
||||||
|
"ArgsReminder": r"1、请在输入区中输入arxiv ID。 "
|
||||||
|
r"2、请在下方高级参数区中输入你的问题,示例:“这篇文章的方法是什么,请用中文回答我” ",
|
||||||
|
"Function": HotReload(Arxiv论文对话),
|
||||||
},
|
},
|
||||||
"翻译README或MD": {
|
"翻译README或MD": {
|
||||||
"Group": "编程",
|
"Group": "编程",
|
||||||
|
|||||||
573
crazy_functions/Arxiv_论文对话.py
普通文件
573
crazy_functions/Arxiv_论文对话.py
普通文件
@@ -0,0 +1,573 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from threading import Lock as ThreadLock
|
||||||
|
from typing import Generator
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
|
from crazy_functions.crazy_utils import input_clipping
|
||||||
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||||
|
from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import ArxivSplitter, save_fragments_to_file, process_arxiv_sync
|
||||||
|
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment as Fragment
|
||||||
|
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
|
||||||
|
from toolbox import CatchException, update_ui, get_log_folder, update_ui_lastest_msg
|
||||||
|
|
||||||
|
# 全局常量配置
|
||||||
|
MAX_HISTORY_ROUND = 5 # 最大历史对话轮数
|
||||||
|
MAX_CONTEXT_TOKEN_LIMIT = 4096 # 上下文最大token数
|
||||||
|
REMEMBER_PREVIEW = 1000 # 记忆预览长度
|
||||||
|
VECTOR_STORE_TYPE = "Simple" # 向量存储类型:Simple或Milvus
|
||||||
|
MAX_CONCURRENT_PAPERS = 20 # 最大并行处理论文数
|
||||||
|
MAX_WORKERS = 3 # 最大工作线程数
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProcessingTask:
|
||||||
|
"""论文处理任务数据类"""
|
||||||
|
arxiv_id: str
|
||||||
|
status: str = "pending" # pending, processing, completed, failed
|
||||||
|
error: Optional[str] = None
|
||||||
|
fragments: List[Fragment] = None
|
||||||
|
start_time: float = field(default_factory=time.time)
|
||||||
|
|
||||||
|
|
||||||
|
class ArxivRagWorker:
|
||||||
|
def __init__(self, user_name: str, llm_kwargs: Dict, arxiv_id: str = None):
|
||||||
|
"""初始化ArxivRagWorker"""
|
||||||
|
self.user_name = user_name
|
||||||
|
self.llm_kwargs = llm_kwargs
|
||||||
|
self.arxiv_id = self._normalize_arxiv_id(arxiv_id) if arxiv_id else None
|
||||||
|
self.fragments = None
|
||||||
|
|
||||||
|
|
||||||
|
# 初始化基础目录
|
||||||
|
self.base_dir = Path(get_log_folder( plugin_name='arxiv_rag_cache'))
|
||||||
|
self._setup_directories()
|
||||||
|
|
||||||
|
# 初始化处理状态
|
||||||
|
|
||||||
|
# 线程安全的计数器和集合
|
||||||
|
self._processing_lock = ThreadLock()
|
||||||
|
self._processed_fragments = set()
|
||||||
|
self._processed_count = 0
|
||||||
|
# 优化的线程池配置
|
||||||
|
cpu_count = os.cpu_count() or 1
|
||||||
|
self.thread_pool = ThreadPoolExecutor(
|
||||||
|
max_workers=min(32, cpu_count * 4),
|
||||||
|
thread_name_prefix="arxiv_worker"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 批处理配置
|
||||||
|
self._batch_size = min(20, cpu_count * 2) # 动态设置批大小
|
||||||
|
self.max_concurrent_papers = MAX_CONCURRENT_PAPERS
|
||||||
|
self._semaphore = None
|
||||||
|
self._loop = None
|
||||||
|
|
||||||
|
# 初始化处理队列
|
||||||
|
self.processing_queue = {}
|
||||||
|
|
||||||
|
# 初始化工作组件
|
||||||
|
self._init_workers()
|
||||||
|
|
||||||
|
def _setup_directories(self):
|
||||||
|
"""设置工作目录"""
|
||||||
|
|
||||||
|
if self.arxiv_id:
|
||||||
|
self.checkpoint_dir = self.base_dir / self.arxiv_id
|
||||||
|
self.vector_store_dir = self.checkpoint_dir / "vector_store"
|
||||||
|
self.fragment_store_dir = self.checkpoint_dir / "fragments"
|
||||||
|
else:
|
||||||
|
self.checkpoint_dir = self.base_dir
|
||||||
|
self.vector_store_dir = self.base_dir / "vector_store"
|
||||||
|
self.fragment_store_dir = self.base_dir / "fragments"
|
||||||
|
|
||||||
|
self.paper_path = self.checkpoint_dir / f"{self.arxiv_id}.processed"
|
||||||
|
self.loading = self.paper_path.exists()
|
||||||
|
# 创建必要的目录
|
||||||
|
for directory in [self.checkpoint_dir, self.vector_store_dir, self.fragment_store_dir]:
|
||||||
|
directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
logger.info(f"Created directory: {directory}")
|
||||||
|
|
||||||
|
def _init_workers(self):
|
||||||
|
"""初始化工作组件"""
|
||||||
|
try:
|
||||||
|
self.rag_worker = LlamaIndexRagWorker(
|
||||||
|
user_name=self.user_name,
|
||||||
|
llm_kwargs=self.llm_kwargs,
|
||||||
|
checkpoint_dir=str(self.vector_store_dir),
|
||||||
|
auto_load_checkpoint=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.arxiv_splitter = ArxivSplitter(
|
||||||
|
root_dir=str(self.checkpoint_dir / "arxiv_cache")
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error initializing workers: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _ensure_loop(self):
|
||||||
|
"""确保存在事件循环"""
|
||||||
|
if threading.current_thread() is threading.main_thread():
|
||||||
|
if self._loop is None:
|
||||||
|
self._loop = asyncio.get_event_loop()
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
self._loop = asyncio.get_event_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
self._loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(self._loop)
|
||||||
|
return self._loop
|
||||||
|
|
||||||
|
@property
|
||||||
|
def semaphore(self):
|
||||||
|
"""延迟创建semaphore"""
|
||||||
|
if self._semaphore is None:
|
||||||
|
self._semaphore = asyncio.Semaphore(self.max_concurrent_papers)
|
||||||
|
return self._semaphore
|
||||||
|
|
||||||
|
async def _process_fragments(self, fragments: List[Fragment]) -> None:
|
||||||
|
"""优化的并行处理论文片段"""
|
||||||
|
if not fragments:
|
||||||
|
logger.warning("No fragments to process")
|
||||||
|
return
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
total_fragments = len(fragments)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. 处理论文概述
|
||||||
|
overview = self._create_overview(fragments[0])
|
||||||
|
overview_success = self._safe_add_to_vector_store_sync(overview['text'])
|
||||||
|
if not overview_success:
|
||||||
|
raise RuntimeError("Failed to add overview to vector store")
|
||||||
|
|
||||||
|
# 2. 并行处理片段
|
||||||
|
successful_fragments = await self._parallel_process_fragments(fragments)
|
||||||
|
|
||||||
|
# 3. 保存处理结果
|
||||||
|
if successful_fragments > 0:
|
||||||
|
await self._save_results(fragments, overview['arxiv_id'], successful_fragments)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in fragment processing: {str(e)}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self._log_processing_stats(start_time, total_fragments)
|
||||||
|
|
||||||
|
def _create_overview(self, first_fragment: Fragment) -> Dict:
|
||||||
|
"""创建论文概述"""
|
||||||
|
return {
|
||||||
|
'arxiv_id': first_fragment.arxiv_id,
|
||||||
|
'text': (
|
||||||
|
f"Paper Title: {first_fragment.title}\n"
|
||||||
|
f"ArXiv ID: {first_fragment.arxiv_id}\n"
|
||||||
|
f"Abstract: {first_fragment.abstract}\n"
|
||||||
|
f"Table of contents:{first_fragment.catalogs}\n"
|
||||||
|
f"Type: OVERVIEW"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _parallel_process_fragments(self, fragments: List[Fragment]) -> int:
|
||||||
|
"""并行处理所有片段"""
|
||||||
|
successful_count = 0
|
||||||
|
loop = self._ensure_loop()
|
||||||
|
|
||||||
|
for i in range(0, len(fragments), self._batch_size):
|
||||||
|
batch = fragments[i:i + self._batch_size]
|
||||||
|
batch_futures = []
|
||||||
|
|
||||||
|
for j, fragment in enumerate(batch):
|
||||||
|
if not self._is_fragment_processed(fragment, i + j):
|
||||||
|
future = loop.run_in_executor(
|
||||||
|
self.thread_pool,
|
||||||
|
self._process_single_fragment_sync,
|
||||||
|
fragment,
|
||||||
|
i + j
|
||||||
|
)
|
||||||
|
batch_futures.append(future)
|
||||||
|
|
||||||
|
if batch_futures:
|
||||||
|
results = await asyncio.gather(*batch_futures, return_exceptions=True)
|
||||||
|
successful_count += sum(1 for r in results if isinstance(r, bool) and r)
|
||||||
|
|
||||||
|
return successful_count
|
||||||
|
|
||||||
|
def _is_fragment_processed(self, fragment: Fragment, index: int) -> bool:
|
||||||
|
"""检查片段是否已处理"""
|
||||||
|
fragment_id = f"{fragment.arxiv_id}_{index}"
|
||||||
|
with self._processing_lock:
|
||||||
|
return fragment_id in self._processed_fragments
|
||||||
|
|
||||||
|
def _safe_add_to_vector_store_sync(self, text: str) -> bool:
|
||||||
|
"""线程安全的向量存储添加"""
|
||||||
|
with self._processing_lock:
|
||||||
|
try:
|
||||||
|
self.rag_worker.add_text_to_vector_store(text)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error adding to vector store: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _process_single_fragment_sync(self, fragment: Fragment, index: int) -> bool:
|
||||||
|
"""处理单个片段"""
|
||||||
|
fragment_id = f"{fragment.arxiv_id}_{index}"
|
||||||
|
try:
|
||||||
|
text = self._build_fragment_text(fragment)
|
||||||
|
if self._safe_add_to_vector_store_sync(text):
|
||||||
|
with self._processing_lock:
|
||||||
|
self._processed_fragments.add(fragment_id)
|
||||||
|
self._processed_count += 1
|
||||||
|
logger.info(f"Successfully processed fragment {index}")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing fragment {index}: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _build_fragment_text(self, fragment: Fragment) -> str:
|
||||||
|
"""构建片段文本"""
|
||||||
|
return "".join([
|
||||||
|
f"Paper Title: {fragment.title}\n",
|
||||||
|
f"Section: {fragment.current_section}\n",
|
||||||
|
f"Content: {fragment.content}\n",
|
||||||
|
f"Bibliography: {fragment.bibliography}\n",
|
||||||
|
"Type: FRAGMENT"
|
||||||
|
])
|
||||||
|
|
||||||
|
async def _save_results(self, fragments: List[Fragment], arxiv_id: str, successful_count: int) -> None:
|
||||||
|
"""保存处理结果"""
|
||||||
|
if successful_count > 0:
|
||||||
|
loop = self._ensure_loop()
|
||||||
|
await loop.run_in_executor(
|
||||||
|
self.thread_pool,
|
||||||
|
save_fragments_to_file,
|
||||||
|
fragments,
|
||||||
|
str(self.fragment_store_dir / f"{arxiv_id}_fragments.json")
|
||||||
|
)
|
||||||
|
|
||||||
|
def _log_processing_stats(self, start_time: float, total_fragments: int) -> None:
|
||||||
|
"""记录处理统计信息"""
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
processing_rate = total_fragments / elapsed_time if elapsed_time > 0 else 0
|
||||||
|
logger.info(
|
||||||
|
f"Processed {self._processed_count}/{total_fragments} fragments "
|
||||||
|
f"in {elapsed_time:.2f}s (rate: {processing_rate:.2f} fragments/s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def process_paper(self, fragments: List[Fragment]) -> bool:
|
||||||
|
"""处理论文主函数"""
|
||||||
|
try:
|
||||||
|
|
||||||
|
if self.paper_path.exists():
|
||||||
|
logger.info(f"Paper {self.arxiv_id} already processed")
|
||||||
|
return True
|
||||||
|
|
||||||
|
task = self._create_processing_task(self.arxiv_id)
|
||||||
|
try:
|
||||||
|
async with self.semaphore:
|
||||||
|
await self._process_fragments(fragments)
|
||||||
|
self._complete_task(task, fragments, self.paper_path)
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self._fail_task(task, str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing paper {self.arxiv_id}: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _create_processing_task(self, arxiv_id: str) -> ProcessingTask:
|
||||||
|
"""创建处理任务"""
|
||||||
|
task = ProcessingTask(arxiv_id=arxiv_id)
|
||||||
|
with self._processing_lock:
|
||||||
|
self.processing_queue[arxiv_id] = task
|
||||||
|
task.status = "processing"
|
||||||
|
return task
|
||||||
|
|
||||||
|
def _complete_task(self, task: ProcessingTask, fragments: List[Fragment], paper_path: Path) -> None:
|
||||||
|
"""完成任务处理"""
|
||||||
|
with self._processing_lock:
|
||||||
|
task.status = "completed"
|
||||||
|
task.fragments = fragments
|
||||||
|
paper_path.touch()
|
||||||
|
logger.info(f"Paper {task.arxiv_id} processed successfully with {self._processed_count} fragments")
|
||||||
|
|
||||||
|
def _fail_task(self, task: ProcessingTask, error: str) -> None:
|
||||||
|
"""任务失败处理"""
|
||||||
|
with self._processing_lock:
|
||||||
|
task.status = "failed"
|
||||||
|
task.error = error
|
||||||
|
|
||||||
|
def _normalize_arxiv_id(self, input_str: str) -> str:
|
||||||
|
"""规范化ArXiv ID"""
|
||||||
|
if not input_str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
input_str = input_str.strip().lower()
|
||||||
|
if 'arxiv.org/' in input_str:
|
||||||
|
if '/pdf/' in input_str:
|
||||||
|
arxiv_id = input_str.split('/pdf/')[-1]
|
||||||
|
else:
|
||||||
|
arxiv_id = input_str.split('/abs/')[-1]
|
||||||
|
return arxiv_id.split('v')[0].strip()
|
||||||
|
return input_str.split('v')[0].strip()
|
||||||
|
|
||||||
|
async def wait_for_paper(self, arxiv_id: str, timeout: float = 300.0) -> bool:
|
||||||
|
"""等待论文处理完成"""
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
with self._processing_lock:
|
||||||
|
task = self.processing_queue.get(arxiv_id)
|
||||||
|
if not task:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if task.status == "completed":
|
||||||
|
return True
|
||||||
|
if task.status == "failed":
|
||||||
|
return False
|
||||||
|
|
||||||
|
if time.time() - start_time > timeout:
|
||||||
|
logger.error(f"Processing paper {arxiv_id} timed out")
|
||||||
|
return False
|
||||||
|
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error waiting for paper {arxiv_id}: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def retrieve_and_generate(self, query: str) -> str:
|
||||||
|
"""检索相关内容并生成提示词"""
|
||||||
|
try:
|
||||||
|
nodes = self.rag_worker.retrieve_from_store_with_query(query)
|
||||||
|
return self.rag_worker.build_prompt(query=query, nodes=nodes)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in retrieve and generate: {str(e)}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def remember_qa(self, question: str, answer: str) -> None:
|
||||||
|
"""记忆问答对"""
|
||||||
|
try:
|
||||||
|
self.rag_worker.remember_qa(question, answer)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error remembering QA: {str(e)}")
|
||||||
|
|
||||||
|
async def auto_analyze_paper(self, chatbot: List, history: List, system_prompt: str) -> None:
|
||||||
|
"""自动分析论文的关键问题"""
|
||||||
|
key_questions = [
|
||||||
|
"What is the main research question or problem addressed in this paper?",
|
||||||
|
"What methods or approaches did the authors use to investigate the problem?",
|
||||||
|
"What are the key findings or results presented in the paper?",
|
||||||
|
"How do the findings of this paper contribute to the broader field or topic of study?",
|
||||||
|
"What are the limitations of this study, and what future research directions do the authors suggest?"
|
||||||
|
]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for question in key_questions:
|
||||||
|
try:
|
||||||
|
prompt = self.retrieve_and_generate(question)
|
||||||
|
if prompt:
|
||||||
|
response = await request_gpt_model_in_new_thread_with_ui_alive(
|
||||||
|
inputs=prompt,
|
||||||
|
inputs_show_user=question,
|
||||||
|
llm_kwargs=self.llm_kwargs,
|
||||||
|
chatbot=chatbot,
|
||||||
|
history=history,
|
||||||
|
sys_prompt=system_prompt
|
||||||
|
)
|
||||||
|
results.append(f"Q: {question}\nA: {response}\n")
|
||||||
|
self.remember_qa(question, response)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in auto analysis: {str(e)}")
|
||||||
|
|
||||||
|
# 合并所有结果
|
||||||
|
summary = "\n\n".join(results)
|
||||||
|
chatbot[-1] = (chatbot[-1][0], f"论文已成功加载并完成初步分析:\n\n{summary}\n\n您现在可以继续提问更多细节。")
|
||||||
|
|
||||||
|
@CatchException
|
||||||
|
def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
|
||||||
|
history: List, system_prompt: str, web_port: str) -> Generator:
|
||||||
|
"""
|
||||||
|
Arxiv论文对话主函数
|
||||||
|
Args:
|
||||||
|
txt: arxiv ID/URL
|
||||||
|
llm_kwargs: LLM配置参数
|
||||||
|
plugin_kwargs: 插件配置参数,包含 advanced_arg 字段作为用户询问指令
|
||||||
|
chatbot: 对话历史
|
||||||
|
history: 聊天历史
|
||||||
|
system_prompt: 系统提示词
|
||||||
|
web_port: Web端口
|
||||||
|
"""
|
||||||
|
# 初始化时,提示用户需要 arxiv ID/URL
|
||||||
|
from toolbox import promote_file_to_downloadzone
|
||||||
|
if len(history) == 0 and not txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '0', '1', '2')):
|
||||||
|
chatbot.append((txt, "请先提供Arxiv论文链接或ID。"))
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
return
|
||||||
|
|
||||||
|
user_name = chatbot.get_user()
|
||||||
|
arxiv_worker = ArxivRagWorker(user_name, llm_kwargs, arxiv_id=txt)
|
||||||
|
arxiv_id = arxiv_worker.arxiv_id
|
||||||
|
|
||||||
|
# 处理新论文的情况
|
||||||
|
if txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '0', '1', '2')) and not arxiv_worker.loading:
|
||||||
|
chatbot.append((txt, "正在处理论文,请稍等..."))
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
fragments, formatted_content, fragment_output_files = process_arxiv_sync(arxiv_worker.arxiv_splitter, arxiv_id)
|
||||||
|
for file in fragment_output_files:
|
||||||
|
promote_file_to_downloadzone(file, chatbot=chatbot)
|
||||||
|
chatbot.append(["论文文字内容已保存至下载区,接下来将进行论文编码,请耐心等待三分钟,论文的文字内容为:", formatted_content])
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
try:
|
||||||
|
# 创建新的事件循环
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 设置超时时间为5分钟
|
||||||
|
success = loop.run_until_complete(
|
||||||
|
asyncio.wait_for(arxiv_worker.process_paper(fragments), timeout=300)
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
success = loop.run_until_complete(
|
||||||
|
asyncio.wait_for(arxiv_worker.wait_for_paper(arxiv_id), timeout=60)
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
chatbot[-1] = (txt, "论文处理完成,您现在可以开始提问。")
|
||||||
|
else:
|
||||||
|
chatbot[-1] = (txt, "论文处理超时,请重试。")
|
||||||
|
else:
|
||||||
|
chatbot[-1] = (txt, "论文处理失败,请检查论文ID是否正确或稍后重试。")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
chatbot[-1] = (txt, "论文处理超时,请重试。")
|
||||||
|
success = False
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
return
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in main process: {str(e)}")
|
||||||
|
chatbot[-1] = (txt, f"处理过程中发生错误: {str(e)}")
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
return
|
||||||
|
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
return
|
||||||
|
# 处理用户询问的情况
|
||||||
|
# 获取用户询问指令
|
||||||
|
user_query = plugin_kwargs.get("advanced_arg",
|
||||||
|
"What is the main research question or problem addressed in this paper?")
|
||||||
|
if len(history)<2:
|
||||||
|
fragments, formatted_content, fragment_output_files = process_arxiv_sync(arxiv_worker.arxiv_splitter, arxiv_id)
|
||||||
|
for file in fragment_output_files:
|
||||||
|
promote_file_to_downloadzone(file, chatbot=chatbot)
|
||||||
|
chatbot.append(["论文文字内容已保存至下载区,论文的文字内容为:", formatted_content])
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
if not user_query:
|
||||||
|
user_query = "What is the main research question or problem addressed in this paper?"
|
||||||
|
# chatbot.append((txt, "请提供您的问题。"))
|
||||||
|
# yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
# return
|
||||||
|
|
||||||
|
# 处理历史对话长度
|
||||||
|
if len(history) > MAX_HISTORY_ROUND * 2:
|
||||||
|
history = history[-(MAX_HISTORY_ROUND * 2):]
|
||||||
|
|
||||||
|
# 处理询问指令
|
||||||
|
query_clip, history, flags = input_clipping(
|
||||||
|
user_query,
|
||||||
|
history,
|
||||||
|
max_token_limit=MAX_CONTEXT_TOKEN_LIMIT,
|
||||||
|
return_clip_flags=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if flags["original_input_len"] != flags["clipped_input_len"]:
|
||||||
|
yield from update_ui_lastest_msg('检测到长输入,正在处理...', chatbot, history, delay=0)
|
||||||
|
if len(user_query) > REMEMBER_PREVIEW:
|
||||||
|
HALF = REMEMBER_PREVIEW // 2
|
||||||
|
query_to_remember = user_query[
|
||||||
|
:HALF] + f" ...\n...(省略{len(user_query) - REMEMBER_PREVIEW}字)...\n... " + user_query[
|
||||||
|
-HALF:]
|
||||||
|
else:
|
||||||
|
query_to_remember = query_clip
|
||||||
|
else:
|
||||||
|
query_to_remember = query_clip
|
||||||
|
|
||||||
|
chatbot.append((user_query, "正在思考中..."))
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
|
||||||
|
# 生成提示词
|
||||||
|
prompt = arxiv_worker.retrieve_and_generate(query_clip)
|
||||||
|
if not prompt:
|
||||||
|
chatbot[-1] = (user_query, "抱歉,处理您的问题时出现错误,请重试。")
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 获取回答
|
||||||
|
response = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||||
|
inputs=prompt,
|
||||||
|
inputs_show_user=query_clip,
|
||||||
|
llm_kwargs=llm_kwargs,
|
||||||
|
chatbot=chatbot,
|
||||||
|
history=history,
|
||||||
|
sys_prompt=system_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
# 记忆问答对
|
||||||
|
# worker.remember_qa(query_to_remember, response)
|
||||||
|
history.extend([user_query, response])
|
||||||
|
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 测试代码
|
||||||
|
llm_kwargs = {
|
||||||
|
'api_key': os.getenv("one_api_key"),
|
||||||
|
'client_ip': '127.0.0.1',
|
||||||
|
'embed_model': 'text-embedding-3-small',
|
||||||
|
'llm_model': 'one-api-Qwen2.5-72B-Instruct',
|
||||||
|
'max_length': 4096,
|
||||||
|
'most_recent_uploaded': None,
|
||||||
|
'temperature': 1,
|
||||||
|
'top_p': 1
|
||||||
|
}
|
||||||
|
plugin_kwargs = {}
|
||||||
|
chatbot = []
|
||||||
|
history = []
|
||||||
|
system_prompt = "You are a helpful assistant."
|
||||||
|
web_port = "8080"
|
||||||
|
|
||||||
|
# 测试论文导入
|
||||||
|
arxiv_url = "https://arxiv.org/abs/2312.12345"
|
||||||
|
for response in Arxiv论文对话(
|
||||||
|
arxiv_url, llm_kwargs, plugin_kwargs,
|
||||||
|
chatbot, history, system_prompt, web_port
|
||||||
|
):
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
# 测试问答
|
||||||
|
question = "这篇论文的主要贡献是什么?"
|
||||||
|
for response in Arxiv论文对话(
|
||||||
|
question, llm_kwargs, plugin_kwargs,
|
||||||
|
chatbot, history, system_prompt, web_port
|
||||||
|
):
|
||||||
|
print(response)
|
||||||
@@ -152,8 +152,6 @@ class Conversation_To_File_Wrap(GptAcademicPluginTemplate):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def hide_cwd(str):
|
def hide_cwd(str):
|
||||||
import os
|
import os
|
||||||
current_path = os.getcwd()
|
current_path = os.getcwd()
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from toolbox import CatchException, report_exception, update_ui_lastest_msg, zip
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
import glob, os, requests, time, json, tarfile
|
import glob, os, requests, time, json, tarfile, threading
|
||||||
|
|
||||||
pj = os.path.join
|
pj = os.path.join
|
||||||
ARXIV_CACHE_DIR = get_conf("ARXIV_CACHE_DIR")
|
ARXIV_CACHE_DIR = get_conf("ARXIV_CACHE_DIR")
|
||||||
@@ -338,11 +338,17 @@ def Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot,
|
|||||||
# <-------------- more requirements ------------->
|
# <-------------- more requirements ------------->
|
||||||
if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
|
if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
|
||||||
more_req = plugin_kwargs.get("advanced_arg", "")
|
more_req = plugin_kwargs.get("advanced_arg", "")
|
||||||
no_cache = more_req.startswith("--no-cache")
|
|
||||||
if no_cache: more_req.lstrip("--no-cache")
|
no_cache = ("--no-cache" in more_req)
|
||||||
|
if no_cache: more_req = more_req.replace("--no-cache", "").strip()
|
||||||
|
|
||||||
|
allow_gptac_cloud_io = ("--allow-cloudio" in more_req) # 从云端下载翻译结果,以及上传翻译结果到云端
|
||||||
|
if allow_gptac_cloud_io: more_req = more_req.replace("--allow-cloudio", "").strip()
|
||||||
|
|
||||||
allow_cache = not no_cache
|
allow_cache = not no_cache
|
||||||
_switch_prompt_ = partial(switch_prompt, more_requirement=more_req)
|
_switch_prompt_ = partial(switch_prompt, more_requirement=more_req)
|
||||||
|
|
||||||
|
|
||||||
# <-------------- check deps ------------->
|
# <-------------- check deps ------------->
|
||||||
try:
|
try:
|
||||||
import glob, os, time, subprocess
|
import glob, os, time, subprocess
|
||||||
@@ -369,6 +375,20 @@ def Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot,
|
|||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# #################################################################
|
||||||
|
if allow_gptac_cloud_io and arxiv_id:
|
||||||
|
# 访问 GPTAC学术云,查询云端是否存在该论文的翻译版本
|
||||||
|
from crazy_functions.latex_fns.latex_actions import check_gptac_cloud
|
||||||
|
success, downloaded = check_gptac_cloud(arxiv_id, chatbot)
|
||||||
|
if success:
|
||||||
|
chatbot.append([
|
||||||
|
f"检测到GPTAC云端存在翻译版本, 如果不满意翻译结果, 请禁用云端分享, 然后重新执行。",
|
||||||
|
None
|
||||||
|
])
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
return
|
||||||
|
#################################################################
|
||||||
|
|
||||||
if os.path.exists(txt):
|
if os.path.exists(txt):
|
||||||
project_folder = txt
|
project_folder = txt
|
||||||
else:
|
else:
|
||||||
@@ -406,14 +426,21 @@ def Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot,
|
|||||||
# <-------------- zip PDF ------------->
|
# <-------------- zip PDF ------------->
|
||||||
zip_res = zip_result(project_folder)
|
zip_res = zip_result(project_folder)
|
||||||
if success:
|
if success:
|
||||||
|
if allow_gptac_cloud_io and arxiv_id:
|
||||||
|
# 如果用户允许,我们将翻译好的arxiv论文PDF上传到GPTAC学术云
|
||||||
|
from crazy_functions.latex_fns.latex_actions import upload_to_gptac_cloud_if_user_allow
|
||||||
|
threading.Thread(target=upload_to_gptac_cloud_if_user_allow,
|
||||||
|
args=(chatbot, arxiv_id), daemon=True).start()
|
||||||
|
|
||||||
chatbot.append((f"成功啦", '请查收结果(压缩包)...'))
|
chatbot.append((f"成功啦", '请查收结果(压缩包)...'))
|
||||||
yield from update_ui(chatbot=chatbot, history=history);
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
time.sleep(1) # 刷新界面
|
time.sleep(1) # 刷新界面
|
||||||
promote_file_to_downloadzone(file=zip_res, chatbot=chatbot)
|
promote_file_to_downloadzone(file=zip_res, chatbot=chatbot)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
chatbot.append((f"失败了",
|
chatbot.append((f"失败了",
|
||||||
'虽然PDF生成失败了, 但请查收结果(压缩包), 内含已经翻译的Tex文档, 您可以到Github Issue区, 用该压缩包进行反馈。如系统是Linux,请检查系统字体(见Github wiki) ...'))
|
'虽然PDF生成失败了, 但请查收结果(压缩包), 内含已经翻译的Tex文档, 您可以到Github Issue区, 用该压缩包进行反馈。如系统是Linux,请检查系统字体(见Github wiki) ...'))
|
||||||
yield from update_ui(chatbot=chatbot, history=history);
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
time.sleep(1) # 刷新界面
|
time.sleep(1) # 刷新界面
|
||||||
promote_file_to_downloadzone(file=zip_res, chatbot=chatbot)
|
promote_file_to_downloadzone(file=zip_res, chatbot=chatbot)
|
||||||
|
|
||||||
|
|||||||
@@ -30,6 +30,8 @@ class Arxiv_Localize(GptAcademicPluginTemplate):
|
|||||||
default_value="", type="string").model_dump_json(), # 高级参数输入区,自动同步
|
default_value="", type="string").model_dump_json(), # 高级参数输入区,自动同步
|
||||||
"allow_cache":
|
"allow_cache":
|
||||||
ArgProperty(title="是否允许从缓存中调取结果", options=["允许缓存", "从头执行"], default_value="允许缓存", description="无", type="dropdown").model_dump_json(),
|
ArgProperty(title="是否允许从缓存中调取结果", options=["允许缓存", "从头执行"], default_value="允许缓存", description="无", type="dropdown").model_dump_json(),
|
||||||
|
"allow_cloudio":
|
||||||
|
ArgProperty(title="是否允许从GPTAC学术云下载(或者上传)翻译结果(仅针对Arxiv论文)", options=["允许", "禁止"], default_value="禁止", description="共享文献,互助互利", type="dropdown").model_dump_json(),
|
||||||
}
|
}
|
||||||
return gui_definition
|
return gui_definition
|
||||||
|
|
||||||
@@ -38,9 +40,14 @@ class Arxiv_Localize(GptAcademicPluginTemplate):
|
|||||||
执行插件
|
执行插件
|
||||||
"""
|
"""
|
||||||
allow_cache = plugin_kwargs["allow_cache"]
|
allow_cache = plugin_kwargs["allow_cache"]
|
||||||
|
allow_cloudio = plugin_kwargs["allow_cloudio"]
|
||||||
advanced_arg = plugin_kwargs["advanced_arg"]
|
advanced_arg = plugin_kwargs["advanced_arg"]
|
||||||
|
|
||||||
if allow_cache == "从头执行": plugin_kwargs["advanced_arg"] = "--no-cache " + plugin_kwargs["advanced_arg"]
|
if allow_cache == "从头执行": plugin_kwargs["advanced_arg"] = "--no-cache " + plugin_kwargs["advanced_arg"]
|
||||||
|
|
||||||
|
# 从云端下载翻译结果,以及上传翻译结果到云端;人人为我,我为人人。
|
||||||
|
if allow_cloudio == "允许": plugin_kwargs["advanced_arg"] = "--allow-cloudio " + plugin_kwargs["advanced_arg"]
|
||||||
|
|
||||||
yield from Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)
|
yield from Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ def 多文件翻译(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
|
|||||||
pfg.file_contents.append(file_content)
|
pfg.file_contents.append(file_content)
|
||||||
|
|
||||||
# <-------- 拆分过长的Markdown文件 ---------->
|
# <-------- 拆分过长的Markdown文件 ---------->
|
||||||
pfg.run_file_split(max_token_limit=2048)
|
pfg.run_file_split(max_token_limit=1024)
|
||||||
n_split = len(pfg.sp_file_contents)
|
n_split = len(pfg.sp_file_contents)
|
||||||
|
|
||||||
# <-------- 多线程翻译开始 ---------->
|
# <-------- 多线程翻译开始 ---------->
|
||||||
|
|||||||
@@ -1,3 +1,9 @@
|
|||||||
|
import os,glob
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from shared_utils.fastapi_server import validate_path_safety
|
||||||
|
|
||||||
|
from toolbox import report_exception
|
||||||
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg
|
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg
|
||||||
from crazy_functions.crazy_utils import input_clipping
|
from crazy_functions.crazy_utils import input_clipping
|
||||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||||
@@ -7,6 +13,37 @@ MAX_HISTORY_ROUND = 5
|
|||||||
MAX_CONTEXT_TOKEN_LIMIT = 4096
|
MAX_CONTEXT_TOKEN_LIMIT = 4096
|
||||||
REMEMBER_PREVIEW = 1000
|
REMEMBER_PREVIEW = 1000
|
||||||
|
|
||||||
|
@CatchException
|
||||||
|
def handle_document_upload(files: List[str], llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request, rag_worker):
|
||||||
|
"""
|
||||||
|
Handles document uploads by extracting text and adding it to the vector store.
|
||||||
|
"""
|
||||||
|
from llama_index.core import Document
|
||||||
|
from crazy_functions.rag_fns.rag_file_support import extract_text, supports_format
|
||||||
|
user_name = chatbot.get_user()
|
||||||
|
checkpoint_dir = get_log_folder(user_name, plugin_name='experimental_rag')
|
||||||
|
|
||||||
|
for file_path in files:
|
||||||
|
try:
|
||||||
|
validate_path_safety(file_path, user_name)
|
||||||
|
text = extract_text(file_path)
|
||||||
|
if text is None:
|
||||||
|
chatbot.append(
|
||||||
|
[f"上传文件: {os.path.basename(file_path)}", f"文件解析失败,无法提取文本内容,请更换文件。失败原因可能为:1.文档格式过于复杂;2. 不支持的文件格式,支持的文件格式后缀有:" + ", ".join(supports_format)])
|
||||||
|
else:
|
||||||
|
chatbot.append(
|
||||||
|
[f"上传文件: {os.path.basename(file_path)}", f"上传文件前50个字符为:{text[:50]}。"])
|
||||||
|
document = Document(text=text, metadata={"source": file_path})
|
||||||
|
rag_worker.add_documents_to_vector_store([document])
|
||||||
|
chatbot.append([f"上传文件: {os.path.basename(file_path)}", "文件已成功添加到知识库。"])
|
||||||
|
except Exception as e:
|
||||||
|
report_exception(chatbot, history, a=f"处理文件: {file_path}", b=str(e))
|
||||||
|
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Main Q&A function with document upload support
|
||||||
@CatchException
|
@CatchException
|
||||||
def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||||
|
|
||||||
@@ -27,24 +64,43 @@ def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, u
|
|||||||
rag_worker = RAG_WORKER_REGISTER[user_name]
|
rag_worker = RAG_WORKER_REGISTER[user_name]
|
||||||
else:
|
else:
|
||||||
rag_worker = RAG_WORKER_REGISTER[user_name] = LlamaIndexRagWorker(
|
rag_worker = RAG_WORKER_REGISTER[user_name] = LlamaIndexRagWorker(
|
||||||
user_name,
|
user_name,
|
||||||
llm_kwargs,
|
llm_kwargs,
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
auto_load_checkpoint=True)
|
auto_load_checkpoint=True
|
||||||
|
)
|
||||||
|
|
||||||
current_context = f"{VECTOR_STORE_TYPE} @ {checkpoint_dir}"
|
current_context = f"{VECTOR_STORE_TYPE} @ {checkpoint_dir}"
|
||||||
tip = "提示:输入“清空向量数据库”可以清空RAG向量数据库"
|
tip = "提示:输入“清空向量数据库”可以清空RAG向量数据库"
|
||||||
if txt == "清空向量数据库":
|
|
||||||
chatbot.append([txt, f'正在清空 ({current_context}) ...'])
|
# 2. Handle special commands
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
if os.path.exists(txt) and os.path.isdir(txt):
|
||||||
rag_worker.purge()
|
project_folder = txt
|
||||||
yield from update_ui_lastest_msg('已清空', chatbot, history, delay=0) # 刷新界面
|
validate_path_safety(project_folder, chatbot.get_user())
|
||||||
|
# Extract file paths from the user input
|
||||||
|
# Assuming the user inputs file paths separated by commas after the command
|
||||||
|
file_paths = [f for f in glob.glob(f'{project_folder}/**/*', recursive=True)]
|
||||||
|
chatbot.append([txt, f'正在处理上传的文档 ({current_context}) ...'])
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
|
|
||||||
|
yield from handle_document_upload(file_paths, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request, rag_worker)
|
||||||
return
|
return
|
||||||
|
|
||||||
chatbot.append([txt, f'正在召回知识 ({current_context}) ...'])
|
elif txt == "清空向量数据库":
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
chatbot.append([txt, f'正在清空 ({current_context}) ...'])
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
|
rag_worker.purge_vector_store()
|
||||||
|
yield from update_ui_lastest_msg('已清空', chatbot, history, delay=0) # 刷新界面
|
||||||
|
return
|
||||||
|
|
||||||
# 2. clip history to reduce token consumption
|
else:
|
||||||
# 2-1. reduce chat round
|
report_exception(chatbot, history, a=f"上传文件路径错误: {txt}", b="请检查并提供正确路径。")
|
||||||
|
|
||||||
|
# 3. Normal Q&A processing
|
||||||
|
chatbot.append([txt, f'正在召回知识 ({current_context}) ...'])
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
|
|
||||||
|
# 4. Clip history to reduce token consumption
|
||||||
txt_origin = txt
|
txt_origin = txt
|
||||||
|
|
||||||
if len(history) > MAX_HISTORY_ROUND * 2:
|
if len(history) > MAX_HISTORY_ROUND * 2:
|
||||||
@@ -52,41 +108,47 @@ def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, u
|
|||||||
txt_clip, history, flags = input_clipping(txt, history, max_token_limit=MAX_CONTEXT_TOKEN_LIMIT, return_clip_flags=True)
|
txt_clip, history, flags = input_clipping(txt, history, max_token_limit=MAX_CONTEXT_TOKEN_LIMIT, return_clip_flags=True)
|
||||||
input_is_clipped_flag = (flags["original_input_len"] != flags["clipped_input_len"])
|
input_is_clipped_flag = (flags["original_input_len"] != flags["clipped_input_len"])
|
||||||
|
|
||||||
# 2-2. if input is clipped, add input to vector store before retrieve
|
# 5. If input is clipped, add input to vector store before retrieve
|
||||||
if input_is_clipped_flag:
|
if input_is_clipped_flag:
|
||||||
yield from update_ui_lastest_msg('检测到长输入, 正在向量化 ...', chatbot, history, delay=0) # 刷新界面
|
yield from update_ui_lastest_msg('检测到长输入, 正在向量化 ...', chatbot, history, delay=0) # 刷新界面
|
||||||
# save input to vector store
|
# Save input to vector store
|
||||||
rag_worker.add_text_to_vector_store(txt_origin)
|
rag_worker.add_text_to_vector_store(txt_origin)
|
||||||
yield from update_ui_lastest_msg('向量化完成 ...', chatbot, history, delay=0) # 刷新界面
|
yield from update_ui_lastest_msg('向量化完成 ...', chatbot, history, delay=0) # 刷新界面
|
||||||
|
|
||||||
if len(txt_origin) > REMEMBER_PREVIEW:
|
if len(txt_origin) > REMEMBER_PREVIEW:
|
||||||
HALF = REMEMBER_PREVIEW//2
|
HALF = REMEMBER_PREVIEW // 2
|
||||||
i_say_to_remember = txt[:HALF] + f" ...\n...(省略{len(txt_origin)-REMEMBER_PREVIEW}字)...\n... " + txt[-HALF:]
|
i_say_to_remember = txt[:HALF] + f" ...\n...(省略{len(txt_origin)-REMEMBER_PREVIEW}字)...\n... " + txt[-HALF:]
|
||||||
if (flags["original_input_len"] - flags["clipped_input_len"]) > HALF:
|
if (flags["original_input_len"] - flags["clipped_input_len"]) > HALF:
|
||||||
txt_clip = txt_clip + f" ...\n...(省略{len(txt_origin)-len(txt_clip)-HALF}字)...\n... " + txt[-HALF:]
|
txt_clip = txt_clip + f" ...\n...(省略{len(txt_origin)-len(txt_clip)-HALF}字)...\n... " + txt[-HALF:]
|
||||||
else:
|
|
||||||
pass
|
|
||||||
i_say = txt_clip
|
|
||||||
else:
|
else:
|
||||||
i_say_to_remember = i_say = txt_clip
|
i_say_to_remember = i_say = txt_clip
|
||||||
else:
|
else:
|
||||||
i_say_to_remember = i_say = txt_clip
|
i_say_to_remember = i_say = txt_clip
|
||||||
|
|
||||||
# 3. we search vector store and build prompts
|
# 6. Search vector store and build prompts
|
||||||
nodes = rag_worker.retrieve_from_store_with_query(i_say)
|
nodes = rag_worker.retrieve_from_store_with_query(i_say)
|
||||||
prompt = rag_worker.build_prompt(query=i_say, nodes=nodes)
|
prompt = rag_worker.build_prompt(query=i_say, nodes=nodes)
|
||||||
|
# 7. Query language model
|
||||||
|
if len(chatbot) != 0:
|
||||||
|
chatbot.pop(-1) # Pop temp chat, because we are going to add them again inside `request_gpt_model_in_new_thread_with_ui_alive`
|
||||||
|
|
||||||
# 4. it is time to query llms
|
|
||||||
if len(chatbot) != 0: chatbot.pop(-1) # pop temp chat, because we are going to add them again inside `request_gpt_model_in_new_thread_with_ui_alive`
|
|
||||||
model_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
model_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||||
inputs=prompt, inputs_show_user=i_say,
|
inputs=prompt,
|
||||||
llm_kwargs=llm_kwargs, chatbot=chatbot, history=history,
|
inputs_show_user=i_say,
|
||||||
|
llm_kwargs=llm_kwargs,
|
||||||
|
chatbot=chatbot,
|
||||||
|
history=history,
|
||||||
sys_prompt=system_prompt,
|
sys_prompt=system_prompt,
|
||||||
retry_times_at_unknown_error=0
|
retry_times_at_unknown_error=0
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. remember what has been asked / answered
|
# 8. Remember Q&A
|
||||||
yield from update_ui_lastest_msg(model_say + '</br></br>' + f'对话记忆中, 请稍等 ({current_context}) ...', chatbot, history, delay=0.5) # 刷新界面
|
yield from update_ui_lastest_msg(
|
||||||
|
model_say + '</br></br>' + f'对话记忆中, 请稍等 ({current_context}) ...',
|
||||||
|
chatbot, history, delay=0.5
|
||||||
|
)
|
||||||
rag_worker.remember_qa(i_say_to_remember, model_say)
|
rag_worker.remember_qa(i_say_to_remember, model_say)
|
||||||
history.extend([i_say, model_say])
|
history.extend([i_say, model_say])
|
||||||
|
|
||||||
yield from update_ui_lastest_msg(model_say, chatbot, history, delay=0, msg=tip) # 刷新界面
|
# 9. Final UI Update
|
||||||
|
yield from update_ui_lastest_msg(model_say, chatbot, history, delay=0, msg=tip)
|
||||||
@@ -6,7 +6,10 @@ from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_ver
|
|||||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||||
from crazy_functions.agent_fns.python_comment_agent import PythonCodeComment
|
from crazy_functions.agent_fns.python_comment_agent import PythonCodeComment
|
||||||
from crazy_functions.diagram_fns.file_tree import FileNode
|
from crazy_functions.diagram_fns.file_tree import FileNode
|
||||||
|
from crazy_functions.agent_fns.watchdog import WatchDog
|
||||||
from shared_utils.advanced_markdown_format import markdown_convertion_for_file
|
from shared_utils.advanced_markdown_format import markdown_convertion_for_file
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
||||||
|
|
||||||
@@ -24,12 +27,13 @@ def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
|
|||||||
file_tree_struct.add_file(file_path, file_path)
|
file_tree_struct.add_file(file_path, file_path)
|
||||||
|
|
||||||
# <第一步,逐个文件分析,多线程>
|
# <第一步,逐个文件分析,多线程>
|
||||||
|
lang = "" if not plugin_kwargs["use_chinese"] else " (you must use Chinese)"
|
||||||
for index, fp in enumerate(file_manifest):
|
for index, fp in enumerate(file_manifest):
|
||||||
# 读取文件
|
# 读取文件
|
||||||
with open(fp, 'r', encoding='utf-8', errors='replace') as f:
|
with open(fp, 'r', encoding='utf-8', errors='replace') as f:
|
||||||
file_content = f.read()
|
file_content = f.read()
|
||||||
prefix = ""
|
prefix = ""
|
||||||
i_say = prefix + f'Please conclude the following source code at {os.path.relpath(fp, project_folder)} with only one sentence, the code is:\n```{file_content}```'
|
i_say = prefix + f'Please conclude the following source code at {os.path.relpath(fp, project_folder)} with only one sentence{lang}, the code is:\n```{file_content}```'
|
||||||
i_say_show_user = prefix + f'[{index+1}/{len(file_manifest)}] 请用一句话对下面的程序文件做一个整体概述: {fp}'
|
i_say_show_user = prefix + f'[{index+1}/{len(file_manifest)}] 请用一句话对下面的程序文件做一个整体概述: {fp}'
|
||||||
# 装载请求内容
|
# 装载请求内容
|
||||||
MAX_TOKEN_SINGLE_FILE = 2560
|
MAX_TOKEN_SINGLE_FILE = 2560
|
||||||
@@ -37,7 +41,7 @@ def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
|
|||||||
inputs_array.append(i_say)
|
inputs_array.append(i_say)
|
||||||
inputs_show_user_array.append(i_say_show_user)
|
inputs_show_user_array.append(i_say_show_user)
|
||||||
history_array.append([])
|
history_array.append([])
|
||||||
sys_prompt_array.append("You are a software architecture analyst analyzing a source code project. Do not dig into details, tell me what the code is doing in general. Your answer must be short, simple and clear.")
|
sys_prompt_array.append(f"You are a software architecture analyst analyzing a source code project. Do not dig into details, tell me what the code is doing in general. Your answer must be short, simple and clear{lang}.")
|
||||||
# 文件读取完成,对每一个源代码文件,生成一个请求线程,发送到大模型进行分析
|
# 文件读取完成,对每一个源代码文件,生成一个请求线程,发送到大模型进行分析
|
||||||
gpt_response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
|
gpt_response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
|
||||||
inputs_array = inputs_array,
|
inputs_array = inputs_array,
|
||||||
@@ -50,10 +54,20 @@ def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
|
|||||||
)
|
)
|
||||||
|
|
||||||
# <第二步,逐个文件分析,生成带注释文件>
|
# <第二步,逐个文件分析,生成带注释文件>
|
||||||
|
tasks = ["" for _ in range(len(file_manifest))]
|
||||||
|
def bark_fn(tasks):
|
||||||
|
for i in range(len(tasks)): tasks[i] = "watchdog is dead"
|
||||||
|
wd = WatchDog(timeout=10, bark_fn=lambda: bark_fn(tasks), interval=3, msg="ThreadWatcher timeout")
|
||||||
|
wd.begin_watch()
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
executor = ThreadPoolExecutor(max_workers=get_conf('DEFAULT_WORKER_NUM'))
|
executor = ThreadPoolExecutor(max_workers=get_conf('DEFAULT_WORKER_NUM'))
|
||||||
def _task_multi_threading(i_say, gpt_say, fp, file_tree_struct):
|
def _task_multi_threading(i_say, gpt_say, fp, file_tree_struct, index):
|
||||||
pcc = PythonCodeComment(llm_kwargs, language='English')
|
language = 'Chinese' if plugin_kwargs["use_chinese"] else 'English'
|
||||||
|
def observe_window_update(x):
|
||||||
|
if tasks[index] == "watchdog is dead":
|
||||||
|
raise TimeoutError("ThreadWatcher: watchdog is dead")
|
||||||
|
tasks[index] = x
|
||||||
|
pcc = PythonCodeComment(llm_kwargs, plugin_kwargs, language=language, observe_window_update=observe_window_update)
|
||||||
pcc.read_file(path=fp, brief=gpt_say)
|
pcc.read_file(path=fp, brief=gpt_say)
|
||||||
revised_path, revised_content = pcc.begin_comment_source_code(None, None)
|
revised_path, revised_content = pcc.begin_comment_source_code(None, None)
|
||||||
file_tree_struct.manifest[fp].revised_path = revised_path
|
file_tree_struct.manifest[fp].revised_path = revised_path
|
||||||
@@ -65,7 +79,8 @@ def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
|
|||||||
with open("crazy_functions/agent_fns/python_comment_compare.html", 'r', encoding='utf-8') as f:
|
with open("crazy_functions/agent_fns/python_comment_compare.html", 'r', encoding='utf-8') as f:
|
||||||
html_template = f.read()
|
html_template = f.read()
|
||||||
warp = lambda x: "```python\n\n" + x + "\n\n```"
|
warp = lambda x: "```python\n\n" + x + "\n\n```"
|
||||||
from themes.theme import advanced_css
|
from themes.theme import load_dynamic_theme
|
||||||
|
_, advanced_css, _, _ = load_dynamic_theme("Default")
|
||||||
html_template = html_template.replace("ADVANCED_CSS", advanced_css)
|
html_template = html_template.replace("ADVANCED_CSS", advanced_css)
|
||||||
html_template = html_template.replace("REPLACE_CODE_FILE_LEFT", pcc.get_markdown_block_in_html(markdown_convertion_for_file(warp(pcc.original_content))))
|
html_template = html_template.replace("REPLACE_CODE_FILE_LEFT", pcc.get_markdown_block_in_html(markdown_convertion_for_file(warp(pcc.original_content))))
|
||||||
html_template = html_template.replace("REPLACE_CODE_FILE_RIGHT", pcc.get_markdown_block_in_html(markdown_convertion_for_file(warp(revised_content))))
|
html_template = html_template.replace("REPLACE_CODE_FILE_RIGHT", pcc.get_markdown_block_in_html(markdown_convertion_for_file(warp(revised_content))))
|
||||||
@@ -73,17 +88,21 @@ def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
|
|||||||
file_tree_struct.manifest[fp].compare_html = compare_html_path
|
file_tree_struct.manifest[fp].compare_html = compare_html_path
|
||||||
with open(compare_html_path, 'w', encoding='utf-8') as f:
|
with open(compare_html_path, 'w', encoding='utf-8') as f:
|
||||||
f.write(html_template)
|
f.write(html_template)
|
||||||
# print('done 1')
|
tasks[index] = ""
|
||||||
|
|
||||||
chatbot.append([None, f"正在处理:"])
|
chatbot.append([None, f"正在处理:"])
|
||||||
futures = []
|
futures = []
|
||||||
|
index = 0
|
||||||
for i_say, gpt_say, fp in zip(gpt_response_collection[0::2], gpt_response_collection[1::2], file_manifest):
|
for i_say, gpt_say, fp in zip(gpt_response_collection[0::2], gpt_response_collection[1::2], file_manifest):
|
||||||
future = executor.submit(_task_multi_threading, i_say, gpt_say, fp, file_tree_struct)
|
future = executor.submit(_task_multi_threading, i_say, gpt_say, fp, file_tree_struct, index)
|
||||||
|
index += 1
|
||||||
futures.append(future)
|
futures.append(future)
|
||||||
|
|
||||||
|
# <第三步,等待任务完成>
|
||||||
cnt = 0
|
cnt = 0
|
||||||
while True:
|
while True:
|
||||||
cnt += 1
|
cnt += 1
|
||||||
|
wd.feed()
|
||||||
time.sleep(3)
|
time.sleep(3)
|
||||||
worker_done = [h.done() for h in futures]
|
worker_done = [h.done() for h in futures]
|
||||||
remain = len(worker_done) - sum(worker_done)
|
remain = len(worker_done) - sum(worker_done)
|
||||||
@@ -92,14 +111,18 @@ def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
|
|||||||
preview_html_list = []
|
preview_html_list = []
|
||||||
for done, fp in zip(worker_done, file_manifest):
|
for done, fp in zip(worker_done, file_manifest):
|
||||||
if not done: continue
|
if not done: continue
|
||||||
preview_html_list.append(file_tree_struct.manifest[fp].compare_html)
|
if hasattr(file_tree_struct.manifest[fp], 'compare_html'):
|
||||||
|
preview_html_list.append(file_tree_struct.manifest[fp].compare_html)
|
||||||
|
else:
|
||||||
|
logger.error(f"文件: {fp} 的注释结果未能成功")
|
||||||
file_links = generate_file_link(preview_html_list)
|
file_links = generate_file_link(preview_html_list)
|
||||||
|
|
||||||
yield from update_ui_lastest_msg(
|
yield from update_ui_lastest_msg(
|
||||||
f"剩余源文件数量: {remain}.\n\n" +
|
f"当前任务: <br/>{'<br/>'.join(tasks)}.<br/>" +
|
||||||
f"已完成的文件: {sum(worker_done)}.\n\n" +
|
f"剩余源文件数量: {remain}.<br/>" +
|
||||||
|
f"已完成的文件: {sum(worker_done)}.<br/>" +
|
||||||
file_links +
|
file_links +
|
||||||
"\n\n" +
|
"<br/>" +
|
||||||
''.join(['.']*(cnt % 10 + 1)
|
''.join(['.']*(cnt % 10 + 1)
|
||||||
), chatbot=chatbot, history=history, delay=0)
|
), chatbot=chatbot, history=history, delay=0)
|
||||||
yield from update_ui(chatbot=chatbot, history=[]) # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=[]) # 刷新界面
|
||||||
@@ -120,6 +143,7 @@ def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
|
|||||||
@CatchException
|
@CatchException
|
||||||
def 注释Python项目(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
def 注释Python项目(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||||
history = [] # 清空历史,以免输入溢出
|
history = [] # 清空历史,以免输入溢出
|
||||||
|
plugin_kwargs["use_chinese"] = plugin_kwargs.get("use_chinese", False)
|
||||||
import glob, os
|
import glob, os
|
||||||
if os.path.exists(txt):
|
if os.path.exists(txt):
|
||||||
project_folder = txt
|
project_folder = txt
|
||||||
|
|||||||
@@ -0,0 +1,36 @@
|
|||||||
|
|
||||||
|
from toolbox import get_conf, update_ui
|
||||||
|
from crazy_functions.plugin_template.plugin_class_template import GptAcademicPluginTemplate, ArgProperty
|
||||||
|
from crazy_functions.SourceCode_Comment import 注释Python项目
|
||||||
|
|
||||||
|
class SourceCodeComment_Wrap(GptAcademicPluginTemplate):
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
请注意`execute`会执行在不同的线程中,因此您在定义和使用类变量时,应当慎之又慎!
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def define_arg_selection_menu(self):
|
||||||
|
"""
|
||||||
|
定义插件的二级选项菜单
|
||||||
|
"""
|
||||||
|
gui_definition = {
|
||||||
|
"main_input":
|
||||||
|
ArgProperty(title="路径", description="程序路径(上传文件后自动填写)", default_value="", type="string").model_dump_json(), # 主输入,自动从输入框同步
|
||||||
|
"use_chinese":
|
||||||
|
ArgProperty(title="注释语言", options=["英文", "中文"], default_value="英文", description="无", type="dropdown").model_dump_json(),
|
||||||
|
# "use_emoji":
|
||||||
|
# ArgProperty(title="在注释中使用emoji", options=["禁止", "允许"], default_value="禁止", description="无", type="dropdown").model_dump_json(),
|
||||||
|
}
|
||||||
|
return gui_definition
|
||||||
|
|
||||||
|
def execute(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||||
|
"""
|
||||||
|
执行插件
|
||||||
|
"""
|
||||||
|
if plugin_kwargs["use_chinese"] == "中文":
|
||||||
|
plugin_kwargs["use_chinese"] = True
|
||||||
|
else:
|
||||||
|
plugin_kwargs["use_chinese"] = False
|
||||||
|
|
||||||
|
yield from 注释Python项目(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)
|
||||||
@@ -68,6 +68,7 @@ Be aware:
|
|||||||
1. You must NOT modify the indent of code.
|
1. You must NOT modify the indent of code.
|
||||||
2. You are NOT authorized to change or translate non-comment code, and you are NOT authorized to add empty lines either, toggle qu.
|
2. You are NOT authorized to change or translate non-comment code, and you are NOT authorized to add empty lines either, toggle qu.
|
||||||
3. Use {LANG} to add comments and docstrings. Do NOT translate Chinese that is already in the code.
|
3. Use {LANG} to add comments and docstrings. Do NOT translate Chinese that is already in the code.
|
||||||
|
4. Besides adding a docstring, use the ⭐ symbol to annotate the most core and important line of code within the function, explaining its role.
|
||||||
|
|
||||||
------------------ Example ------------------
|
------------------ Example ------------------
|
||||||
INPUT:
|
INPUT:
|
||||||
@@ -116,10 +117,66 @@ def zip_result(folder):
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
revise_funtion_prompt_chinese = '''
|
||||||
|
您需要阅读以下代码,并根据以下说明修订源代码({FILE_BASENAME}):
|
||||||
|
1. 如果源代码中包含函数的话, 你应该分析给定函数实现了什么功能
|
||||||
|
2. 如果源代码中包含函数的话, 你需要为函数添加docstring, docstring必须使用中文
|
||||||
|
|
||||||
|
请注意:
|
||||||
|
1. 你不得修改代码的缩进
|
||||||
|
2. 你无权更改或翻译代码中的非注释部分,也不允许添加空行
|
||||||
|
3. 使用 {LANG} 添加注释和文档字符串。不要翻译代码中已有的中文
|
||||||
|
4. 除了添加docstring之外, 使用⭐符号给该函数中最核心、最重要的一行代码添加注释,并说明其作用
|
||||||
|
|
||||||
|
------------------ 示例 ------------------
|
||||||
|
INPUT:
|
||||||
|
```
|
||||||
|
L0000 |
|
||||||
|
L0001 |def zip_result(folder):
|
||||||
|
L0002 | t = gen_time_str()
|
||||||
|
L0003 | zip_folder(folder, get_log_folder(), f"result.zip")
|
||||||
|
L0004 | return os.path.join(get_log_folder(), f"result.zip")
|
||||||
|
L0005 |
|
||||||
|
L0006 |
|
||||||
|
```
|
||||||
|
|
||||||
|
OUTPUT:
|
||||||
|
|
||||||
|
<instruction_1_purpose>
|
||||||
|
该函数用于压缩指定文件夹,并返回生成的`zip`文件的路径。
|
||||||
|
</instruction_1_purpose>
|
||||||
|
<instruction_2_revised_code>
|
||||||
|
```
|
||||||
|
def zip_result(folder):
|
||||||
|
"""
|
||||||
|
该函数将指定的文件夹压缩成ZIP文件, 并将其存储在日志文件夹中。
|
||||||
|
|
||||||
|
输入参数:
|
||||||
|
folder (str): 需要压缩的文件夹的路径。
|
||||||
|
返回值:
|
||||||
|
str: 日志文件夹中创建的ZIP文件的路径。
|
||||||
|
"""
|
||||||
|
t = gen_time_str()
|
||||||
|
zip_folder(folder, get_log_folder(), f"result.zip") # ⭐ 执行文件夹的压缩
|
||||||
|
return os.path.join(get_log_folder(), f"result.zip")
|
||||||
|
```
|
||||||
|
</instruction_2_revised_code>
|
||||||
|
------------------ End of Example ------------------
|
||||||
|
|
||||||
|
|
||||||
|
------------------ the real INPUT you need to process NOW ({FILE_BASENAME}) ------------------
|
||||||
|
```
|
||||||
|
{THE_CODE}
|
||||||
|
```
|
||||||
|
{INDENT_REMINDER}
|
||||||
|
{BRIEF_REMINDER}
|
||||||
|
{HINT_REMINDER}
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
class PythonCodeComment():
|
class PythonCodeComment():
|
||||||
|
|
||||||
def __init__(self, llm_kwargs, language) -> None:
|
def __init__(self, llm_kwargs, plugin_kwargs, language, observe_window_update) -> None:
|
||||||
self.original_content = ""
|
self.original_content = ""
|
||||||
self.full_context = []
|
self.full_context = []
|
||||||
self.full_context_with_line_no = []
|
self.full_context_with_line_no = []
|
||||||
@@ -127,7 +184,13 @@ class PythonCodeComment():
|
|||||||
self.page_limit = 100 # 100 lines of code each page
|
self.page_limit = 100 # 100 lines of code each page
|
||||||
self.ignore_limit = 20
|
self.ignore_limit = 20
|
||||||
self.llm_kwargs = llm_kwargs
|
self.llm_kwargs = llm_kwargs
|
||||||
|
self.plugin_kwargs = plugin_kwargs
|
||||||
self.language = language
|
self.language = language
|
||||||
|
self.observe_window_update = observe_window_update
|
||||||
|
if self.language == "chinese":
|
||||||
|
self.core_prompt = revise_funtion_prompt_chinese
|
||||||
|
else:
|
||||||
|
self.core_prompt = revise_funtion_prompt
|
||||||
self.path = None
|
self.path = None
|
||||||
self.file_basename = None
|
self.file_basename = None
|
||||||
self.file_brief = ""
|
self.file_brief = ""
|
||||||
@@ -258,7 +321,7 @@ class PythonCodeComment():
|
|||||||
hint_reminder = "" if hint is None else f"(Reminder: do not ignore or modify code such as `{hint}`, provide complete code in the OUTPUT.)"
|
hint_reminder = "" if hint is None else f"(Reminder: do not ignore or modify code such as `{hint}`, provide complete code in the OUTPUT.)"
|
||||||
self.llm_kwargs['temperature'] = 0
|
self.llm_kwargs['temperature'] = 0
|
||||||
result = predict_no_ui_long_connection(
|
result = predict_no_ui_long_connection(
|
||||||
inputs=revise_funtion_prompt.format(
|
inputs=self.core_prompt.format(
|
||||||
LANG=self.language,
|
LANG=self.language,
|
||||||
FILE_BASENAME=self.file_basename,
|
FILE_BASENAME=self.file_basename,
|
||||||
THE_CODE=code,
|
THE_CODE=code,
|
||||||
@@ -348,6 +411,7 @@ class PythonCodeComment():
|
|||||||
try:
|
try:
|
||||||
# yield from update_ui_lastest_msg(f"({self.file_basename}) 正在读取下一段代码片段:\n", chatbot=chatbot, history=history, delay=0)
|
# yield from update_ui_lastest_msg(f"({self.file_basename}) 正在读取下一段代码片段:\n", chatbot=chatbot, history=history, delay=0)
|
||||||
next_batch, line_no_start, line_no_end = self.get_next_batch()
|
next_batch, line_no_start, line_no_end = self.get_next_batch()
|
||||||
|
self.observe_window_update(f"正在处理{self.file_basename} - {line_no_start}/{len(self.full_context)}\n")
|
||||||
# yield from update_ui_lastest_msg(f"({self.file_basename}) 处理代码片段:\n\n{next_batch}", chatbot=chatbot, history=history, delay=0)
|
# yield from update_ui_lastest_msg(f"({self.file_basename}) 处理代码片段:\n\n{next_batch}", chatbot=chatbot, history=history, delay=0)
|
||||||
|
|
||||||
hint = None
|
hint = None
|
||||||
|
|||||||
@@ -1,39 +1,47 @@
|
|||||||
import ast
|
import token
|
||||||
|
import tokenize
|
||||||
|
import copy
|
||||||
|
import io
|
||||||
|
|
||||||
class CommentRemover(ast.NodeTransformer):
|
|
||||||
def visit_FunctionDef(self, node):
|
|
||||||
# 移除函数的文档字符串
|
|
||||||
if (node.body and isinstance(node.body[0], ast.Expr) and
|
|
||||||
isinstance(node.body[0].value, ast.Str)):
|
|
||||||
node.body = node.body[1:]
|
|
||||||
self.generic_visit(node)
|
|
||||||
return node
|
|
||||||
|
|
||||||
def visit_ClassDef(self, node):
|
def remove_python_comments(input_source: str) -> str:
|
||||||
# 移除类的文档字符串
|
source_flag = copy.copy(input_source)
|
||||||
if (node.body and isinstance(node.body[0], ast.Expr) and
|
source = io.StringIO(input_source)
|
||||||
isinstance(node.body[0].value, ast.Str)):
|
ls = input_source.split('\n')
|
||||||
node.body = node.body[1:]
|
prev_toktype = token.INDENT
|
||||||
self.generic_visit(node)
|
readline = source.readline
|
||||||
return node
|
|
||||||
|
|
||||||
def visit_Module(self, node):
|
def get_char_index(lineno, col):
|
||||||
# 移除模块的文档字符串
|
# find the index of the char in the source code
|
||||||
if (node.body and isinstance(node.body[0], ast.Expr) and
|
if lineno == 1:
|
||||||
isinstance(node.body[0].value, ast.Str)):
|
return len('\n'.join(ls[:(lineno-1)])) + col
|
||||||
node.body = node.body[1:]
|
else:
|
||||||
self.generic_visit(node)
|
return len('\n'.join(ls[:(lineno-1)])) + col + 1
|
||||||
return node
|
|
||||||
|
def replace_char_between(start_lineno, start_col, end_lineno, end_col, source, replace_char, ls):
|
||||||
|
# replace char between start_lineno, start_col and end_lineno, end_col with replace_char, but keep '\n' and ' '
|
||||||
|
b = get_char_index(start_lineno, start_col)
|
||||||
|
e = get_char_index(end_lineno, end_col)
|
||||||
|
for i in range(b, e):
|
||||||
|
if source[i] == '\n':
|
||||||
|
source = source[:i] + '\n' + source[i+1:]
|
||||||
|
elif source[i] == ' ':
|
||||||
|
source = source[:i] + ' ' + source[i+1:]
|
||||||
|
else:
|
||||||
|
source = source[:i] + replace_char + source[i+1:]
|
||||||
|
return source
|
||||||
|
|
||||||
|
tokgen = tokenize.generate_tokens(readline)
|
||||||
|
for toktype, ttext, (slineno, scol), (elineno, ecol), ltext in tokgen:
|
||||||
|
if toktype == token.STRING and (prev_toktype == token.INDENT):
|
||||||
|
source_flag = replace_char_between(slineno, scol, elineno, ecol, source_flag, ' ', ls)
|
||||||
|
elif toktype == token.STRING and (prev_toktype == token.NEWLINE):
|
||||||
|
source_flag = replace_char_between(slineno, scol, elineno, ecol, source_flag, ' ', ls)
|
||||||
|
elif toktype == tokenize.COMMENT:
|
||||||
|
source_flag = replace_char_between(slineno, scol, elineno, ecol, source_flag, ' ', ls)
|
||||||
|
prev_toktype = toktype
|
||||||
|
return source_flag
|
||||||
|
|
||||||
def remove_python_comments(source_code):
|
|
||||||
# 解析源代码为 AST
|
|
||||||
tree = ast.parse(source_code)
|
|
||||||
# 移除注释
|
|
||||||
transformer = CommentRemover()
|
|
||||||
tree = transformer.visit(tree)
|
|
||||||
# 将处理后的 AST 转换回源代码
|
|
||||||
return ast.unparse(tree)
|
|
||||||
|
|
||||||
# 示例使用
|
# 示例使用
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -0,0 +1,450 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import datetime
|
||||||
|
from docx import Document
|
||||||
|
from docx.enum.style import WD_STYLE_TYPE
|
||||||
|
from docx.enum.text import WD_PARAGRAPH_ALIGNMENT, WD_LINE_SPACING
|
||||||
|
from docx.oxml.ns import qn
|
||||||
|
from docx.shared import Inches, Cm
|
||||||
|
from docx.shared import Pt, RGBColor, Inches
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentFormatter(ABC):
|
||||||
|
"""文档格式化基类,定义文档格式化的基本接口"""
|
||||||
|
|
||||||
|
def __init__(self, final_summary: str, file_summaries_map: Dict, failed_files: List[Tuple]):
|
||||||
|
self.final_summary = final_summary
|
||||||
|
self.file_summaries_map = file_summaries_map
|
||||||
|
self.failed_files = failed_files
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_failed_files(self) -> str:
|
||||||
|
"""格式化失败文件列表"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_file_summaries(self) -> str:
|
||||||
|
"""格式化文件总结内容"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_document(self) -> str:
|
||||||
|
"""创建完整文档"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class WordFormatter(DocumentFormatter):
|
||||||
|
"""Word格式文档生成器 - 符合中国政府公文格式规范(GB/T 9704-2012),并进行了优化"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.doc = Document()
|
||||||
|
self._setup_document()
|
||||||
|
self._create_styles()
|
||||||
|
# 初始化三级标题编号系统
|
||||||
|
self.numbers = {
|
||||||
|
1: 0, # 一级标题编号
|
||||||
|
2: 0, # 二级标题编号
|
||||||
|
3: 0 # 三级标题编号
|
||||||
|
}
|
||||||
|
|
||||||
|
def _setup_document(self):
|
||||||
|
"""设置文档基本格式,包括页面设置和页眉"""
|
||||||
|
sections = self.doc.sections
|
||||||
|
for section in sections:
|
||||||
|
# 设置页面大小为A4
|
||||||
|
section.page_width = Cm(21)
|
||||||
|
section.page_height = Cm(29.7)
|
||||||
|
# 设置页边距
|
||||||
|
section.top_margin = Cm(3.7) # 上边距37mm
|
||||||
|
section.bottom_margin = Cm(3.5) # 下边距35mm
|
||||||
|
section.left_margin = Cm(2.8) # 左边距28mm
|
||||||
|
section.right_margin = Cm(2.6) # 右边距26mm
|
||||||
|
# 设置页眉页脚距离
|
||||||
|
section.header_distance = Cm(2.0)
|
||||||
|
section.footer_distance = Cm(2.0)
|
||||||
|
|
||||||
|
# 添加页眉
|
||||||
|
header = section.header
|
||||||
|
header_para = header.paragraphs[0]
|
||||||
|
header_para.alignment = WD_PARAGRAPH_ALIGNMENT.RIGHT
|
||||||
|
header_run = header_para.add_run("该文档由GPT-academic生成")
|
||||||
|
header_run.font.name = '仿宋'
|
||||||
|
header_run._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
|
||||||
|
header_run.font.size = Pt(9)
|
||||||
|
|
||||||
|
def _create_styles(self):
|
||||||
|
"""创建文档样式"""
|
||||||
|
# 创建正文样式
|
||||||
|
style = self.doc.styles.add_style('Normal_Custom', WD_STYLE_TYPE.PARAGRAPH)
|
||||||
|
style.font.name = '仿宋'
|
||||||
|
style._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
|
||||||
|
style.font.size = Pt(14)
|
||||||
|
style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
|
||||||
|
style.paragraph_format.space_after = Pt(0)
|
||||||
|
style.paragraph_format.first_line_indent = Pt(28)
|
||||||
|
|
||||||
|
# 创建各级标题样式
|
||||||
|
self._create_heading_style('Title_Custom', '方正小标宋简体', 32, WD_PARAGRAPH_ALIGNMENT.CENTER)
|
||||||
|
self._create_heading_style('Heading1_Custom', '黑体', 22, WD_PARAGRAPH_ALIGNMENT.LEFT)
|
||||||
|
self._create_heading_style('Heading2_Custom', '黑体', 18, WD_PARAGRAPH_ALIGNMENT.LEFT)
|
||||||
|
self._create_heading_style('Heading3_Custom', '黑体', 16, WD_PARAGRAPH_ALIGNMENT.LEFT)
|
||||||
|
|
||||||
|
def _create_heading_style(self, style_name: str, font_name: str, font_size: int, alignment):
|
||||||
|
"""创建标题样式"""
|
||||||
|
style = self.doc.styles.add_style(style_name, WD_STYLE_TYPE.PARAGRAPH)
|
||||||
|
style.font.name = font_name
|
||||||
|
style._element.rPr.rFonts.set(qn('w:eastAsia'), font_name)
|
||||||
|
style.font.size = Pt(font_size)
|
||||||
|
style.font.bold = True
|
||||||
|
style.paragraph_format.alignment = alignment
|
||||||
|
style.paragraph_format.space_before = Pt(12)
|
||||||
|
style.paragraph_format.space_after = Pt(12)
|
||||||
|
style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
|
||||||
|
return style
|
||||||
|
|
||||||
|
def _get_heading_number(self, level: int) -> str:
|
||||||
|
"""
|
||||||
|
生成标题编号
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: 标题级别 (0-3)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 格式化的标题编号
|
||||||
|
"""
|
||||||
|
if level == 0: # 主标题不需要编号
|
||||||
|
return ""
|
||||||
|
|
||||||
|
self.numbers[level] += 1 # 增加当前级别的编号
|
||||||
|
|
||||||
|
# 重置下级标题编号
|
||||||
|
for i in range(level + 1, 4):
|
||||||
|
self.numbers[i] = 0
|
||||||
|
|
||||||
|
# 根据级别返回不同格式的编号
|
||||||
|
if level == 1:
|
||||||
|
return f"{self.numbers[1]}. "
|
||||||
|
elif level == 2:
|
||||||
|
return f"{self.numbers[1]}.{self.numbers[2]} "
|
||||||
|
elif level == 3:
|
||||||
|
return f"{self.numbers[1]}.{self.numbers[2]}.{self.numbers[3]} "
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _add_heading(self, text: str, level: int):
|
||||||
|
"""
|
||||||
|
添加带编号的标题
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 标题文本
|
||||||
|
level: 标题级别 (0-3)
|
||||||
|
"""
|
||||||
|
style_map = {
|
||||||
|
0: 'Title_Custom',
|
||||||
|
1: 'Heading1_Custom',
|
||||||
|
2: 'Heading2_Custom',
|
||||||
|
3: 'Heading3_Custom'
|
||||||
|
}
|
||||||
|
|
||||||
|
number = self._get_heading_number(level)
|
||||||
|
paragraph = self.doc.add_paragraph(style=style_map[level])
|
||||||
|
|
||||||
|
if number:
|
||||||
|
number_run = paragraph.add_run(number)
|
||||||
|
font_size = 22 if level == 1 else (18 if level == 2 else 16)
|
||||||
|
self._get_run_style(number_run, '黑体', font_size, True)
|
||||||
|
|
||||||
|
text_run = paragraph.add_run(text)
|
||||||
|
font_size = 32 if level == 0 else (22 if level == 1 else (18 if level == 2 else 16))
|
||||||
|
self._get_run_style(text_run, '黑体', font_size, True)
|
||||||
|
|
||||||
|
# 主标题添加日期
|
||||||
|
if level == 0:
|
||||||
|
date_paragraph = self.doc.add_paragraph()
|
||||||
|
date_paragraph.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
|
||||||
|
date_run = date_paragraph.add_run(datetime.now().strftime('%Y年%m月%d日'))
|
||||||
|
self._get_run_style(date_run, '仿宋', 16, False)
|
||||||
|
|
||||||
|
return paragraph
|
||||||
|
|
||||||
|
def _get_run_style(self, run, font_name: str, font_size: int, bold: bool = False):
|
||||||
|
"""设置文本运行对象的样式"""
|
||||||
|
run.font.name = font_name
|
||||||
|
run._element.rPr.rFonts.set(qn('w:eastAsia'), font_name)
|
||||||
|
run.font.size = Pt(font_size)
|
||||||
|
run.font.bold = bold
|
||||||
|
|
||||||
|
def format_failed_files(self) -> str:
|
||||||
|
"""格式化失败文件列表"""
|
||||||
|
result = []
|
||||||
|
if not self.failed_files:
|
||||||
|
return "\n".join(result)
|
||||||
|
|
||||||
|
result.append("处理失败文件:")
|
||||||
|
for fp, reason in self.failed_files:
|
||||||
|
result.append(f"• {os.path.basename(fp)}: {reason}")
|
||||||
|
|
||||||
|
self._add_heading("处理失败文件", 1)
|
||||||
|
for fp, reason in self.failed_files:
|
||||||
|
self._add_content(f"• {os.path.basename(fp)}: {reason}", indent=False)
|
||||||
|
self.doc.add_paragraph()
|
||||||
|
|
||||||
|
return "\n".join(result)
|
||||||
|
|
||||||
|
def _add_content(self, text: str, indent: bool = True):
|
||||||
|
"""添加正文内容"""
|
||||||
|
paragraph = self.doc.add_paragraph(text, style='Normal_Custom')
|
||||||
|
if not indent:
|
||||||
|
paragraph.paragraph_format.first_line_indent = Pt(0)
|
||||||
|
return paragraph
|
||||||
|
|
||||||
|
def format_file_summaries(self) -> str:
|
||||||
|
"""
|
||||||
|
格式化文件总结内容,确保正确的标题层级
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 格式化后的文件总结字符串
|
||||||
|
|
||||||
|
标题层级规则:
|
||||||
|
1. 一级标题为"各文件详细总结"
|
||||||
|
2. 如果文件有目录路径:
|
||||||
|
- 目录路径作为二级标题 (2.1, 2.2 等)
|
||||||
|
- 该目录下所有文件作为三级标题 (2.1.1, 2.1.2 等)
|
||||||
|
3. 如果文件没有目录路径:
|
||||||
|
- 文件直接作为二级标题 (2.1, 2.2 等)
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
# 首先对文件路径进行分组整理
|
||||||
|
file_groups = {}
|
||||||
|
for path in sorted(self.file_summaries_map.keys()):
|
||||||
|
dir_path = os.path.dirname(path)
|
||||||
|
if dir_path not in file_groups:
|
||||||
|
file_groups[dir_path] = []
|
||||||
|
file_groups[dir_path].append(path)
|
||||||
|
|
||||||
|
# 处理没有目录的文件
|
||||||
|
root_files = file_groups.get("", [])
|
||||||
|
if root_files:
|
||||||
|
for path in sorted(root_files):
|
||||||
|
file_name = os.path.basename(path)
|
||||||
|
result.append(f"\n📄 {file_name}")
|
||||||
|
result.append(self.file_summaries_map[path])
|
||||||
|
# 无目录的文件作为二级标题
|
||||||
|
self._add_heading(f"📄 {file_name}", 2)
|
||||||
|
self._add_content(self.file_summaries_map[path])
|
||||||
|
self.doc.add_paragraph()
|
||||||
|
|
||||||
|
# 处理有目录的文件
|
||||||
|
for dir_path in sorted(file_groups.keys()):
|
||||||
|
if dir_path == "": # 跳过已处理的根目录文件
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 添加目录作为二级标题
|
||||||
|
result.append(f"\n📁 {dir_path}")
|
||||||
|
self._add_heading(f"📁 {dir_path}", 2)
|
||||||
|
|
||||||
|
# 该目录下的所有文件作为三级标题
|
||||||
|
for path in sorted(file_groups[dir_path]):
|
||||||
|
file_name = os.path.basename(path)
|
||||||
|
result.append(f"\n📄 {file_name}")
|
||||||
|
result.append(self.file_summaries_map[path])
|
||||||
|
|
||||||
|
# 添加文件名作为三级标题
|
||||||
|
self._add_heading(f"📄 {file_name}", 3)
|
||||||
|
self._add_content(self.file_summaries_map[path])
|
||||||
|
self.doc.add_paragraph()
|
||||||
|
|
||||||
|
return "\n".join(result)
|
||||||
|
|
||||||
|
|
||||||
|
def create_document(self):
|
||||||
|
"""创建完整Word文档并返回文档对象"""
|
||||||
|
# 重置所有编号
|
||||||
|
for level in self.numbers:
|
||||||
|
self.numbers[level] = 0
|
||||||
|
|
||||||
|
# 添加主标题
|
||||||
|
self._add_heading("文档总结报告", 0)
|
||||||
|
self.doc.add_paragraph()
|
||||||
|
|
||||||
|
# 添加总体摘要
|
||||||
|
self._add_heading("总体摘要", 1)
|
||||||
|
self._add_content(self.final_summary)
|
||||||
|
self.doc.add_paragraph()
|
||||||
|
|
||||||
|
# 添加失败文件列表(如果有)
|
||||||
|
if self.failed_files:
|
||||||
|
self.format_failed_files()
|
||||||
|
|
||||||
|
# 添加文件详细总结
|
||||||
|
self._add_heading("各文件详细总结", 1)
|
||||||
|
self.format_file_summaries()
|
||||||
|
|
||||||
|
return self.doc
|
||||||
|
|
||||||
|
|
||||||
|
class MarkdownFormatter(DocumentFormatter):
|
||||||
|
"""Markdown格式文档生成器"""
|
||||||
|
|
||||||
|
def format_failed_files(self) -> str:
|
||||||
|
if not self.failed_files:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
formatted_text = ["\n## ⚠️ 处理失败的文件"]
|
||||||
|
for fp, reason in self.failed_files:
|
||||||
|
formatted_text.append(f"- {os.path.basename(fp)}: {reason}")
|
||||||
|
formatted_text.append("\n---")
|
||||||
|
return "\n".join(formatted_text)
|
||||||
|
|
||||||
|
def format_file_summaries(self) -> str:
|
||||||
|
formatted_text = []
|
||||||
|
sorted_paths = sorted(self.file_summaries_map.keys())
|
||||||
|
current_dir = ""
|
||||||
|
|
||||||
|
for path in sorted_paths:
|
||||||
|
dir_path = os.path.dirname(path)
|
||||||
|
if dir_path != current_dir:
|
||||||
|
if dir_path:
|
||||||
|
formatted_text.append(f"\n## 📁 {dir_path}")
|
||||||
|
current_dir = dir_path
|
||||||
|
|
||||||
|
file_name = os.path.basename(path)
|
||||||
|
formatted_text.append(f"\n### 📄 {file_name}")
|
||||||
|
formatted_text.append(self.file_summaries_map[path])
|
||||||
|
formatted_text.append("\n---")
|
||||||
|
|
||||||
|
return "\n".join(formatted_text)
|
||||||
|
|
||||||
|
def create_document(self) -> str:
|
||||||
|
document = [
|
||||||
|
"# 📑 文档总结报告",
|
||||||
|
"\n## 总体摘要",
|
||||||
|
self.final_summary
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.failed_files:
|
||||||
|
document.append(self.format_failed_files())
|
||||||
|
|
||||||
|
document.extend([
|
||||||
|
"\n# 📚 各文件详细总结",
|
||||||
|
self.format_file_summaries()
|
||||||
|
])
|
||||||
|
|
||||||
|
return "\n".join(document)
|
||||||
|
|
||||||
|
|
||||||
|
class HtmlFormatter(DocumentFormatter):
|
||||||
|
"""HTML格式文档生成器"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.css_styles = """
|
||||||
|
body {
|
||||||
|
font-family: "Microsoft YaHei", Arial, sans-serif;
|
||||||
|
line-height: 1.6;
|
||||||
|
max-width: 1000px;
|
||||||
|
margin: 0 auto;
|
||||||
|
padding: 20px;
|
||||||
|
color: #333;
|
||||||
|
}
|
||||||
|
h1 {
|
||||||
|
color: #2c3e50;
|
||||||
|
border-bottom: 2px solid #eee;
|
||||||
|
padding-bottom: 10px;
|
||||||
|
font-size: 24px;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
h2 {
|
||||||
|
color: #34495e;
|
||||||
|
margin-top: 30px;
|
||||||
|
font-size: 20px;
|
||||||
|
border-left: 4px solid #3498db;
|
||||||
|
padding-left: 10px;
|
||||||
|
}
|
||||||
|
h3 {
|
||||||
|
color: #2c3e50;
|
||||||
|
font-size: 18px;
|
||||||
|
margin-top: 20px;
|
||||||
|
}
|
||||||
|
.summary {
|
||||||
|
background-color: #f8f9fa;
|
||||||
|
padding: 20px;
|
||||||
|
border-radius: 5px;
|
||||||
|
margin: 20px 0;
|
||||||
|
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
||||||
|
}
|
||||||
|
.details {
|
||||||
|
margin-top: 40px;
|
||||||
|
}
|
||||||
|
.failed-files {
|
||||||
|
background-color: #fff3f3;
|
||||||
|
padding: 15px;
|
||||||
|
border-left: 4px solid #e74c3c;
|
||||||
|
margin: 20px 0;
|
||||||
|
}
|
||||||
|
.file-summary {
|
||||||
|
background-color: #fff;
|
||||||
|
padding: 15px;
|
||||||
|
margin: 15px 0;
|
||||||
|
border-radius: 4px;
|
||||||
|
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def format_failed_files(self) -> str:
|
||||||
|
if not self.failed_files:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
failed_files_html = ['<div class="failed-files">']
|
||||||
|
failed_files_html.append("<h2>⚠️ 处理失败的文件</h2>")
|
||||||
|
failed_files_html.append("<ul>")
|
||||||
|
for fp, reason in self.failed_files:
|
||||||
|
failed_files_html.append(f"<li><strong>{os.path.basename(fp)}:</strong> {reason}</li>")
|
||||||
|
failed_files_html.append("</ul></div>")
|
||||||
|
return "\n".join(failed_files_html)
|
||||||
|
|
||||||
|
def format_file_summaries(self) -> str:
|
||||||
|
formatted_html = []
|
||||||
|
sorted_paths = sorted(self.file_summaries_map.keys())
|
||||||
|
current_dir = ""
|
||||||
|
|
||||||
|
for path in sorted_paths:
|
||||||
|
dir_path = os.path.dirname(path)
|
||||||
|
if dir_path != current_dir:
|
||||||
|
if dir_path:
|
||||||
|
formatted_html.append(f'<h2>📁 {dir_path}</h2>')
|
||||||
|
current_dir = dir_path
|
||||||
|
|
||||||
|
file_name = os.path.basename(path)
|
||||||
|
formatted_html.append('<div class="file-summary">')
|
||||||
|
formatted_html.append(f'<h3>📄 {file_name}</h3>')
|
||||||
|
formatted_html.append(f'<p>{self.file_summaries_map[path]}</p>')
|
||||||
|
formatted_html.append('</div>')
|
||||||
|
|
||||||
|
return "\n".join(formatted_html)
|
||||||
|
|
||||||
|
def create_document(self) -> str:
|
||||||
|
return f"""
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset='utf-8'>
|
||||||
|
<title>文档总结报告</title>
|
||||||
|
<style>{self.css_styles}</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>📑 文档总结报告</h1>
|
||||||
|
<h2>总体摘要</h2>
|
||||||
|
<div class="summary">{self.final_summary}</div>
|
||||||
|
{self.format_failed_files()}
|
||||||
|
<div class="details">
|
||||||
|
<h2>📚 各文件详细总结</h2>
|
||||||
|
{self.format_file_summaries()}
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@@ -0,0 +1,387 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, Optional, Type, TypeVar, Generic, Union
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum, auto
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment
|
||||||
|
|
||||||
|
# 设置日志
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# 自定义异常类定义
|
||||||
|
class FoldingError(Exception):
|
||||||
|
"""折叠相关的自定义异常基类"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FormattingError(FoldingError):
|
||||||
|
"""格式化过程中的错误"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataError(FoldingError):
|
||||||
|
"""元数据相关的错误"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationError(FoldingError):
|
||||||
|
"""验证错误"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FoldingStyle(Enum):
|
||||||
|
"""折叠样式枚举"""
|
||||||
|
SIMPLE = auto() # 简单折叠
|
||||||
|
DETAILED = auto() # 详细折叠(带有额外信息)
|
||||||
|
NESTED = auto() # 嵌套折叠
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FoldingOptions:
|
||||||
|
"""折叠选项配置"""
|
||||||
|
style: FoldingStyle = FoldingStyle.DETAILED
|
||||||
|
code_language: Optional[str] = None # 代码块的语言
|
||||||
|
show_timestamp: bool = False # 是否显示时间戳
|
||||||
|
indent_level: int = 0 # 缩进级别
|
||||||
|
custom_css: Optional[str] = None # 自定义CSS类
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar('T') # 用于泛型类型
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMetadata(ABC):
|
||||||
|
"""元数据基类"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def validate(self) -> bool:
|
||||||
|
"""验证元数据的有效性"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _validate_non_empty_str(self, value: Optional[str]) -> bool:
|
||||||
|
"""验证字符串非空"""
|
||||||
|
return bool(value and value.strip())
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FileMetadata(BaseMetadata):
|
||||||
|
"""文件元数据"""
|
||||||
|
rel_path: str
|
||||||
|
size: float
|
||||||
|
last_modified: Optional[datetime] = None
|
||||||
|
mime_type: Optional[str] = None
|
||||||
|
encoding: str = 'utf-8'
|
||||||
|
|
||||||
|
def validate(self) -> bool:
|
||||||
|
"""验证文件元数据的有效性"""
|
||||||
|
try:
|
||||||
|
if not self._validate_non_empty_str(self.rel_path):
|
||||||
|
return False
|
||||||
|
if self.size < 0:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"File metadata validation error: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ContentFormatter(ABC, Generic[T]):
|
||||||
|
"""内容格式化抽象基类
|
||||||
|
|
||||||
|
支持泛型类型参数,可以指定具体的元数据类型。
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format(self,
|
||||||
|
content: str,
|
||||||
|
metadata: T,
|
||||||
|
options: Optional[FoldingOptions] = None) -> str:
|
||||||
|
"""格式化内容
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: 需要格式化的内容
|
||||||
|
metadata: 类型化的元数据
|
||||||
|
options: 折叠选项
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 格式化后的内容
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FormattingError: 格式化过程中的错误
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _create_summary(self, metadata: T) -> str:
|
||||||
|
"""创建折叠摘要,可被子类重写"""
|
||||||
|
return str(metadata)
|
||||||
|
|
||||||
|
def _format_content_block(self,
|
||||||
|
content: str,
|
||||||
|
options: Optional[FoldingOptions]) -> str:
|
||||||
|
"""格式化内容块,处理代码块等特殊格式"""
|
||||||
|
if not options:
|
||||||
|
return content
|
||||||
|
|
||||||
|
if options.code_language:
|
||||||
|
return f"```{options.code_language}\n{content}\n```"
|
||||||
|
return content
|
||||||
|
|
||||||
|
def _add_indent(self, text: str, level: int) -> str:
|
||||||
|
"""添加缩进"""
|
||||||
|
if level <= 0:
|
||||||
|
return text
|
||||||
|
indent = " " * level
|
||||||
|
return "\n".join(indent + line for line in text.splitlines())
|
||||||
|
|
||||||
|
|
||||||
|
class FileContentFormatter(ContentFormatter[FileMetadata]):
|
||||||
|
"""文件内容格式化器"""
|
||||||
|
|
||||||
|
def format(self,
|
||||||
|
content: str,
|
||||||
|
metadata: FileMetadata,
|
||||||
|
options: Optional[FoldingOptions] = None) -> str:
|
||||||
|
"""格式化文件内容"""
|
||||||
|
if not metadata.validate():
|
||||||
|
raise MetadataError("Invalid file metadata")
|
||||||
|
|
||||||
|
try:
|
||||||
|
options = options or FoldingOptions()
|
||||||
|
|
||||||
|
# 构建摘要信息
|
||||||
|
summary_parts = [
|
||||||
|
f"{metadata.rel_path} ({metadata.size:.2f}MB)",
|
||||||
|
f"Type: {metadata.mime_type}" if metadata.mime_type else None,
|
||||||
|
(f"Modified: {metadata.last_modified.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||||
|
if metadata.last_modified and options.show_timestamp else None)
|
||||||
|
]
|
||||||
|
summary = " | ".join(filter(None, summary_parts))
|
||||||
|
|
||||||
|
# 构建HTML类
|
||||||
|
css_class = f' class="{options.custom_css}"' if options.custom_css else ''
|
||||||
|
|
||||||
|
# 格式化内容
|
||||||
|
formatted_content = self._format_content_block(content, options)
|
||||||
|
|
||||||
|
# 组装最终结果
|
||||||
|
result = (
|
||||||
|
f'<details{css_class}><summary>{summary}</summary>\n\n'
|
||||||
|
f'{formatted_content}\n\n'
|
||||||
|
f'</details>\n\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._add_indent(result, options.indent_level)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error formatting file content: {str(e)}")
|
||||||
|
raise FormattingError(f"Failed to format file content: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
class ContentFoldingManager:
|
||||||
|
"""内容折叠管理器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""初始化折叠管理器"""
|
||||||
|
self._formatters: Dict[str, ContentFormatter] = {}
|
||||||
|
self._register_default_formatters()
|
||||||
|
|
||||||
|
def _register_default_formatters(self) -> None:
|
||||||
|
"""注册默认的格式化器"""
|
||||||
|
self.register_formatter('file', FileContentFormatter())
|
||||||
|
|
||||||
|
def register_formatter(self, name: str, formatter: ContentFormatter) -> None:
|
||||||
|
"""注册新的格式化器"""
|
||||||
|
if not isinstance(formatter, ContentFormatter):
|
||||||
|
raise TypeError("Formatter must implement ContentFormatter interface")
|
||||||
|
self._formatters[name] = formatter
|
||||||
|
|
||||||
|
def _guess_language(self, extension: str) -> Optional[str]:
|
||||||
|
"""根据文件扩展名猜测编程语言"""
|
||||||
|
extension = extension.lower().lstrip('.')
|
||||||
|
language_map = {
|
||||||
|
'py': 'python',
|
||||||
|
'js': 'javascript',
|
||||||
|
'java': 'java',
|
||||||
|
'cpp': 'cpp',
|
||||||
|
'cs': 'csharp',
|
||||||
|
'html': 'html',
|
||||||
|
'css': 'css',
|
||||||
|
'md': 'markdown',
|
||||||
|
'json': 'json',
|
||||||
|
'xml': 'xml',
|
||||||
|
'sql': 'sql',
|
||||||
|
'sh': 'bash',
|
||||||
|
'yaml': 'yaml',
|
||||||
|
'yml': 'yaml',
|
||||||
|
'txt': None # 纯文本不需要语言标识
|
||||||
|
}
|
||||||
|
return language_map.get(extension)
|
||||||
|
|
||||||
|
def format_content(self,
|
||||||
|
content: str,
|
||||||
|
formatter_type: str,
|
||||||
|
metadata: Union[FileMetadata],
|
||||||
|
options: Optional[FoldingOptions] = None) -> str:
|
||||||
|
"""格式化内容"""
|
||||||
|
formatter = self._formatters.get(formatter_type)
|
||||||
|
if not formatter:
|
||||||
|
raise KeyError(f"No formatter registered for type: {formatter_type}")
|
||||||
|
|
||||||
|
if not isinstance(metadata, FileMetadata):
|
||||||
|
raise TypeError("Invalid metadata type")
|
||||||
|
|
||||||
|
return formatter.format(content, metadata, options)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PaperMetadata(BaseMetadata):
|
||||||
|
"""论文元数据"""
|
||||||
|
title: str
|
||||||
|
authors: str
|
||||||
|
abstract: str
|
||||||
|
catalogs: str
|
||||||
|
arxiv_id: str = ""
|
||||||
|
|
||||||
|
def validate(self) -> bool:
|
||||||
|
"""验证论文元数据的有效性"""
|
||||||
|
try:
|
||||||
|
if not self._validate_non_empty_str(self.title):
|
||||||
|
return False
|
||||||
|
if not self._validate_non_empty_str(self.authors):
|
||||||
|
return False
|
||||||
|
if not self._validate_non_empty_str(self.abstract):
|
||||||
|
return False
|
||||||
|
if not self._validate_non_empty_str(self.catalogs):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Paper metadata validation error: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class PaperContentFormatter(ContentFormatter[PaperMetadata]):
|
||||||
|
"""论文内容格式化器"""
|
||||||
|
|
||||||
|
def format(self,
|
||||||
|
fragments: list[SectionFragment],
|
||||||
|
metadata: PaperMetadata,
|
||||||
|
options: Optional[FoldingOptions] = None) -> str:
|
||||||
|
"""格式化论文内容
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fragments: 论文片段列表
|
||||||
|
metadata: 论文元数据
|
||||||
|
options: 折叠选项
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 格式化后的论文内容
|
||||||
|
"""
|
||||||
|
if not metadata.validate():
|
||||||
|
raise MetadataError("Invalid paper metadata")
|
||||||
|
|
||||||
|
try:
|
||||||
|
options = options or FoldingOptions()
|
||||||
|
|
||||||
|
# 1. 生成标题部分(不折叠)
|
||||||
|
result = [f"# {metadata.title}\n"]
|
||||||
|
|
||||||
|
# 2. 生成作者信息(折叠)
|
||||||
|
result.append(self._create_folded_section(
|
||||||
|
"Authors",
|
||||||
|
metadata.authors,
|
||||||
|
options
|
||||||
|
))
|
||||||
|
|
||||||
|
# 3. 生成摘要(折叠)
|
||||||
|
result.append(self._create_folded_section(
|
||||||
|
"Abstract",
|
||||||
|
metadata.abstract,
|
||||||
|
options
|
||||||
|
))
|
||||||
|
|
||||||
|
# 4. 生成目录树(折叠)
|
||||||
|
result.append(self._create_folded_section(
|
||||||
|
"Table of Contents",
|
||||||
|
f"```\n{metadata.catalogs}\n```",
|
||||||
|
options
|
||||||
|
))
|
||||||
|
|
||||||
|
# 5. 按章节组织并生成内容
|
||||||
|
sections = self._organize_sections(fragments)
|
||||||
|
for section, section_fragments in sections.items():
|
||||||
|
# 拼接该章节的所有内容
|
||||||
|
section_content = "\n\n".join(
|
||||||
|
fragment.content for fragment in section_fragments
|
||||||
|
)
|
||||||
|
|
||||||
|
result.append(self._create_folded_section(
|
||||||
|
section,
|
||||||
|
section_content,
|
||||||
|
options
|
||||||
|
))
|
||||||
|
|
||||||
|
# 6. 生成参考文献(折叠)
|
||||||
|
# 收集所有非空的参考文献
|
||||||
|
all_refs = "\n".join(filter(None,
|
||||||
|
(fragment.bibliography for fragment in fragments)
|
||||||
|
))
|
||||||
|
if all_refs:
|
||||||
|
result.append(self._create_folded_section(
|
||||||
|
"Bibliography",
|
||||||
|
f"```bibtex\n{all_refs}\n```",
|
||||||
|
options
|
||||||
|
))
|
||||||
|
|
||||||
|
return "\n\n".join(result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error formatting paper content: {str(e)}")
|
||||||
|
raise FormattingError(f"Failed to format paper content: {str(e)}")
|
||||||
|
|
||||||
|
def _create_folded_section(self,
|
||||||
|
title: str,
|
||||||
|
content: str,
|
||||||
|
options: FoldingOptions) -> str:
|
||||||
|
"""创建折叠区块
|
||||||
|
|
||||||
|
Args:
|
||||||
|
title: 区块标题
|
||||||
|
content: 区块内容
|
||||||
|
options: 折叠选项
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 格式化后的折叠区块
|
||||||
|
"""
|
||||||
|
css_class = f' class="{options.custom_css}"' if options.custom_css else ''
|
||||||
|
|
||||||
|
result = (
|
||||||
|
f'<details{css_class}><summary>{title}</summary>\n\n'
|
||||||
|
f'{content}\n\n'
|
||||||
|
f'</details>'
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._add_indent(result, options.indent_level)
|
||||||
|
|
||||||
|
def _organize_sections(self,
|
||||||
|
fragments: list[SectionFragment]
|
||||||
|
) -> Dict[str, list[SectionFragment]]:
|
||||||
|
"""将片段按章节分组
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fragments: 论文片段列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, list[SectionFragment]]: 按章节分组的片段字典
|
||||||
|
"""
|
||||||
|
sections: Dict[str, list[SectionFragment]] = {}
|
||||||
|
|
||||||
|
for fragment in fragments:
|
||||||
|
section = fragment.current_section or "Uncategorized"
|
||||||
|
if section not in sections:
|
||||||
|
sections[section] = []
|
||||||
|
sections[section].append(fragment)
|
||||||
|
|
||||||
|
return sections
|
||||||
@@ -0,0 +1,354 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Dict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SectionFragment:
|
||||||
|
"""Arxiv论文片段数据类"""
|
||||||
|
title: str
|
||||||
|
authors: str
|
||||||
|
abstract: str
|
||||||
|
catalogs: str
|
||||||
|
arxiv_id: str = ""
|
||||||
|
current_section: str = "Introduction"
|
||||||
|
content: str = ''
|
||||||
|
bibliography: str = ''
|
||||||
|
|
||||||
|
|
||||||
|
class PaperHtmlFormatter:
|
||||||
|
"""HTML格式论文文档生成器"""
|
||||||
|
|
||||||
|
def __init__(self, fragments: List[SectionFragment], output_dir: Path):
|
||||||
|
self.fragments = fragments
|
||||||
|
self.output_dir = output_dir
|
||||||
|
self.css_styles = """
|
||||||
|
:root {
|
||||||
|
--primary-color: #1a73e8;
|
||||||
|
--secondary-color: #34495e;
|
||||||
|
--background-color: #f8f9fa;
|
||||||
|
--text-color: #2c3e50;
|
||||||
|
--border-color: #e0e0e0;
|
||||||
|
--code-bg-color: #f6f8fa;
|
||||||
|
}
|
||||||
|
|
||||||
|
body {
|
||||||
|
font-family: "Source Serif Pro", "Times New Roman", serif;
|
||||||
|
line-height: 1.8;
|
||||||
|
max-width: 1000px;
|
||||||
|
margin: 0 auto;
|
||||||
|
padding: 2rem;
|
||||||
|
color: var(--text-color);
|
||||||
|
background-color: var(--background-color);
|
||||||
|
font-size: 16px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.container {
|
||||||
|
background: white;
|
||||||
|
padding: 2rem;
|
||||||
|
border-radius: 8px;
|
||||||
|
box-shadow: 0 2px 12px rgba(0,0,0,0.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
h1 {
|
||||||
|
color: var(--primary-color);
|
||||||
|
font-size: 2.2em;
|
||||||
|
text-align: center;
|
||||||
|
margin: 1.5rem 0;
|
||||||
|
padding-bottom: 1rem;
|
||||||
|
border-bottom: 3px solid var(--primary-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
h2 {
|
||||||
|
color: var(--secondary-color);
|
||||||
|
font-size: 1.8em;
|
||||||
|
margin-top: 2rem;
|
||||||
|
padding-left: 1rem;
|
||||||
|
border-left: 4px solid var(--primary-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
h3 {
|
||||||
|
color: var(--text-color);
|
||||||
|
font-size: 1.5em;
|
||||||
|
margin-top: 1.5rem;
|
||||||
|
border-bottom: 2px solid var(--border-color);
|
||||||
|
padding-bottom: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.authors {
|
||||||
|
text-align: center;
|
||||||
|
color: var(--secondary-color);
|
||||||
|
font-size: 1.1em;
|
||||||
|
margin: 1rem 0 2rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.abstract-container {
|
||||||
|
background: var(--background-color);
|
||||||
|
padding: 1.5rem;
|
||||||
|
border-radius: 6px;
|
||||||
|
margin: 2rem 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.abstract-title {
|
||||||
|
font-weight: bold;
|
||||||
|
color: var(--primary-color);
|
||||||
|
margin-bottom: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.abstract-content {
|
||||||
|
font-style: italic;
|
||||||
|
line-height: 1.7;
|
||||||
|
}
|
||||||
|
|
||||||
|
.toc {
|
||||||
|
background: white;
|
||||||
|
padding: 1.5rem;
|
||||||
|
border-radius: 6px;
|
||||||
|
margin: 2rem 0;
|
||||||
|
box-shadow: 0 2px 8px rgba(0,0,0,0.05);
|
||||||
|
}
|
||||||
|
|
||||||
|
.toc-title {
|
||||||
|
color: var(--primary-color);
|
||||||
|
font-size: 1.4em;
|
||||||
|
margin-bottom: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.section-content {
|
||||||
|
background: white;
|
||||||
|
padding: 1.5rem;
|
||||||
|
border-radius: 6px;
|
||||||
|
margin: 1.5rem 0;
|
||||||
|
box-shadow: 0 1px 3px rgba(0,0,0,0.05);
|
||||||
|
}
|
||||||
|
|
||||||
|
.fragment {
|
||||||
|
margin: 2rem 0;
|
||||||
|
padding-left: 1rem;
|
||||||
|
border-left: 3px solid var(--border-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
.fragment:hover {
|
||||||
|
border-left-color: var(--primary-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
.bibliography {
|
||||||
|
background: var(--code-bg-color);
|
||||||
|
padding: 1rem;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-family: "Source Code Pro", monospace;
|
||||||
|
font-size: 0.9em;
|
||||||
|
white-space: pre-wrap;
|
||||||
|
margin-top: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
pre {
|
||||||
|
background: var(--code-bg-color);
|
||||||
|
padding: 1rem;
|
||||||
|
border-radius: 4px;
|
||||||
|
overflow-x: auto;
|
||||||
|
font-family: "Source Code Pro", monospace;
|
||||||
|
}
|
||||||
|
|
||||||
|
.paper-info {
|
||||||
|
background: white;
|
||||||
|
padding: 2rem;
|
||||||
|
border-radius: 8px;
|
||||||
|
margin: 2rem 0;
|
||||||
|
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
.arxiv-id {
|
||||||
|
text-align: center;
|
||||||
|
color: #666;
|
||||||
|
font-size: 0.9em;
|
||||||
|
margin: 1rem 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.section-title {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 0.5rem;
|
||||||
|
color: var(--secondary-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
.section-icon {
|
||||||
|
color: var(--primary-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
@media print {
|
||||||
|
body {
|
||||||
|
background: white;
|
||||||
|
}
|
||||||
|
.container {
|
||||||
|
box-shadow: none;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _sanitize_html(self, text: str) -> str:
|
||||||
|
"""清理HTML特殊字符"""
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
replacements = {
|
||||||
|
"&": "&",
|
||||||
|
"<": "<",
|
||||||
|
">": ">",
|
||||||
|
'"': """,
|
||||||
|
"'": "'"
|
||||||
|
}
|
||||||
|
|
||||||
|
for old, new in replacements.items():
|
||||||
|
text = text.replace(old, new)
|
||||||
|
return text
|
||||||
|
|
||||||
|
def _create_section_id(self, section: str) -> str:
|
||||||
|
"""创建section的ID"""
|
||||||
|
section = section.strip() or "uncategorized"
|
||||||
|
# 移除特殊字符,转换为小写并用连字符替换空格
|
||||||
|
section_id = re.sub(r'[^\w\s-]', '', section.lower())
|
||||||
|
return section_id.replace(' ', '-')
|
||||||
|
|
||||||
|
def format_paper_info(self) -> str:
|
||||||
|
"""格式化论文基本信息"""
|
||||||
|
if not self.fragments:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
first_fragment = self.fragments[0]
|
||||||
|
paper_info = ['<div class="paper-info">']
|
||||||
|
|
||||||
|
# 添加标题
|
||||||
|
if first_fragment.title:
|
||||||
|
paper_info.append(f'<h1>{self._sanitize_html(first_fragment.title)}</h1>')
|
||||||
|
|
||||||
|
# 添加arXiv ID
|
||||||
|
if first_fragment.arxiv_id:
|
||||||
|
paper_info.append(f'<div class="arxiv-id">arXiv: {self._sanitize_html(first_fragment.arxiv_id)}</div>')
|
||||||
|
|
||||||
|
# 添加作者
|
||||||
|
if first_fragment.authors:
|
||||||
|
paper_info.append(f'<div class="authors">{self._sanitize_html(first_fragment.authors)}</div>')
|
||||||
|
|
||||||
|
# 添加摘要
|
||||||
|
if first_fragment.abstract:
|
||||||
|
paper_info.append('<div class="abstract-container">')
|
||||||
|
paper_info.append('<div class="abstract-title">Abstract</div>')
|
||||||
|
paper_info.append(f'<div class="abstract-content">{self._sanitize_html(first_fragment.abstract)}</div>')
|
||||||
|
paper_info.append('</div>')
|
||||||
|
|
||||||
|
# 添加目录结构
|
||||||
|
if first_fragment.catalogs:
|
||||||
|
paper_info.append('<h2>Document Structure</h2>')
|
||||||
|
paper_info.append('<pre>')
|
||||||
|
paper_info.append(self._sanitize_html(first_fragment.catalogs))
|
||||||
|
paper_info.append('</pre>')
|
||||||
|
|
||||||
|
paper_info.append('</div>')
|
||||||
|
return '\n'.join(paper_info)
|
||||||
|
|
||||||
|
def format_table_of_contents(self, sections: Dict[str, List[SectionFragment]]) -> str:
|
||||||
|
"""生成目录"""
|
||||||
|
toc = ['<div class="toc">']
|
||||||
|
toc.append('<div class="toc-title">Table of Contents</div>')
|
||||||
|
toc.append('<nav>')
|
||||||
|
|
||||||
|
for section in sections:
|
||||||
|
section_id = self._create_section_id(section)
|
||||||
|
clean_section = section.strip() or "Uncategorized"
|
||||||
|
toc.append(f'<div><a href="#{section_id}">{self._sanitize_html(clean_section)} '
|
||||||
|
f'</a></div>')
|
||||||
|
|
||||||
|
toc.append('</nav>')
|
||||||
|
toc.append('</div>')
|
||||||
|
return '\n'.join(toc)
|
||||||
|
|
||||||
|
def format_sections(self) -> str:
|
||||||
|
"""格式化论文各部分内容"""
|
||||||
|
sections = {}
|
||||||
|
for fragment in self.fragments:
|
||||||
|
section = fragment.current_section or "Uncategorized"
|
||||||
|
if section not in sections:
|
||||||
|
sections[section] = []
|
||||||
|
sections[section].append(fragment)
|
||||||
|
|
||||||
|
formatted_html = ['<div class="content">']
|
||||||
|
formatted_html.append(self.format_table_of_contents(sections))
|
||||||
|
|
||||||
|
# 生成各部分内容
|
||||||
|
for section, fragments in sections.items():
|
||||||
|
section_id = self._create_section_id(section)
|
||||||
|
formatted_html.append(f'<h2 id="{section_id}">')
|
||||||
|
formatted_html.append(f'<span class="section-title">')
|
||||||
|
formatted_html.append(f'<span class="section-icon">§</span>')
|
||||||
|
formatted_html.append(f'{self._sanitize_html(section)}')
|
||||||
|
formatted_html.append('</span>')
|
||||||
|
formatted_html.append('</h2>')
|
||||||
|
|
||||||
|
formatted_html.append('<div class="section-content">')
|
||||||
|
|
||||||
|
for i, fragment in enumerate(fragments, 1):
|
||||||
|
formatted_html.append('<div class="fragment">')
|
||||||
|
|
||||||
|
# 添加内容
|
||||||
|
if fragment.content:
|
||||||
|
formatted_html.append(
|
||||||
|
f'<div class="fragment-content">{self._sanitize_html(fragment.content)}</div>'
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加参考文献
|
||||||
|
if fragment.bibliography:
|
||||||
|
formatted_html.append('<div class="bibliography">')
|
||||||
|
formatted_html.append(f'{self._sanitize_html(fragment.bibliography)}')
|
||||||
|
formatted_html.append('</div>')
|
||||||
|
|
||||||
|
formatted_html.append('</div>')
|
||||||
|
|
||||||
|
formatted_html.append('</div>')
|
||||||
|
|
||||||
|
formatted_html.append('</div>')
|
||||||
|
return '\n'.join(formatted_html)
|
||||||
|
|
||||||
|
def save_html(self) -> Path:
|
||||||
|
"""保存HTML文档"""
|
||||||
|
try:
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
filename = f"paper_content_{timestamp}.html"
|
||||||
|
file_path = self.output_dir / filename
|
||||||
|
|
||||||
|
html_content = f"""
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="utf-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||||
|
<title>{self._sanitize_html(self.fragments[0].title if self.fragments else 'Paper Content')}</title>
|
||||||
|
<style>
|
||||||
|
{self.css_styles}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
{self.format_paper_info()}
|
||||||
|
{self.format_sections()}
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
|
||||||
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(html_content)
|
||||||
|
|
||||||
|
print(f"HTML document saved to: {file_path}")
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error saving HTML document: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# 使用示例:
|
||||||
|
# formatter = PaperHtmlFormatter(fragments, output_dir)
|
||||||
|
# output_path = formatter.save_html()
|
||||||
@@ -3,7 +3,7 @@ import re
|
|||||||
import shutil
|
import shutil
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from toolbox import update_ui, update_ui_lastest_msg, get_log_folder
|
from toolbox import update_ui, update_ui_lastest_msg, get_log_folder, gen_time_str
|
||||||
from toolbox import get_conf, promote_file_to_downloadzone
|
from toolbox import get_conf, promote_file_to_downloadzone
|
||||||
from crazy_functions.latex_fns.latex_toolbox import PRESERVE, TRANSFORM
|
from crazy_functions.latex_fns.latex_toolbox import PRESERVE, TRANSFORM
|
||||||
from crazy_functions.latex_fns.latex_toolbox import set_forbidden_text, set_forbidden_text_begin_end, set_forbidden_text_careful_brace
|
from crazy_functions.latex_fns.latex_toolbox import set_forbidden_text, set_forbidden_text_begin_end, set_forbidden_text_careful_brace
|
||||||
@@ -468,3 +468,70 @@ def write_html(sp_file_contents, sp_file_result, chatbot, project_folder):
|
|||||||
except:
|
except:
|
||||||
from toolbox import trimmed_format_exc
|
from toolbox import trimmed_format_exc
|
||||||
logger.error('writing html result failed:', trimmed_format_exc())
|
logger.error('writing html result failed:', trimmed_format_exc())
|
||||||
|
|
||||||
|
|
||||||
|
def upload_to_gptac_cloud_if_user_allow(chatbot, arxiv_id):
|
||||||
|
try:
|
||||||
|
# 如果用户允许,我们将arxiv论文PDF上传到GPTAC学术云
|
||||||
|
from toolbox import map_file_to_sha256
|
||||||
|
# 检查是否顺利,如果没有生成预期的文件,则跳过
|
||||||
|
is_result_good = False
|
||||||
|
for file_path in chatbot._cookies.get("files_to_promote", []):
|
||||||
|
if file_path.endswith('translate_zh.pdf'):
|
||||||
|
is_result_good = True
|
||||||
|
if not is_result_good:
|
||||||
|
return
|
||||||
|
# 上传文件
|
||||||
|
for file_path in chatbot._cookies.get("files_to_promote", []):
|
||||||
|
align_name = None
|
||||||
|
# normalized name
|
||||||
|
for name in ['translate_zh.pdf', 'comparison.pdf']:
|
||||||
|
if file_path.endswith(name): align_name = name
|
||||||
|
# if match any align name
|
||||||
|
if align_name:
|
||||||
|
logger.info(f'Uploading to GPTAC cloud as the user has set `allow_cloud_io`: {file_path}')
|
||||||
|
with open(file_path, 'rb') as f:
|
||||||
|
import requests
|
||||||
|
url = 'https://cloud-2.agent-matrix.com/arxiv_tf_paper_normal_upload'
|
||||||
|
files = {'file': (align_name, f, 'application/octet-stream')}
|
||||||
|
data = {
|
||||||
|
'arxiv_id': arxiv_id,
|
||||||
|
'file_hash': map_file_to_sha256(file_path),
|
||||||
|
'language': 'zh',
|
||||||
|
'trans_prompt': 'to_be_implemented',
|
||||||
|
'llm_model': 'to_be_implemented',
|
||||||
|
'llm_model_param': 'to_be_implemented',
|
||||||
|
}
|
||||||
|
resp = requests.post(url=url, files=files, data=data, timeout=30)
|
||||||
|
logger.info(f'Uploading terminate ({resp.status_code})`: {file_path}')
|
||||||
|
except:
|
||||||
|
# 如果上传失败,不会中断程序,因为这是次要功能
|
||||||
|
pass
|
||||||
|
|
||||||
|
def check_gptac_cloud(arxiv_id, chatbot):
|
||||||
|
import requests
|
||||||
|
success = False
|
||||||
|
downloaded = []
|
||||||
|
try:
|
||||||
|
for pdf_target in ['translate_zh.pdf', 'comparison.pdf']:
|
||||||
|
url = 'https://cloud-2.agent-matrix.com/arxiv_tf_paper_normal_exist'
|
||||||
|
data = {
|
||||||
|
'arxiv_id': arxiv_id,
|
||||||
|
'name': pdf_target,
|
||||||
|
}
|
||||||
|
resp = requests.post(url=url, data=data)
|
||||||
|
cache_hit_result = resp.text.strip('"')
|
||||||
|
if cache_hit_result.startswith("http"):
|
||||||
|
url = cache_hit_result
|
||||||
|
logger.info(f'Downloading from GPTAC cloud: {url}')
|
||||||
|
resp = requests.get(url=url, timeout=30)
|
||||||
|
target = os.path.join(get_log_folder(plugin_name='gptac_cloud'), gen_time_str(), pdf_target)
|
||||||
|
os.makedirs(os.path.dirname(target), exist_ok=True)
|
||||||
|
with open(target, 'wb') as f:
|
||||||
|
f.write(resp.content)
|
||||||
|
new_path = promote_file_to_downloadzone(target, chatbot=chatbot)
|
||||||
|
success = True
|
||||||
|
downloaded.append(new_path)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return success, downloaded
|
||||||
|
|||||||
@@ -6,12 +6,16 @@ class SafeUnpickler(pickle.Unpickler):
|
|||||||
def get_safe_classes(self):
|
def get_safe_classes(self):
|
||||||
from crazy_functions.latex_fns.latex_actions import LatexPaperFileGroup, LatexPaperSplit
|
from crazy_functions.latex_fns.latex_actions import LatexPaperFileGroup, LatexPaperSplit
|
||||||
from crazy_functions.latex_fns.latex_toolbox import LinkedListNode
|
from crazy_functions.latex_fns.latex_toolbox import LinkedListNode
|
||||||
|
from numpy.core.multiarray import scalar
|
||||||
|
from numpy import dtype
|
||||||
# 定义允许的安全类
|
# 定义允许的安全类
|
||||||
safe_classes = {
|
safe_classes = {
|
||||||
# 在这里添加其他安全的类
|
# 在这里添加其他安全的类
|
||||||
'LatexPaperFileGroup': LatexPaperFileGroup,
|
'LatexPaperFileGroup': LatexPaperFileGroup,
|
||||||
'LatexPaperSplit': LatexPaperSplit,
|
'LatexPaperSplit': LatexPaperSplit,
|
||||||
'LinkedListNode': LinkedListNode,
|
'LinkedListNode': LinkedListNode,
|
||||||
|
'scalar': scalar,
|
||||||
|
'dtype': dtype,
|
||||||
}
|
}
|
||||||
return safe_classes
|
return safe_classes
|
||||||
|
|
||||||
@@ -22,8 +26,6 @@ class SafeUnpickler(pickle.Unpickler):
|
|||||||
for class_name in self.safe_classes.keys():
|
for class_name in self.safe_classes.keys():
|
||||||
if (class_name in f'{module}.{name}'):
|
if (class_name in f'{module}.{name}'):
|
||||||
match_class_name = class_name
|
match_class_name = class_name
|
||||||
if module == 'numpy' or module.startswith('numpy.'):
|
|
||||||
return super().find_class(module, name)
|
|
||||||
if match_class_name is not None:
|
if match_class_name is not None:
|
||||||
return self.safe_classes[match_class_name]
|
return self.safe_classes[match_class_name]
|
||||||
# 如果尝试加载未授权的类,则抛出异常
|
# 如果尝试加载未授权的类,则抛出异常
|
||||||
|
|||||||
@@ -712,134 +712,137 @@ def _merge_pdfs_ng(pdf1_path, pdf2_path, output_path):
|
|||||||
# 内部链接:跳转到文档中的某个页面
|
# 内部链接:跳转到文档中的某个页面
|
||||||
dest = action.get("/D") # 目标页或目标位置
|
dest = action.get("/D") # 目标页或目标位置
|
||||||
# if dest and annot.idnum in page2_annot_id:
|
# if dest and annot.idnum in page2_annot_id:
|
||||||
if dest in pdf2_reader.named_destinations:
|
# if dest in pdf2_reader.named_destinations:
|
||||||
# 获取原始文件中跳转信息,包括跳转页面
|
if dest and page2.annotations:
|
||||||
destination = pdf2_reader.named_destinations[
|
if annot in page2.annotations:
|
||||||
dest
|
# 获取原始文件中跳转信息,包括跳转页面
|
||||||
]
|
destination = pdf2_reader.named_destinations[
|
||||||
page_number = (
|
dest
|
||||||
pdf2_reader.get_destination_page_number(
|
|
||||||
destination
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# 更新跳转信息,跳转到对应的页面和,指定坐标 (100, 150),缩放比例为 100%
|
|
||||||
# “/D”:[10,'/XYZ',100,100,0]
|
|
||||||
if destination.dest_array[1] == "/XYZ":
|
|
||||||
annot_obj["/A"].update(
|
|
||||||
{
|
|
||||||
NameObject("/D"): ArrayObject(
|
|
||||||
[
|
|
||||||
NumberObject(page_number),
|
|
||||||
destination.dest_array[1],
|
|
||||||
FloatObject(
|
|
||||||
destination.dest_array[
|
|
||||||
2
|
|
||||||
]
|
|
||||||
+ int(
|
|
||||||
page1.mediaBox.getWidth()
|
|
||||||
)
|
|
||||||
),
|
|
||||||
destination.dest_array[3],
|
|
||||||
destination.dest_array[4],
|
|
||||||
]
|
|
||||||
) # 确保键和值是 PdfObject
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
annot_obj["/A"].update(
|
|
||||||
{
|
|
||||||
NameObject("/D"): ArrayObject(
|
|
||||||
[
|
|
||||||
NumberObject(page_number),
|
|
||||||
destination.dest_array[1],
|
|
||||||
]
|
|
||||||
) # 确保键和值是 PdfObject
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
rect = annot_obj.get("/Rect")
|
|
||||||
# 更新点击坐标
|
|
||||||
rect = ArrayObject(
|
|
||||||
[
|
|
||||||
FloatObject(
|
|
||||||
rect[0]
|
|
||||||
+ int(page1.mediaBox.getWidth())
|
|
||||||
),
|
|
||||||
rect[1],
|
|
||||||
FloatObject(
|
|
||||||
rect[2]
|
|
||||||
+ int(page1.mediaBox.getWidth())
|
|
||||||
),
|
|
||||||
rect[3],
|
|
||||||
]
|
]
|
||||||
)
|
page_number = (
|
||||||
annot_obj.update(
|
pdf2_reader.get_destination_page_number(
|
||||||
{
|
destination
|
||||||
NameObject(
|
)
|
||||||
"/Rect"
|
)
|
||||||
): rect # 确保键和值是 PdfObject
|
# 更新跳转信息,跳转到对应的页面和,指定坐标 (100, 150),缩放比例为 100%
|
||||||
}
|
# “/D”:[10,'/XYZ',100,100,0]
|
||||||
)
|
if destination.dest_array[1] == "/XYZ":
|
||||||
|
annot_obj["/A"].update(
|
||||||
|
{
|
||||||
|
NameObject("/D"): ArrayObject(
|
||||||
|
[
|
||||||
|
NumberObject(page_number),
|
||||||
|
destination.dest_array[1],
|
||||||
|
FloatObject(
|
||||||
|
destination.dest_array[
|
||||||
|
2
|
||||||
|
]
|
||||||
|
+ int(
|
||||||
|
page1.mediaBox.getWidth()
|
||||||
|
)
|
||||||
|
),
|
||||||
|
destination.dest_array[3],
|
||||||
|
destination.dest_array[4],
|
||||||
|
]
|
||||||
|
) # 确保键和值是 PdfObject
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
annot_obj["/A"].update(
|
||||||
|
{
|
||||||
|
NameObject("/D"): ArrayObject(
|
||||||
|
[
|
||||||
|
NumberObject(page_number),
|
||||||
|
destination.dest_array[1],
|
||||||
|
]
|
||||||
|
) # 确保键和值是 PdfObject
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
rect = annot_obj.get("/Rect")
|
||||||
|
# 更新点击坐标
|
||||||
|
rect = ArrayObject(
|
||||||
|
[
|
||||||
|
FloatObject(
|
||||||
|
rect[0]
|
||||||
|
+ int(page1.mediaBox.getWidth())
|
||||||
|
),
|
||||||
|
rect[1],
|
||||||
|
FloatObject(
|
||||||
|
rect[2]
|
||||||
|
+ int(page1.mediaBox.getWidth())
|
||||||
|
),
|
||||||
|
rect[3],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
annot_obj.update(
|
||||||
|
{
|
||||||
|
NameObject(
|
||||||
|
"/Rect"
|
||||||
|
): rect # 确保键和值是 PdfObject
|
||||||
|
}
|
||||||
|
)
|
||||||
# if dest and annot.idnum in page1_annot_id:
|
# if dest and annot.idnum in page1_annot_id:
|
||||||
if dest in pdf1_reader.named_destinations:
|
# if dest in pdf1_reader.named_destinations:
|
||||||
|
if dest and page1.annotations:
|
||||||
# 获取原始文件中跳转信息,包括跳转页面
|
if annot in page1.annotations:
|
||||||
destination = pdf1_reader.named_destinations[
|
# 获取原始文件中跳转信息,包括跳转页面
|
||||||
dest
|
destination = pdf1_reader.named_destinations[
|
||||||
]
|
dest
|
||||||
page_number = (
|
|
||||||
pdf1_reader.get_destination_page_number(
|
|
||||||
destination
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# 更新跳转信息,跳转到对应的页面和,指定坐标 (100, 150),缩放比例为 100%
|
|
||||||
# “/D”:[10,'/XYZ',100,100,0]
|
|
||||||
if destination.dest_array[1] == "/XYZ":
|
|
||||||
annot_obj["/A"].update(
|
|
||||||
{
|
|
||||||
NameObject("/D"): ArrayObject(
|
|
||||||
[
|
|
||||||
NumberObject(page_number),
|
|
||||||
destination.dest_array[1],
|
|
||||||
FloatObject(
|
|
||||||
destination.dest_array[
|
|
||||||
2
|
|
||||||
]
|
|
||||||
),
|
|
||||||
destination.dest_array[3],
|
|
||||||
destination.dest_array[4],
|
|
||||||
]
|
|
||||||
) # 确保键和值是 PdfObject
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
annot_obj["/A"].update(
|
|
||||||
{
|
|
||||||
NameObject("/D"): ArrayObject(
|
|
||||||
[
|
|
||||||
NumberObject(page_number),
|
|
||||||
destination.dest_array[1],
|
|
||||||
]
|
|
||||||
) # 确保键和值是 PdfObject
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
rect = annot_obj.get("/Rect")
|
|
||||||
rect = ArrayObject(
|
|
||||||
[
|
|
||||||
FloatObject(rect[0]),
|
|
||||||
rect[1],
|
|
||||||
FloatObject(rect[2]),
|
|
||||||
rect[3],
|
|
||||||
]
|
]
|
||||||
)
|
page_number = (
|
||||||
annot_obj.update(
|
pdf1_reader.get_destination_page_number(
|
||||||
{
|
destination
|
||||||
NameObject(
|
)
|
||||||
"/Rect"
|
)
|
||||||
): rect # 确保键和值是 PdfObject
|
# 更新跳转信息,跳转到对应的页面和,指定坐标 (100, 150),缩放比例为 100%
|
||||||
}
|
# “/D”:[10,'/XYZ',100,100,0]
|
||||||
)
|
if destination.dest_array[1] == "/XYZ":
|
||||||
|
annot_obj["/A"].update(
|
||||||
|
{
|
||||||
|
NameObject("/D"): ArrayObject(
|
||||||
|
[
|
||||||
|
NumberObject(page_number),
|
||||||
|
destination.dest_array[1],
|
||||||
|
FloatObject(
|
||||||
|
destination.dest_array[
|
||||||
|
2
|
||||||
|
]
|
||||||
|
),
|
||||||
|
destination.dest_array[3],
|
||||||
|
destination.dest_array[4],
|
||||||
|
]
|
||||||
|
) # 确保键和值是 PdfObject
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
annot_obj["/A"].update(
|
||||||
|
{
|
||||||
|
NameObject("/D"): ArrayObject(
|
||||||
|
[
|
||||||
|
NumberObject(page_number),
|
||||||
|
destination.dest_array[1],
|
||||||
|
]
|
||||||
|
) # 确保键和值是 PdfObject
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
rect = annot_obj.get("/Rect")
|
||||||
|
rect = ArrayObject(
|
||||||
|
[
|
||||||
|
FloatObject(rect[0]),
|
||||||
|
rect[1],
|
||||||
|
FloatObject(rect[2]),
|
||||||
|
rect[3],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
annot_obj.update(
|
||||||
|
{
|
||||||
|
NameObject(
|
||||||
|
"/Rect"
|
||||||
|
): rect # 确保键和值是 PdfObject
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
elif "/S" in action and action["/S"] == "/URI":
|
elif "/S" in action and action["/S"] == "/URI":
|
||||||
# 外部链接:跳转到某个URI
|
# 外部链接:跳转到某个URI
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ from toolbox import promote_file_to_downloadzone, extract_archive
|
|||||||
from toolbox import generate_file_link, zip_folder
|
from toolbox import generate_file_link, zip_folder
|
||||||
from crazy_functions.crazy_utils import get_files_from_everything
|
from crazy_functions.crazy_utils import get_files_from_everything
|
||||||
from shared_utils.colorful import *
|
from shared_utils.colorful import *
|
||||||
|
from loguru import logger
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
def refresh_key(doc2x_api_key):
|
def refresh_key(doc2x_api_key):
|
||||||
import requests, json
|
import requests, json
|
||||||
@@ -22,105 +24,140 @@ def refresh_key(doc2x_api_key):
|
|||||||
raise RuntimeError(format("[ERROR] status code: %d, body: %s" % (res.status_code, res.text)))
|
raise RuntimeError(format("[ERROR] status code: %d, body: %s" % (res.status_code, res.text)))
|
||||||
return doc2x_api_key
|
return doc2x_api_key
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def 解析PDF_DOC2X_转Latex(pdf_file_path):
|
def 解析PDF_DOC2X_转Latex(pdf_file_path):
|
||||||
|
zip_file_path, unzipped_folder = 解析PDF_DOC2X(pdf_file_path, format='tex')
|
||||||
|
return unzipped_folder
|
||||||
|
|
||||||
|
|
||||||
|
def 解析PDF_DOC2X(pdf_file_path, format='tex'):
|
||||||
|
"""
|
||||||
|
format: 'tex', 'md', 'docx'
|
||||||
|
"""
|
||||||
import requests, json, os
|
import requests, json, os
|
||||||
DOC2X_API_KEY = get_conf('DOC2X_API_KEY')
|
DOC2X_API_KEY = get_conf('DOC2X_API_KEY')
|
||||||
latex_dir = get_log_folder(plugin_name="pdf_ocr_latex")
|
latex_dir = get_log_folder(plugin_name="pdf_ocr_latex")
|
||||||
|
markdown_dir = get_log_folder(plugin_name="pdf_ocr")
|
||||||
doc2x_api_key = DOC2X_API_KEY
|
doc2x_api_key = DOC2X_API_KEY
|
||||||
if doc2x_api_key.startswith('sk-'):
|
|
||||||
url = "https://api.doc2x.noedgeai.com/api/v1/pdf"
|
|
||||||
else:
|
|
||||||
doc2x_api_key = refresh_key(doc2x_api_key)
|
|
||||||
url = "https://api.doc2x.noedgeai.com/api/platform/pdf"
|
|
||||||
|
|
||||||
|
|
||||||
|
# < ------ 第1步:上传 ------ >
|
||||||
|
logger.info("Doc2x 第1步:上传")
|
||||||
|
with open(pdf_file_path, 'rb') as file:
|
||||||
|
res = requests.post(
|
||||||
|
"https://v2.doc2x.noedgeai.com/api/v2/parse/pdf",
|
||||||
|
headers={"Authorization": "Bearer " + doc2x_api_key},
|
||||||
|
data=file
|
||||||
|
)
|
||||||
|
# res_json = []
|
||||||
|
if res.status_code == 200:
|
||||||
|
res_json = res.json()
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Doc2x return an error: {res.json()}")
|
||||||
|
uuid = res_json['data']['uid']
|
||||||
|
|
||||||
|
# < ------ 第2步:轮询等待 ------ >
|
||||||
|
logger.info("Doc2x 第2步:轮询等待")
|
||||||
|
params = {'uid': uuid}
|
||||||
|
while True:
|
||||||
|
res = requests.get(
|
||||||
|
'https://v2.doc2x.noedgeai.com/api/v2/parse/status',
|
||||||
|
headers={"Authorization": "Bearer " + doc2x_api_key},
|
||||||
|
params=params
|
||||||
|
)
|
||||||
|
res_json = res.json()
|
||||||
|
if res_json['data']['status'] == "success":
|
||||||
|
break
|
||||||
|
elif res_json['data']['status'] == "processing":
|
||||||
|
time.sleep(3)
|
||||||
|
logger.info(f"Doc2x is processing at {res_json['data']['progress']}%")
|
||||||
|
elif res_json['data']['status'] == "failed":
|
||||||
|
raise RuntimeError(f"Doc2x return an error: {res_json}")
|
||||||
|
|
||||||
|
|
||||||
|
# < ------ 第3步:提交转化 ------ >
|
||||||
|
logger.info("Doc2x 第3步:提交转化")
|
||||||
|
data = {
|
||||||
|
"uid": uuid,
|
||||||
|
"to": format,
|
||||||
|
"formula_mode": "dollar",
|
||||||
|
"filename": "output"
|
||||||
|
}
|
||||||
res = requests.post(
|
res = requests.post(
|
||||||
url,
|
'https://v2.doc2x.noedgeai.com/api/v2/convert/parse',
|
||||||
files={"file": open(pdf_file_path, "rb")},
|
headers={"Authorization": "Bearer " + doc2x_api_key},
|
||||||
data={"ocr": "1"},
|
json=data
|
||||||
headers={"Authorization": "Bearer " + doc2x_api_key}
|
|
||||||
)
|
)
|
||||||
res_json = []
|
|
||||||
if res.status_code == 200:
|
if res.status_code == 200:
|
||||||
decoded = res.content.decode("utf-8")
|
res_json = res.json()
|
||||||
for z_decoded in decoded.split('\n'):
|
|
||||||
if len(z_decoded) == 0: continue
|
|
||||||
assert z_decoded.startswith("data: ")
|
|
||||||
z_decoded = z_decoded[len("data: "):]
|
|
||||||
decoded_json = json.loads(z_decoded)
|
|
||||||
res_json.append(decoded_json)
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(format("[ERROR] status code: %d, body: %s" % (res.status_code, res.text)))
|
raise RuntimeError(f"Doc2x return an error: {res.json()}")
|
||||||
|
|
||||||
uuid = res_json[0]['uuid']
|
|
||||||
to = "latex" # latex, md, docx
|
|
||||||
url = "https://api.doc2x.noedgeai.com/api/export"+"?request_id="+uuid+"&to="+to
|
|
||||||
|
|
||||||
res = requests.get(url, headers={"Authorization": "Bearer " + doc2x_api_key})
|
# < ------ 第4步:等待结果 ------ >
|
||||||
latex_zip_path = os.path.join(latex_dir, gen_time_str() + '.zip')
|
logger.info("Doc2x 第4步:等待结果")
|
||||||
latex_unzip_path = os.path.join(latex_dir, gen_time_str())
|
params = {'uid': uuid}
|
||||||
if res.status_code == 200:
|
while True:
|
||||||
with open(latex_zip_path, "wb") as f: f.write(res.content)
|
res = requests.get(
|
||||||
else:
|
'https://v2.doc2x.noedgeai.com/api/v2/convert/parse/result',
|
||||||
raise RuntimeError(format("[ERROR] status code: %d, body: %s" % (res.status_code, res.text)))
|
headers={"Authorization": "Bearer " + doc2x_api_key},
|
||||||
|
params=params
|
||||||
|
)
|
||||||
|
res_json = res.json()
|
||||||
|
if res_json['data']['status'] == "success":
|
||||||
|
break
|
||||||
|
elif res_json['data']['status'] == "processing":
|
||||||
|
time.sleep(3)
|
||||||
|
logger.info(f"Doc2x still processing")
|
||||||
|
elif res_json['data']['status'] == "failed":
|
||||||
|
raise RuntimeError(f"Doc2x return an error: {res_json}")
|
||||||
|
|
||||||
|
|
||||||
|
# < ------ 第5步:最后的处理 ------ >
|
||||||
|
logger.info("Doc2x 第5步:最后的处理")
|
||||||
|
|
||||||
|
if format=='tex':
|
||||||
|
target_path = latex_dir
|
||||||
|
if format=='md':
|
||||||
|
target_path = markdown_dir
|
||||||
|
os.makedirs(target_path, exist_ok=True)
|
||||||
|
|
||||||
|
max_attempt = 3
|
||||||
|
# < ------ 下载 ------ >
|
||||||
|
for attempt in range(max_attempt):
|
||||||
|
try:
|
||||||
|
result_url = res_json['data']['url']
|
||||||
|
res = requests.get(result_url)
|
||||||
|
zip_path = os.path.join(target_path, gen_time_str() + '.zip')
|
||||||
|
unzip_path = os.path.join(target_path, gen_time_str())
|
||||||
|
if res.status_code == 200:
|
||||||
|
with open(zip_path, "wb") as f: f.write(res.content)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Doc2x return an error: {res.json()}")
|
||||||
|
except Exception as e:
|
||||||
|
if attempt < max_attempt - 1:
|
||||||
|
logger.error(f"Failed to download latex file, retrying... {e}")
|
||||||
|
time.sleep(3)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# < ------ 解压 ------ >
|
||||||
import zipfile
|
import zipfile
|
||||||
with zipfile.ZipFile(latex_zip_path, 'r') as zip_ref:
|
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||||||
zip_ref.extractall(latex_unzip_path)
|
zip_ref.extractall(unzip_path)
|
||||||
|
return zip_path, unzip_path
|
||||||
|
|
||||||
return latex_unzip_path
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def 解析PDF_DOC2X_单文件(fp, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, DOC2X_API_KEY, user_request):
|
def 解析PDF_DOC2X_单文件(fp, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, DOC2X_API_KEY, user_request):
|
||||||
|
|
||||||
|
|
||||||
def pdf2markdown(filepath):
|
def pdf2markdown(filepath):
|
||||||
import requests, json, os
|
chatbot.append((None, f"Doc2x 解析中"))
|
||||||
markdown_dir = get_log_folder(plugin_name="pdf_ocr")
|
|
||||||
doc2x_api_key = DOC2X_API_KEY
|
|
||||||
if doc2x_api_key.startswith('sk-'):
|
|
||||||
url = "https://api.doc2x.noedgeai.com/api/v1/pdf"
|
|
||||||
else:
|
|
||||||
doc2x_api_key = refresh_key(doc2x_api_key)
|
|
||||||
url = "https://api.doc2x.noedgeai.com/api/platform/pdf"
|
|
||||||
|
|
||||||
chatbot.append((None, "加载PDF文件,发送至DOC2X解析..."))
|
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
|
|
||||||
res = requests.post(
|
md_zip_path, unzipped_folder = 解析PDF_DOC2X(filepath, format='md')
|
||||||
url,
|
|
||||||
files={"file": open(filepath, "rb")},
|
|
||||||
data={"ocr": "1"},
|
|
||||||
headers={"Authorization": "Bearer " + doc2x_api_key}
|
|
||||||
)
|
|
||||||
res_json = []
|
|
||||||
if res.status_code == 200:
|
|
||||||
decoded = res.content.decode("utf-8")
|
|
||||||
for z_decoded in decoded.split('\n'):
|
|
||||||
if len(z_decoded) == 0: continue
|
|
||||||
assert z_decoded.startswith("data: ")
|
|
||||||
z_decoded = z_decoded[len("data: "):]
|
|
||||||
decoded_json = json.loads(z_decoded)
|
|
||||||
res_json.append(decoded_json)
|
|
||||||
if 'limit exceeded' in decoded_json.get('status', ''):
|
|
||||||
raise RuntimeError("Doc2x API 页数受限,请联系 Doc2x 方面,并更换新的 API 秘钥。")
|
|
||||||
else:
|
|
||||||
raise RuntimeError(format("[ERROR] status code: %d, body: %s" % (res.status_code, res.text)))
|
|
||||||
uuid = res_json[0]['uuid']
|
|
||||||
to = "md" # latex, md, docx
|
|
||||||
url = "https://api.doc2x.noedgeai.com/api/export"+"?request_id="+uuid+"&to="+to
|
|
||||||
|
|
||||||
chatbot.append((None, f"读取解析: {url} ..."))
|
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
|
||||||
|
|
||||||
res = requests.get(url, headers={"Authorization": "Bearer " + doc2x_api_key})
|
|
||||||
md_zip_path = os.path.join(markdown_dir, gen_time_str() + '.zip')
|
|
||||||
if res.status_code == 200:
|
|
||||||
with open(md_zip_path, "wb") as f: f.write(res.content)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(format("[ERROR] status code: %d, body: %s" % (res.status_code, res.text)))
|
|
||||||
promote_file_to_downloadzone(md_zip_path, chatbot=chatbot)
|
promote_file_to_downloadzone(md_zip_path, chatbot=chatbot)
|
||||||
chatbot.append((None, f"完成解析 {md_zip_path} ..."))
|
chatbot.append((None, f"完成解析 {md_zip_path} ..."))
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
|
|||||||
@@ -0,0 +1,115 @@
|
|||||||
|
import logging
|
||||||
|
import tarfile
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Dict
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
class ArxivDownloader:
|
||||||
|
"""用于下载arXiv论文源码的下载器"""
|
||||||
|
|
||||||
|
def __init__(self, root_dir: str = "./papers", proxies: Optional[Dict[str, str]] = None):
|
||||||
|
"""
|
||||||
|
初始化下载器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root_dir: 保存下载文件的根目录
|
||||||
|
proxies: 代理服务器设置,例如 {"http": "http://proxy:port", "https": "https://proxy:port"}
|
||||||
|
"""
|
||||||
|
self.root_dir = Path(root_dir)
|
||||||
|
self.root_dir.mkdir(exist_ok=True)
|
||||||
|
self.proxies = proxies
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
|
||||||
|
def _download_and_extract(self, arxiv_id: str) -> str:
|
||||||
|
"""
|
||||||
|
下载并解压arxiv论文源码
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arxiv_id: arXiv论文ID,例如"2103.00020"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 解压后的文件目录路径
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: 当下载失败时抛出
|
||||||
|
"""
|
||||||
|
paper_dir = self.root_dir / arxiv_id
|
||||||
|
tar_path = paper_dir / f"{arxiv_id}.tar.gz"
|
||||||
|
|
||||||
|
# 检查缓存
|
||||||
|
if paper_dir.exists() and any(paper_dir.iterdir()):
|
||||||
|
logging.info(f"Using cached version for {arxiv_id}")
|
||||||
|
return str(paper_dir)
|
||||||
|
|
||||||
|
paper_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
urls = [
|
||||||
|
f"https://arxiv.org/src/{arxiv_id}",
|
||||||
|
f"https://arxiv.org/e-print/{arxiv_id}"
|
||||||
|
]
|
||||||
|
|
||||||
|
for url in urls:
|
||||||
|
try:
|
||||||
|
logging.info(f"Downloading from {url}")
|
||||||
|
response = requests.get(url, proxies=self.proxies)
|
||||||
|
if response.status_code == 200:
|
||||||
|
tar_path.write_bytes(response.content)
|
||||||
|
with tarfile.open(tar_path, 'r:gz') as tar:
|
||||||
|
tar.extractall(path=paper_dir)
|
||||||
|
return str(paper_dir)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Download failed for {url}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise RuntimeError(f"Failed to download paper {arxiv_id}")
|
||||||
|
|
||||||
|
def download_paper(self, arxiv_id: str) -> str:
|
||||||
|
"""
|
||||||
|
下载指定的arXiv论文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arxiv_id: arXiv论文ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 论文文件所在的目录路径
|
||||||
|
"""
|
||||||
|
return self._download_and_extract(arxiv_id)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""测试下载功能"""
|
||||||
|
# 配置代理(如果需要)
|
||||||
|
proxies = {
|
||||||
|
"http": "http://your-proxy:port",
|
||||||
|
"https": "https://your-proxy:port"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 创建下载器实例(如果不需要代理,可以不传入proxies参数)
|
||||||
|
downloader = ArxivDownloader(root_dir="./downloaded_papers", proxies=None)
|
||||||
|
|
||||||
|
# 测试下载一篇论文(这里使用一个示例ID)
|
||||||
|
try:
|
||||||
|
paper_id = "2103.00020" # 这是一个示例ID
|
||||||
|
paper_dir = downloader.download_paper(paper_id)
|
||||||
|
print(f"Successfully downloaded paper to: {paper_dir}")
|
||||||
|
|
||||||
|
# 检查下载的文件
|
||||||
|
paper_path = Path(paper_dir)
|
||||||
|
if paper_path.exists():
|
||||||
|
print("Downloaded files:")
|
||||||
|
for file in paper_path.rglob("*"):
|
||||||
|
if file.is_file():
|
||||||
|
print(f"- {file.relative_to(paper_path)}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error downloading paper: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,836 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import tarfile
|
||||||
|
import time
|
||||||
|
from copy import deepcopy
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Dict, Set
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from crazy_functions.rag_fns.arxiv_fns.author_extractor import LatexAuthorExtractor
|
||||||
|
from crazy_functions.rag_fns.arxiv_fns.essay_structure import EssayStructureParser, DocumentStructure, read_tex_file
|
||||||
|
from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section
|
||||||
|
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment
|
||||||
|
from crazy_functions.rag_fns.arxiv_fns.tex_utils import TexUtils
|
||||||
|
from crazy_functions.doc_fns.content_folder import PaperContentFormatter, PaperMetadata
|
||||||
|
|
||||||
|
|
||||||
|
def save_fragments_to_file(fragments: List[SectionFragment], output_dir: Path ) -> Path:
|
||||||
|
"""
|
||||||
|
Save all fragments to a single structured markdown file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fragments: List of SectionFragment objects
|
||||||
|
output_dir: Output directory path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Path to the generated markdown file
|
||||||
|
"""
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Create output directory
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
output_path = Path(output_dir)
|
||||||
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Generate filename
|
||||||
|
filename = f"paper_latex_content_{timestamp}.md"
|
||||||
|
file_path = output_path/ filename
|
||||||
|
|
||||||
|
# Group fragments by section
|
||||||
|
sections = {}
|
||||||
|
for fragment in fragments:
|
||||||
|
section = fragment.current_section or "Uncategorized"
|
||||||
|
if section not in sections:
|
||||||
|
sections[section] = []
|
||||||
|
sections[section].append(fragment)
|
||||||
|
|
||||||
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
|
# Write document header
|
||||||
|
f.write("# Document Fragments Analysis\n\n")
|
||||||
|
f.write("## Overview\n")
|
||||||
|
f.write(f"- Total Fragments: {len(fragments)}\n")
|
||||||
|
f.write(f"- Generated Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||||
|
|
||||||
|
# Add paper information if available
|
||||||
|
if fragments and (fragments[0].title or fragments[0].abstract):
|
||||||
|
f.write("\n## Paper Information\n")
|
||||||
|
if fragments[0].title:
|
||||||
|
f.write(f"### Title\n{fragments[0].title}\n")
|
||||||
|
if fragments[0].authors:
|
||||||
|
f.write(f"\n### Authors\n{fragments[0].authors}\n")
|
||||||
|
if fragments[0].abstract:
|
||||||
|
f.write(f"\n### Abstract\n{fragments[0].abstract}\n")
|
||||||
|
|
||||||
|
# Write section tree if available
|
||||||
|
if fragments and fragments[0].catalogs:
|
||||||
|
f.write("\n## Section Tree\n")
|
||||||
|
f.write("```\n") # 添加代码块开始标记
|
||||||
|
f.write(fragments[0].catalogs)
|
||||||
|
f.write("\n```") # 添加代码块结束标记
|
||||||
|
|
||||||
|
# Generate table of contents
|
||||||
|
f.write("\n## Table of Contents\n")
|
||||||
|
for section in sections:
|
||||||
|
clean_section = section.strip() or "Uncategorized"
|
||||||
|
fragment_count = len(sections[section])
|
||||||
|
f.write(f"- [{clean_section}](#{clean_section.lower().replace(' ', '-')}) "
|
||||||
|
f"({fragment_count} fragments)\n")
|
||||||
|
|
||||||
|
# Write content sections
|
||||||
|
f.write("\n## Content\n")
|
||||||
|
for section, section_fragments in sections.items():
|
||||||
|
clean_section = section.strip() or "Uncategorized"
|
||||||
|
f.write(f"\n### {clean_section}\n")
|
||||||
|
|
||||||
|
# Write each fragment
|
||||||
|
for i, fragment in enumerate(section_fragments, 1):
|
||||||
|
f.write(f"\n#### Fragment {i}\n")
|
||||||
|
|
||||||
|
# Metadata
|
||||||
|
f.write("**Metadata:**\n")
|
||||||
|
metadata = [
|
||||||
|
f"- Section: {fragment.current_section}",
|
||||||
|
f"- Length: {len(fragment.content)} chars",
|
||||||
|
f"- ArXiv ID: {fragment.arxiv_id}" if fragment.arxiv_id else None
|
||||||
|
]
|
||||||
|
f.write("\n".join(filter(None, metadata)) + "\n")
|
||||||
|
|
||||||
|
# Content
|
||||||
|
f.write("\n**Content:**\n")
|
||||||
|
f.write("\n")
|
||||||
|
f.write(fragment.content)
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
# Bibliography if exists
|
||||||
|
if fragment.bibliography:
|
||||||
|
f.write("\n**Bibliography:**\n")
|
||||||
|
f.write("```bibtex\n")
|
||||||
|
f.write(fragment.bibliography)
|
||||||
|
f.write("\n```\n")
|
||||||
|
|
||||||
|
# Add separator
|
||||||
|
if i < len(section_fragments):
|
||||||
|
f.write("\n---\n")
|
||||||
|
|
||||||
|
# Add statistics
|
||||||
|
f.write("\n## Statistics\n")
|
||||||
|
|
||||||
|
# Length distribution
|
||||||
|
lengths = [len(f.content) for f in fragments]
|
||||||
|
f.write("\n### Length Distribution\n")
|
||||||
|
f.write(f"- Minimum: {min(lengths)} chars\n")
|
||||||
|
f.write(f"- Maximum: {max(lengths)} chars\n")
|
||||||
|
f.write(f"- Average: {sum(lengths) / len(lengths):.1f} chars\n")
|
||||||
|
|
||||||
|
# Section distribution
|
||||||
|
f.write("\n### Section Distribution\n")
|
||||||
|
for section, section_fragments in sections.items():
|
||||||
|
percentage = (len(section_fragments) / len(fragments)) * 100
|
||||||
|
f.write(f"- {section}: {len(section_fragments)} ({percentage:.1f}%)\n")
|
||||||
|
|
||||||
|
print(f"Fragments saved to: {file_path}")
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
|
# 定义各种引用命令的模式
|
||||||
|
CITATION_PATTERNS = [
|
||||||
|
# 基本的 \cite{} 格式
|
||||||
|
r'\\cite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||||
|
# natbib 格式
|
||||||
|
r'\\citep(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||||
|
r'\\citet(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||||
|
r'\\citeauthor(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||||
|
r'\\citeyear(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||||
|
r'\\citealt(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||||
|
r'\\citealp(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||||
|
# biblatex 格式
|
||||||
|
r'\\textcite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||||
|
r'\\parencite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||||
|
r'\\autocite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||||
|
# 自定义 [cite:...] 格式
|
||||||
|
r'\[cite:([^\]]+)\]',
|
||||||
|
]
|
||||||
|
|
||||||
|
# 编译所有模式
|
||||||
|
COMPILED_PATTERNS = [re.compile(pattern) for pattern in CITATION_PATTERNS]
|
||||||
|
|
||||||
|
|
||||||
|
class ArxivSplitter:
|
||||||
|
"""Arxiv论文智能分割器"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
root_dir: str = "gpt_log/arxiv_cache",
|
||||||
|
proxies: Optional[Dict[str, str]] = None,
|
||||||
|
cache_ttl: int = 7 * 24 * 60 * 60):
|
||||||
|
"""
|
||||||
|
初始化分割器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
char_range: 字符数范围(最小值, 最大值)
|
||||||
|
root_dir: 缓存根目录
|
||||||
|
proxies: 代理设置
|
||||||
|
cache_ttl: 缓存过期时间(秒)
|
||||||
|
"""
|
||||||
|
self.root_dir = Path(root_dir)
|
||||||
|
self.root_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.proxies = proxies or {}
|
||||||
|
self.cache_ttl = cache_ttl
|
||||||
|
|
||||||
|
# 动态计算最优线程数
|
||||||
|
import multiprocessing
|
||||||
|
cpu_count = multiprocessing.cpu_count()
|
||||||
|
# 根据CPU核心数动态设置,但设置上限防止过度并发
|
||||||
|
self.document_structure = DocumentStructure()
|
||||||
|
self.document_parser = EssayStructureParser()
|
||||||
|
|
||||||
|
self.max_workers = min(32, cpu_count * 2)
|
||||||
|
|
||||||
|
# 初始化TeX处理器
|
||||||
|
self.tex_processor = TexUtils()
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
self._setup_logging()
|
||||||
|
|
||||||
|
def _setup_logging(self):
|
||||||
|
"""配置日志"""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def _normalize_arxiv_id(self, input_str: str) -> str:
|
||||||
|
"""规范化ArXiv ID"""
|
||||||
|
if 'arxiv.org/' in input_str.lower():
|
||||||
|
# 处理URL格式
|
||||||
|
if '/pdf/' in input_str:
|
||||||
|
arxiv_id = input_str.split('/pdf/')[-1]
|
||||||
|
else:
|
||||||
|
arxiv_id = input_str.split('/abs/')[-1]
|
||||||
|
# 移除版本号和其他后缀
|
||||||
|
return arxiv_id.split('v')[0].strip()
|
||||||
|
return input_str.split('v')[0].strip()
|
||||||
|
|
||||||
|
def _check_cache(self, paper_dir: Path) -> bool:
|
||||||
|
"""
|
||||||
|
检查缓存是否有效,包括文件完整性检查
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paper_dir: 论文目录路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 如果缓存有效返回True,否则返回False
|
||||||
|
"""
|
||||||
|
if not paper_dir.exists():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查目录中是否存在必要文件
|
||||||
|
has_tex_files = False
|
||||||
|
has_main_tex = False
|
||||||
|
|
||||||
|
for file_path in paper_dir.rglob("*"):
|
||||||
|
if file_path.suffix == '.tex':
|
||||||
|
has_tex_files = True
|
||||||
|
content = self.tex_processor.read_file(str(file_path))
|
||||||
|
if content and r'\documentclass' in content:
|
||||||
|
has_main_tex = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not (has_tex_files and has_main_tex):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查缓存时间
|
||||||
|
cache_time = paper_dir.stat().st_mtime
|
||||||
|
if (time.time() - cache_time) < self.cache_ttl:
|
||||||
|
self.logger.info(f"Using valid cache for {paper_dir.name}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def download_paper(self, arxiv_id: str, paper_dir: Path) -> bool:
|
||||||
|
"""
|
||||||
|
异步下载论文,包含重试机制和临时文件处理
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arxiv_id: ArXiv论文ID
|
||||||
|
paper_dir: 目标目录路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 下载成功返回True,否则返回False
|
||||||
|
"""
|
||||||
|
from crazy_functions.rag_fns.arxiv_fns.arxiv_downloader import ArxivDownloader
|
||||||
|
temp_tar_path = paper_dir / f"{arxiv_id}_temp.tar.gz"
|
||||||
|
final_tar_path = paper_dir / f"{arxiv_id}.tar.gz"
|
||||||
|
|
||||||
|
# 确保目录存在
|
||||||
|
paper_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 尝试使用 ArxivDownloader 下载
|
||||||
|
try:
|
||||||
|
downloader = ArxivDownloader(root_dir=str(paper_dir), proxies=self.proxies)
|
||||||
|
downloaded_dir = downloader.download_paper(arxiv_id)
|
||||||
|
if downloaded_dir:
|
||||||
|
self.logger.info(f"Successfully downloaded using ArxivDownloader to {downloaded_dir}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"ArxivDownloader failed: {str(e)}. Falling back to direct download.")
|
||||||
|
|
||||||
|
# 如果 ArxivDownloader 失败,使用原有的下载方式作为备选
|
||||||
|
urls = [
|
||||||
|
f"https://arxiv.org/src/{arxiv_id}",
|
||||||
|
f"https://arxiv.org/e-print/{arxiv_id}"
|
||||||
|
]
|
||||||
|
|
||||||
|
max_retries = 3
|
||||||
|
retry_delay = 1 # 初始重试延迟(秒)
|
||||||
|
|
||||||
|
for url in urls:
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
self.logger.info(f"Downloading from {url} (attempt {attempt + 1}/{max_retries})")
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(url, proxy=self.proxies.get('http')) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
content = await response.read()
|
||||||
|
|
||||||
|
# 写入临时文件
|
||||||
|
temp_tar_path.write_bytes(content)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 验证tar文件完整性并解压
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
await loop.run_in_executor(None, self._process_tar_file, temp_tar_path, paper_dir)
|
||||||
|
|
||||||
|
# 下载成功后移动临时文件到最终位置
|
||||||
|
temp_tar_path.rename(final_tar_path)
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Invalid tar file: {str(e)}")
|
||||||
|
if temp_tar_path.exists():
|
||||||
|
temp_tar_path.unlink()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Download attempt {attempt + 1} failed from {url}: {str(e)}")
|
||||||
|
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
|
||||||
|
continue
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _process_tar_file(self, tar_path: Path, extract_path: Path):
|
||||||
|
"""处理tar文件的同步操作"""
|
||||||
|
with tarfile.open(tar_path, 'r:gz') as tar:
|
||||||
|
tar.testall() # 验证文件完整性
|
||||||
|
tar.extractall(path=extract_path) # 解压文件
|
||||||
|
|
||||||
|
def process_references(self, doc_structure: DocumentStructure, ref_bib: str) -> DocumentStructure:
|
||||||
|
"""
|
||||||
|
Process citations in document structure and add referenced literature for each section
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc_structure: DocumentStructure object
|
||||||
|
ref_bib: String containing references separated by newlines
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated DocumentStructure object
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Create a copy to avoid modifying the original
|
||||||
|
doc = deepcopy(doc_structure)
|
||||||
|
|
||||||
|
# Parse references into a mapping
|
||||||
|
ref_map = self._parse_references(ref_bib)
|
||||||
|
if not ref_map:
|
||||||
|
self.logger.warning("No valid references found in ref_bib")
|
||||||
|
return doc
|
||||||
|
|
||||||
|
# Process all sections recursively
|
||||||
|
self._process_section_references(doc.toc, ref_map)
|
||||||
|
|
||||||
|
return doc
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error processing references: {str(e)}")
|
||||||
|
return doc_structure # Return original if processing fails
|
||||||
|
|
||||||
|
def _process_section_references(self, sections: List[Section], ref_map: Dict[str, str]) -> None:
|
||||||
|
"""
|
||||||
|
Recursively process sections to add references
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sections: List of Section objects
|
||||||
|
ref_map: Mapping of citation keys to full references
|
||||||
|
"""
|
||||||
|
for section in sections:
|
||||||
|
if section.content:
|
||||||
|
# Find citations in current section
|
||||||
|
cited_refs = self.find_citations(section.content)
|
||||||
|
|
||||||
|
if cited_refs:
|
||||||
|
# Get full references for citations
|
||||||
|
full_refs = []
|
||||||
|
for ref_key in cited_refs:
|
||||||
|
ref_text = ref_map.get(ref_key)
|
||||||
|
if ref_text:
|
||||||
|
full_refs.append(ref_text)
|
||||||
|
else:
|
||||||
|
self.logger.warning(f"Reference not found for citation key: {ref_key}")
|
||||||
|
|
||||||
|
# Add references to section content
|
||||||
|
if full_refs:
|
||||||
|
section.bibliography = "\n\n".join(full_refs)
|
||||||
|
|
||||||
|
# Process subsections recursively
|
||||||
|
if section.subsections:
|
||||||
|
self._process_section_references(section.subsections, ref_map)
|
||||||
|
|
||||||
|
def _parse_references(self, ref_bib: str) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Parse reference string into a mapping of citation keys to full references
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ref_bib: Reference string with references separated by newlines
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping citation keys to full reference text
|
||||||
|
"""
|
||||||
|
ref_map = {}
|
||||||
|
current_ref = []
|
||||||
|
current_key = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
for line in ref_bib.split('\n'):
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# New reference entry
|
||||||
|
if line.startswith('@'):
|
||||||
|
# Save previous reference if exists
|
||||||
|
if current_key and current_ref:
|
||||||
|
ref_map[current_key] = '\n'.join(current_ref)
|
||||||
|
current_ref = []
|
||||||
|
|
||||||
|
# Extract key from new reference
|
||||||
|
key_match = re.search(r'{(.*?),', line)
|
||||||
|
if key_match:
|
||||||
|
current_key = key_match.group(1)
|
||||||
|
current_ref.append(line)
|
||||||
|
else:
|
||||||
|
if current_ref is not None:
|
||||||
|
current_ref.append(line)
|
||||||
|
|
||||||
|
# Save last reference
|
||||||
|
if current_key and current_ref:
|
||||||
|
ref_map[current_key] = '\n'.join(current_ref)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error parsing references: {str(e)}")
|
||||||
|
|
||||||
|
return ref_map
|
||||||
|
|
||||||
|
# 编译一次正则表达式以提高效率
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _clean_citation_key(key: str) -> str:
|
||||||
|
"""Clean individual citation key."""
|
||||||
|
return key.strip().strip(',').strip()
|
||||||
|
|
||||||
|
def _extract_keys_from_group(self, keys_str: str) -> Set[str]:
|
||||||
|
"""Extract and clean individual citation keys from a group."""
|
||||||
|
try:
|
||||||
|
# 分割多个引用键(支持逗号和分号分隔)
|
||||||
|
separators = '[,;]'
|
||||||
|
keys = re.split(separators, keys_str)
|
||||||
|
# 清理并过滤空键
|
||||||
|
return {self._clean_citation_key(k) for k in keys if self._clean_citation_key(k)}
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Error processing citation group '{keys_str}': {e}")
|
||||||
|
return set()
|
||||||
|
|
||||||
|
def find_citations(self, content: str) -> Set[str]:
|
||||||
|
"""
|
||||||
|
Find citation keys in text content in various formats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Text content to search for citations
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of unique citation keys
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Supported formats include:
|
||||||
|
- \cite{key1,key2}
|
||||||
|
- \cite[p. 1]{key}
|
||||||
|
- \citep{key}
|
||||||
|
- \citet{key}
|
||||||
|
- [cite:key1, key2]
|
||||||
|
- And many other variants
|
||||||
|
"""
|
||||||
|
citations = set()
|
||||||
|
|
||||||
|
if not content:
|
||||||
|
return citations
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 对每个编译好的模式进行搜索
|
||||||
|
for pattern in COMPILED_PATTERNS:
|
||||||
|
matches = pattern.finditer(content)
|
||||||
|
for match in matches:
|
||||||
|
# 获取捕获组中的引用键
|
||||||
|
keys_str = match.group(1)
|
||||||
|
if keys_str:
|
||||||
|
# 提取并添加所有引用键
|
||||||
|
new_keys = self._extract_keys_from_group(keys_str)
|
||||||
|
citations.update(new_keys)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error finding citations: {str(e)}")
|
||||||
|
|
||||||
|
# 移除明显无效的键
|
||||||
|
citations = {key for key in citations
|
||||||
|
if key and not key.startswith(('\\', '{', '}', '[', ']'))}
|
||||||
|
|
||||||
|
return citations
|
||||||
|
|
||||||
|
def get_citation_contexts(self, content: str, context_chars: int = 100) -> dict:
|
||||||
|
"""
|
||||||
|
Find citations and their surrounding context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Text content to search for citations
|
||||||
|
context_chars: Number of characters of context to include before/after
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping citation keys to lists of context strings
|
||||||
|
"""
|
||||||
|
contexts = {}
|
||||||
|
|
||||||
|
if not content:
|
||||||
|
return contexts
|
||||||
|
|
||||||
|
try:
|
||||||
|
for pattern in COMPILED_PATTERNS:
|
||||||
|
matches = pattern.finditer(content)
|
||||||
|
for match in matches:
|
||||||
|
# 获取匹配的位置
|
||||||
|
start = max(0, match.start() - context_chars)
|
||||||
|
end = min(len(content), match.end() + context_chars)
|
||||||
|
|
||||||
|
# 获取上下文
|
||||||
|
context = content[start:end]
|
||||||
|
|
||||||
|
# 获取并处理引用键
|
||||||
|
keys_str = match.group(1)
|
||||||
|
keys = self._extract_keys_from_group(keys_str)
|
||||||
|
|
||||||
|
# 为每个键添加上下文
|
||||||
|
for key in keys:
|
||||||
|
if key not in contexts:
|
||||||
|
contexts[key] = []
|
||||||
|
contexts[key].append(context)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error finding citation contexts: {str(e)}")
|
||||||
|
|
||||||
|
return contexts
|
||||||
|
|
||||||
|
async def process(self, arxiv_id_or_url: str) -> List[SectionFragment]:
|
||||||
|
"""
|
||||||
|
Process ArXiv paper and convert to list of SectionFragments.
|
||||||
|
Each fragment represents the smallest section unit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arxiv_id_or_url: ArXiv paper ID or URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[SectionFragment]: List of processed paper fragments
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
arxiv_id = self._normalize_arxiv_id(arxiv_id_or_url)
|
||||||
|
paper_dir = self.root_dir / arxiv_id
|
||||||
|
|
||||||
|
# Check if paper directory exists, if not, try to download
|
||||||
|
if not paper_dir.exists():
|
||||||
|
self.logger.info(f"Downloading paper {arxiv_id}")
|
||||||
|
await self.download_paper(arxiv_id, paper_dir)
|
||||||
|
|
||||||
|
# Find main TeX file
|
||||||
|
main_tex = self.tex_processor.find_main_tex_file(str(paper_dir))
|
||||||
|
if not main_tex:
|
||||||
|
raise RuntimeError(f"No main TeX file found in {paper_dir}")
|
||||||
|
|
||||||
|
# 读取主 TeX 文件内容
|
||||||
|
main_tex_content = read_tex_file(main_tex)
|
||||||
|
|
||||||
|
# Get all related TeX files and references
|
||||||
|
tex_files = self.tex_processor.resolve_includes(main_tex)
|
||||||
|
ref_bib = self.tex_processor.resolve_references(main_tex, paper_dir)
|
||||||
|
|
||||||
|
if not tex_files:
|
||||||
|
raise RuntimeError(f"No valid TeX files found for {arxiv_id}")
|
||||||
|
|
||||||
|
# Reset document structure for new processing
|
||||||
|
self.document_structure = DocumentStructure()
|
||||||
|
|
||||||
|
# 提取作者信息
|
||||||
|
author_extractor = LatexAuthorExtractor()
|
||||||
|
authors = author_extractor.extract_authors(main_tex_content)
|
||||||
|
self.document_structure.authors = authors # 保存到文档结构中
|
||||||
|
|
||||||
|
# Process each TeX file
|
||||||
|
for file_path in tex_files:
|
||||||
|
self.logger.info(f"Processing TeX file: {file_path}")
|
||||||
|
tex_content = read_tex_file(file_path)
|
||||||
|
if tex_content:
|
||||||
|
additional_doc = self.document_parser.parse(tex_content)
|
||||||
|
self.document_structure = self.document_structure.merge(additional_doc)
|
||||||
|
|
||||||
|
# Process references if available
|
||||||
|
if ref_bib:
|
||||||
|
self.document_structure = self.process_references(self.document_structure, ref_bib)
|
||||||
|
self.logger.info("Successfully processed references")
|
||||||
|
else:
|
||||||
|
self.logger.info("No references found to process")
|
||||||
|
|
||||||
|
# Generate table of contents once
|
||||||
|
section_tree = self.document_structure.generate_toc_tree()
|
||||||
|
|
||||||
|
# Convert DocumentStructure to SectionFragments
|
||||||
|
fragments = self._convert_to_fragments(
|
||||||
|
doc_structure=self.document_structure,
|
||||||
|
arxiv_id=arxiv_id,
|
||||||
|
section_tree=section_tree
|
||||||
|
)
|
||||||
|
|
||||||
|
return fragments
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to process {arxiv_id_or_url}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _convert_to_fragments(self,
|
||||||
|
doc_structure: DocumentStructure,
|
||||||
|
arxiv_id: str,
|
||||||
|
section_tree: str) -> List[SectionFragment]:
|
||||||
|
"""
|
||||||
|
Convert DocumentStructure to list of SectionFragments.
|
||||||
|
Creates a fragment for each leaf section in the document hierarchy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc_structure: Source DocumentStructure
|
||||||
|
arxiv_id: ArXiv paper ID
|
||||||
|
section_tree: Pre-generated table of contents tree
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[SectionFragment]: List of paper fragments
|
||||||
|
"""
|
||||||
|
fragments = []
|
||||||
|
|
||||||
|
# Create a base template for all fragments to avoid repetitive assignments
|
||||||
|
base_fragment_template = {
|
||||||
|
'title': doc_structure.title,
|
||||||
|
'authors': doc_structure.authors,
|
||||||
|
'abstract': doc_structure.abstract,
|
||||||
|
'catalogs': section_tree,
|
||||||
|
'arxiv_id': arxiv_id
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_leaf_sections(section: Section, path: List[str] = None) -> None:
|
||||||
|
"""
|
||||||
|
Recursively find all leaf sections and create fragments.
|
||||||
|
A leaf section is one that has content but no subsections, or has neither.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
section: Current section being processed
|
||||||
|
path: List of section titles forming the path to current section
|
||||||
|
"""
|
||||||
|
if path is None:
|
||||||
|
path = []
|
||||||
|
|
||||||
|
current_path = path + [section.title]
|
||||||
|
|
||||||
|
if not section.subsections:
|
||||||
|
# This is a leaf section, create a fragment if it has content
|
||||||
|
if section.content or section.bibliography:
|
||||||
|
fragment = SectionFragment(
|
||||||
|
**base_fragment_template,
|
||||||
|
current_section="/".join(current_path),
|
||||||
|
content=self._clean_content(section.content),
|
||||||
|
bibliography=section.bibliography
|
||||||
|
)
|
||||||
|
if self._validate_fragment(fragment):
|
||||||
|
fragments.append(fragment)
|
||||||
|
else:
|
||||||
|
# Process each subsection
|
||||||
|
for subsection in section.subsections:
|
||||||
|
get_leaf_sections(subsection, current_path)
|
||||||
|
|
||||||
|
# Process all top-level sections
|
||||||
|
for section in doc_structure.toc:
|
||||||
|
get_leaf_sections(section)
|
||||||
|
|
||||||
|
# Add a fragment for the abstract if it exists
|
||||||
|
if doc_structure.abstract:
|
||||||
|
abstract_fragment = SectionFragment(
|
||||||
|
**base_fragment_template,
|
||||||
|
current_section="Abstract",
|
||||||
|
content=self._clean_content(doc_structure.abstract)
|
||||||
|
)
|
||||||
|
if self._validate_fragment(abstract_fragment):
|
||||||
|
fragments.insert(0, abstract_fragment)
|
||||||
|
|
||||||
|
self.logger.info(f"Created {len(fragments)} fragments")
|
||||||
|
return fragments
|
||||||
|
|
||||||
|
def _validate_fragment(self, fragment: SectionFragment) -> bool:
|
||||||
|
"""
|
||||||
|
Validate if the fragment has all required fields with meaningful content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fragment: SectionFragment to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if fragment is valid, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return all([
|
||||||
|
fragment.title.strip(),
|
||||||
|
fragment.catalogs.strip(),
|
||||||
|
fragment.current_section.strip(),
|
||||||
|
fragment.content.strip() or fragment.bibliography.strip()
|
||||||
|
])
|
||||||
|
except AttributeError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _clean_content(self, content: str) -> str:
|
||||||
|
"""
|
||||||
|
Clean and normalize content text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Raw content text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Cleaned content text
|
||||||
|
"""
|
||||||
|
if not content:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Remove excessive whitespace
|
||||||
|
content = re.sub(r'\s+', ' ', content)
|
||||||
|
|
||||||
|
# Remove remaining LaTeX artifacts
|
||||||
|
content = re.sub(r'\\item\s*', '• ', content) # Convert \item to bullet points
|
||||||
|
content = re.sub(r'\\[a-zA-Z]+\{([^}]*)\}', r'\1', content) # Remove simple LaTeX commands
|
||||||
|
|
||||||
|
# Clean special characters
|
||||||
|
content = content.replace('\\\\', '\n') # Convert LaTeX newlines to actual newlines
|
||||||
|
content = re.sub(r'\s*\n\s*', '\n', content) # Clean up newlines
|
||||||
|
|
||||||
|
return content.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def process_arxiv_sync(splitter: ArxivSplitter, arxiv_id: str) -> tuple[List[SectionFragment], str, List[Path]]:
|
||||||
|
"""
|
||||||
|
同步处理 ArXiv 文档并返回分割后的片段
|
||||||
|
|
||||||
|
Args:
|
||||||
|
splitter: ArxivSplitter 实例
|
||||||
|
arxiv_id: ArXiv 文档ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: 分割后的文档片段列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from crazy_functions.doc_fns.tex_html_formatter import PaperHtmlFormatter
|
||||||
|
# 创建一个异步函数来执行异步操作
|
||||||
|
async def _process():
|
||||||
|
return await splitter.process(arxiv_id)
|
||||||
|
|
||||||
|
# 使用 asyncio.run() 运行异步函数
|
||||||
|
output_files=[]
|
||||||
|
fragments = asyncio.run(_process())
|
||||||
|
file_save_path = splitter.root_dir / "arxiv_fragments"
|
||||||
|
# 保存片段到文件
|
||||||
|
try:
|
||||||
|
md_output_dir = save_fragments_to_file(
|
||||||
|
fragments,
|
||||||
|
output_dir = file_save_path
|
||||||
|
)
|
||||||
|
output_files.append(md_output_dir)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
# 创建论文格式化器
|
||||||
|
formatter = PaperContentFormatter()
|
||||||
|
|
||||||
|
# 准备元数据
|
||||||
|
# 创建格式化选项
|
||||||
|
|
||||||
|
metadata = PaperMetadata(
|
||||||
|
title=fragments[0].title,
|
||||||
|
authors=fragments[0].authors,
|
||||||
|
abstract=fragments[0].abstract,
|
||||||
|
catalogs=fragments[0].catalogs,
|
||||||
|
arxiv_id=fragments[0].arxiv_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 格式化内容
|
||||||
|
formatted_content = formatter.format(fragments, metadata)
|
||||||
|
|
||||||
|
try:
|
||||||
|
html_formatter = PaperHtmlFormatter(fragments, file_save_path)
|
||||||
|
html_output_dir = html_formatter.save_html()
|
||||||
|
output_files.append(html_output_dir)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return fragments, formatted_content, output_files
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ Processing failed for {arxiv_id}: {str(e)}")
|
||||||
|
raise
|
||||||
|
def test_arxiv_splitter():
|
||||||
|
"""测试ArXiv分割器的功能"""
|
||||||
|
|
||||||
|
# 测试配置
|
||||||
|
test_cases = [
|
||||||
|
{
|
||||||
|
"arxiv_id": "2411.03663",
|
||||||
|
"expected_title": "Large Language Models and Simple Scripts",
|
||||||
|
"min_fragments": 10,
|
||||||
|
},
|
||||||
|
# {
|
||||||
|
# "arxiv_id": "1805.10988",
|
||||||
|
# "expected_title": "RAG vs Fine-tuning",
|
||||||
|
# "min_fragments": 15,
|
||||||
|
# }
|
||||||
|
]
|
||||||
|
|
||||||
|
# 创建分割器实例
|
||||||
|
splitter = ArxivSplitter(
|
||||||
|
root_dir="private_upload/default_user"
|
||||||
|
)
|
||||||
|
|
||||||
|
for case in test_cases:
|
||||||
|
print(f"\nTesting paper: {case['arxiv_id']}")
|
||||||
|
try:
|
||||||
|
# fragments = await splitter.process(case['arxiv_id'])
|
||||||
|
fragments, formatted_content, output_dir = process_arxiv_sync(splitter, case['arxiv_id'])
|
||||||
|
# 保存fragments
|
||||||
|
for fragment in fragments:
|
||||||
|
# 长度检查
|
||||||
|
print((fragment.content))
|
||||||
|
print(len(fragment.content))
|
||||||
|
# 类型检查
|
||||||
|
print(output_dir)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ Test failed for {case['arxiv_id']}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_arxiv_splitter()
|
||||||
@@ -0,0 +1,177 @@
|
|||||||
|
import re
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class LatexAuthorExtractor:
|
||||||
|
def __init__(self):
|
||||||
|
# Patterns for matching author blocks with balanced braces
|
||||||
|
self.author_block_patterns = [
|
||||||
|
# Standard LaTeX patterns with optional arguments
|
||||||
|
r'\\author(?:\s*\[[^\]]*\])?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
r'\\(?:title)?author[s]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
r'\\name[s]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
r'\\Author[s]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
r'\\AUTHOR[S]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
# Conference and journal specific patterns
|
||||||
|
r'\\addauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
r'\\IEEEauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
r'\\speaker\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
r'\\authorrunning\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
# Academic publisher specific patterns
|
||||||
|
r'\\alignauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
r'\\spauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
r'\\authors\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
]
|
||||||
|
|
||||||
|
# Cleaning patterns for LaTeX commands and formatting
|
||||||
|
self.cleaning_patterns = [
|
||||||
|
# Text formatting commands - preserve content
|
||||||
|
(r'\\textbf\{([^}]+)\}', r'\1'),
|
||||||
|
(r'\\textit\{([^}]+)\}', r'\1'),
|
||||||
|
(r'\\emph\{([^}]+)\}', r'\1'),
|
||||||
|
(r'\\texttt\{([^}]+)\}', r'\1'),
|
||||||
|
(r'\\textrm\{([^}]+)\}', r'\1'),
|
||||||
|
(r'\\text\{([^}]+)\}', r'\1'),
|
||||||
|
|
||||||
|
# Affiliation and footnote markers
|
||||||
|
(r'\$\^{[^}]+}\$', ''),
|
||||||
|
(r'\^{[^}]+}', ''),
|
||||||
|
(r'\\thanks\{[^}]+\}', ''),
|
||||||
|
(r'\\footnote\{[^}]+\}', ''),
|
||||||
|
|
||||||
|
# Email and contact formatting
|
||||||
|
(r'\\email\{([^}]+)\}', r'\1'),
|
||||||
|
(r'\\href\{[^}]+\}\{([^}]+)\}', r'\1'),
|
||||||
|
|
||||||
|
# Institution formatting
|
||||||
|
(r'\\inst\{[^}]+\}', ''),
|
||||||
|
(r'\\affil\{[^}]+\}', ''),
|
||||||
|
|
||||||
|
# Special characters and symbols
|
||||||
|
(r'\\&', '&'),
|
||||||
|
(r'\\\\\s*', ' '),
|
||||||
|
(r'\\,', ' '),
|
||||||
|
(r'\\;', ' '),
|
||||||
|
(r'\\quad', ' '),
|
||||||
|
(r'\\qquad', ' '),
|
||||||
|
|
||||||
|
# Math mode content
|
||||||
|
(r'\$[^$]+\$', ''),
|
||||||
|
|
||||||
|
# Common symbols
|
||||||
|
(r'\\dagger', '†'),
|
||||||
|
(r'\\ddagger', '‡'),
|
||||||
|
(r'\\ast', '*'),
|
||||||
|
(r'\\star', '★'),
|
||||||
|
|
||||||
|
# Remove remaining LaTeX commands
|
||||||
|
(r'\\[a-zA-Z]+', ''),
|
||||||
|
|
||||||
|
# Clean up remaining special characters
|
||||||
|
(r'[\\{}]', '')
|
||||||
|
]
|
||||||
|
|
||||||
|
def extract_author_block(self, text: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Extract the complete author block from LaTeX text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): Input LaTeX text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: Extracted author block or None if not found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for pattern in self.author_block_patterns:
|
||||||
|
match = re.search(pattern, text, re.DOTALL | re.MULTILINE)
|
||||||
|
if match:
|
||||||
|
return match.group(1).strip()
|
||||||
|
return None
|
||||||
|
|
||||||
|
except (AttributeError, IndexError) as e:
|
||||||
|
print(f"Error extracting author block: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def clean_tex_commands(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
Remove LaTeX commands and formatting from text while preserving content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): Text containing LaTeX commands
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Cleaned text with commands removed
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
cleaned_text = text
|
||||||
|
|
||||||
|
# Apply cleaning patterns
|
||||||
|
for pattern, replacement in self.cleaning_patterns:
|
||||||
|
cleaned_text = re.sub(pattern, replacement, cleaned_text)
|
||||||
|
|
||||||
|
# Clean up whitespace
|
||||||
|
cleaned_text = re.sub(r'\s+', ' ', cleaned_text)
|
||||||
|
cleaned_text = cleaned_text.strip()
|
||||||
|
|
||||||
|
return cleaned_text
|
||||||
|
|
||||||
|
def extract_authors(self, text: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Extract and clean author information from LaTeX text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): Input LaTeX text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: Cleaned author information or None if extraction fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Extract author block
|
||||||
|
author_block = self.extract_author_block(text)
|
||||||
|
if not author_block:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Clean LaTeX commands
|
||||||
|
cleaned_authors = self.clean_tex_commands(author_block)
|
||||||
|
return cleaned_authors or None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing text: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def test_author_extractor():
|
||||||
|
"""Test the LatexAuthorExtractor with sample inputs."""
|
||||||
|
test_cases = [
|
||||||
|
# Basic test case
|
||||||
|
(r"\author{John Doe}", "John Doe"),
|
||||||
|
|
||||||
|
# Test with multiple authors
|
||||||
|
(r"\author{Alice Smith \and Bob Jones}", "Alice Smith and Bob Jones"),
|
||||||
|
|
||||||
|
# Test with affiliations
|
||||||
|
(r"\author[1]{John Smith}\affil[1]{University}", "John Smith"),
|
||||||
|
|
||||||
|
]
|
||||||
|
|
||||||
|
extractor = LatexAuthorExtractor()
|
||||||
|
|
||||||
|
for i, (input_tex, expected) in enumerate(test_cases, 1):
|
||||||
|
result = extractor.extract_authors(input_tex)
|
||||||
|
print(f"\nTest case {i}:")
|
||||||
|
print(f"Input: {input_tex[:50]}...")
|
||||||
|
print(f"Expected: {expected[:50]}...")
|
||||||
|
print(f"Got: {result[:50]}...")
|
||||||
|
print(f"Pass: {bool(result and result.strip() == expected.strip())}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_author_extractor()
|
||||||
@@ -0,0 +1,290 @@
|
|||||||
|
"""
|
||||||
|
LaTeX Document Parser
|
||||||
|
|
||||||
|
This module provides functionality for parsing and extracting structured information from LaTeX documents,
|
||||||
|
including metadata, document structure, and content. It uses modular design and clean architecture principles.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from copy import deepcopy
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
from crazy_functions.rag_fns.arxiv_fns.latex_cleaner import clean_latex_commands
|
||||||
|
from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section, EnhancedSectionExtractor
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def read_tex_file(file_path):
|
||||||
|
encodings = ['utf-8', 'latin1', 'gbk', 'gb2312', 'ascii']
|
||||||
|
for encoding in encodings:
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r', encoding=encoding) as f:
|
||||||
|
return f.read()
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DocumentStructure:
|
||||||
|
title: str = ''
|
||||||
|
authors: str = ''
|
||||||
|
abstract: str = ''
|
||||||
|
toc: List[Section] = field(default_factory=list)
|
||||||
|
metadata: Dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def merge(self, other: 'DocumentStructure', strategy: str = 'smart') -> 'DocumentStructure':
|
||||||
|
"""
|
||||||
|
Merge this document structure with another one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other: Another DocumentStructure to merge with
|
||||||
|
strategy: Merge strategy - 'smart' (default) or 'append'
|
||||||
|
'smart' - Intelligently merge sections with same titles
|
||||||
|
'append' - Simply append sections from other document
|
||||||
|
"""
|
||||||
|
merged = deepcopy(self)
|
||||||
|
|
||||||
|
# Merge title if needed
|
||||||
|
if not merged.title and other.title:
|
||||||
|
merged.title = other.title
|
||||||
|
|
||||||
|
# Merge abstract
|
||||||
|
merged.abstract = self._merge_abstract(merged.abstract, other.abstract)
|
||||||
|
|
||||||
|
# Merge metadata
|
||||||
|
merged.metadata.update(other.metadata)
|
||||||
|
|
||||||
|
if strategy == 'append':
|
||||||
|
merged.toc.extend(deepcopy(other.toc))
|
||||||
|
else: # smart merge
|
||||||
|
# Create sections lookup for efficient merging
|
||||||
|
sections_map = {s.title: s for s in merged.toc}
|
||||||
|
|
||||||
|
for other_section in other.toc:
|
||||||
|
if other_section.title in sections_map:
|
||||||
|
# Merge existing section
|
||||||
|
idx = next(i for i, s in enumerate(merged.toc)
|
||||||
|
if s.title == other_section.title)
|
||||||
|
merged.toc[idx] = merged.toc[idx].merge(other_section)
|
||||||
|
else:
|
||||||
|
# Add new section
|
||||||
|
merged.toc.append(deepcopy(other_section))
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_abstract(abstract1: str, abstract2: str) -> str:
|
||||||
|
"""Merge abstracts intelligently."""
|
||||||
|
if not abstract1:
|
||||||
|
return abstract2
|
||||||
|
if not abstract2:
|
||||||
|
return abstract1
|
||||||
|
# Combine non-empty abstracts with a separator
|
||||||
|
return f"{abstract1}\n\n{abstract2}"
|
||||||
|
|
||||||
|
def generate_toc_tree(self, indent_char: str = " ", abstract_preview_length: int = 0) -> str:
|
||||||
|
"""
|
||||||
|
Generate a tree-like string representation of the table of contents including abstract.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
indent_char: Character(s) used for indentation. Default is two spaces.
|
||||||
|
abstract_preview_length: Maximum length of abstract preview. Default is 200 characters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A formatted string showing the hierarchical document structure with abstract
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _format_section(section: Section, level: int = 0) -> str:
|
||||||
|
# Create the current section line with proper indentation
|
||||||
|
current_line = f"{indent_char * level}{'•' if level > 0 else '○'} {section.title}\n"
|
||||||
|
|
||||||
|
# Recursively process subsections
|
||||||
|
subsections = ""
|
||||||
|
if section.subsections:
|
||||||
|
subsections = "".join(_format_section(subsec, level + 1)
|
||||||
|
for subsec in section.subsections)
|
||||||
|
|
||||||
|
return current_line + subsections
|
||||||
|
|
||||||
|
result = []
|
||||||
|
|
||||||
|
# Add document title if it exists
|
||||||
|
if self.title:
|
||||||
|
result.append(f"《{self.title}》\n")
|
||||||
|
|
||||||
|
# Add abstract if it exists
|
||||||
|
if self.abstract:
|
||||||
|
result.append("\n□ Abstract:")
|
||||||
|
# Format abstract content with word wrap
|
||||||
|
abstract_preview = self.abstract[:abstract_preview_length]
|
||||||
|
if len(self.abstract) > abstract_preview_length:
|
||||||
|
abstract_preview += "..."
|
||||||
|
|
||||||
|
# Split abstract into lines and indent them
|
||||||
|
wrapped_lines = []
|
||||||
|
current_line = ""
|
||||||
|
for word in abstract_preview.split():
|
||||||
|
if len(current_line) + len(word) + 1 <= 80: # 80 characters per line
|
||||||
|
current_line = (current_line + " " + word).strip()
|
||||||
|
else:
|
||||||
|
wrapped_lines.append(current_line)
|
||||||
|
current_line = word
|
||||||
|
if current_line:
|
||||||
|
wrapped_lines.append(current_line)
|
||||||
|
|
||||||
|
# Add formatted abstract lines
|
||||||
|
for line in wrapped_lines:
|
||||||
|
result.append(f"\n{indent_char}{line}")
|
||||||
|
result.append("\n") # Add extra newline after abstract
|
||||||
|
|
||||||
|
# Add table of contents header if there are sections
|
||||||
|
if self.toc:
|
||||||
|
result.append("\n◈ Table of Contents:\n")
|
||||||
|
|
||||||
|
# Add all top-level sections and their subsections
|
||||||
|
result.extend(_format_section(section, 0) for section in self.toc)
|
||||||
|
|
||||||
|
return "".join(result)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseExtractor(ABC):
|
||||||
|
"""Base class for LaTeX content extractors."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def extract(self, content: str) -> str:
|
||||||
|
"""Extract specific content from LaTeX document."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TitleExtractor(BaseExtractor):
|
||||||
|
"""Extracts title from LaTeX document."""
|
||||||
|
|
||||||
|
PATTERNS = [
|
||||||
|
r'\\title{(.+?)}',
|
||||||
|
r'\\title\[.*?\]{(.+?)}',
|
||||||
|
r'\\Title{(.+?)}',
|
||||||
|
r'\\TITLE{(.+?)}',
|
||||||
|
r'\\begin{document}\s*\\section[*]?{(.+?)}',
|
||||||
|
r'\\maketitle\s*\\section[*]?{(.+?)}',
|
||||||
|
r'\\chapter[*]?{(.+?)}'
|
||||||
|
]
|
||||||
|
|
||||||
|
def extract(self, content: str) -> str:
|
||||||
|
"""Extract title using defined patterns."""
|
||||||
|
for pattern in self.PATTERNS:
|
||||||
|
matches = list(re.finditer(pattern, content, re.IGNORECASE | re.DOTALL))
|
||||||
|
for match in matches:
|
||||||
|
title = match.group(1).strip()
|
||||||
|
if title:
|
||||||
|
return clean_latex_commands(title)
|
||||||
|
return ''
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractExtractor(BaseExtractor):
|
||||||
|
"""Extracts abstract from LaTeX document."""
|
||||||
|
|
||||||
|
PATTERNS = [
|
||||||
|
r'\\begin{abstract}(.*?)\\end{abstract}',
|
||||||
|
r'\\abstract{(.*?)}',
|
||||||
|
r'\\ABSTRACT{(.*?)}',
|
||||||
|
r'\\Abstract{(.*?)}',
|
||||||
|
r'\\begin{Abstract}(.*?)\\end{Abstract}',
|
||||||
|
r'\\section[*]?{(?:Abstract|ABSTRACT)}\s*(.*?)(?:\\section|\Z)',
|
||||||
|
r'\\chapter[*]?{(?:Abstract|ABSTRACT)}\s*(.*?)(?:\\chapter|\Z)'
|
||||||
|
]
|
||||||
|
|
||||||
|
def extract(self, content: str) -> str:
|
||||||
|
"""Extract abstract using defined patterns."""
|
||||||
|
for pattern in self.PATTERNS:
|
||||||
|
matches = list(re.finditer(pattern, content, re.IGNORECASE | re.DOTALL))
|
||||||
|
for match in matches:
|
||||||
|
abstract = match.group(1).strip()
|
||||||
|
if abstract:
|
||||||
|
return clean_latex_commands(abstract)
|
||||||
|
return ''
|
||||||
|
|
||||||
|
|
||||||
|
class EssayStructureParser:
|
||||||
|
"""Main class for parsing LaTeX documents."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.title_extractor = TitleExtractor()
|
||||||
|
self.abstract_extractor = AbstractExtractor()
|
||||||
|
self.section_extractor = EnhancedSectionExtractor() # Using the enhanced extractor
|
||||||
|
|
||||||
|
def parse(self, content: str) -> DocumentStructure:
|
||||||
|
"""Parse LaTeX document and extract structured information."""
|
||||||
|
try:
|
||||||
|
content = self._preprocess_content(content)
|
||||||
|
|
||||||
|
return DocumentStructure(
|
||||||
|
title=self.title_extractor.extract(content),
|
||||||
|
abstract=self.abstract_extractor.extract(content),
|
||||||
|
toc=self.section_extractor.extract(content)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error parsing LaTeX document: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _preprocess_content(self, content: str) -> str:
|
||||||
|
"""Preprocess LaTeX content for parsing."""
|
||||||
|
# Remove comments
|
||||||
|
content = re.sub(r'(?<!\\)%.*$', '', content, flags=re.MULTILINE)
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def pretty_print_structure(doc: DocumentStructure, max_content_length: int = 100):
|
||||||
|
"""Print document structure in a readable format."""
|
||||||
|
print(f"Title: {doc.title}\n")
|
||||||
|
print(f"Abstract: {doc.abstract}\n")
|
||||||
|
print("Table of Contents:")
|
||||||
|
|
||||||
|
def print_section(section: Section, indent: int = 0):
|
||||||
|
print(" " * indent + f"- {section.title}")
|
||||||
|
if section.content:
|
||||||
|
preview = section.content[:max_content_length]
|
||||||
|
if len(section.content) > max_content_length:
|
||||||
|
preview += "..."
|
||||||
|
print(" " * (indent + 1) + f"Content: {preview}")
|
||||||
|
for subsection in section.subsections:
|
||||||
|
print_section(subsection, indent + 1)
|
||||||
|
|
||||||
|
for section in doc.toc:
|
||||||
|
print_section(section)
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage:
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
# Test with a file
|
||||||
|
file_path = 'test_cache/2411.03663/neurips_2024.tex'
|
||||||
|
main_tex = read_tex_file(file_path)
|
||||||
|
|
||||||
|
# Parse main file
|
||||||
|
parser = EssayStructureParser()
|
||||||
|
main_doc = parser.parse(main_tex)
|
||||||
|
|
||||||
|
# Merge other documents
|
||||||
|
file_path_list = [
|
||||||
|
"test_cache/2411.03663/1_intro.tex",
|
||||||
|
"test_cache/2411.03663/0_abstract.tex",
|
||||||
|
"test_cache/2411.03663/2_pre.tex",
|
||||||
|
"test_cache/2411.03663/3_method.tex",
|
||||||
|
"test_cache/2411.03663/4_experiment.tex",
|
||||||
|
"test_cache/2411.03663/5_related_work.tex",
|
||||||
|
"test_cache/2411.03663/6_conclu.tex",
|
||||||
|
"test_cache/2411.03663/reference.bib"
|
||||||
|
]
|
||||||
|
for file_path in file_path_list:
|
||||||
|
tex_content = read_tex_file(file_path)
|
||||||
|
additional_doc = parser.parse(tex_content)
|
||||||
|
main_doc = main_doc.merge(additional_doc)
|
||||||
|
|
||||||
|
tree = main_doc.generate_toc_tree()
|
||||||
|
pretty_print_structure(main_doc)
|
||||||
@@ -0,0 +1,329 @@
|
|||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Set, Dict, Pattern, Optional, List, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class EnvType(Enum):
|
||||||
|
"""Environment classification types."""
|
||||||
|
PRESERVE = "preserve" # Preserve complete environment including commands
|
||||||
|
REMOVE = "remove" # Remove environment completely
|
||||||
|
EXTRACT = "extract" # Extract and clean content
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LatexConfig:
|
||||||
|
"""Configuration for LaTeX processing."""
|
||||||
|
preserve_envs: Set[str] = field(default_factory=lambda: {
|
||||||
|
# Math environments - preserve complete content
|
||||||
|
'equation', 'equation*', 'align', 'align*', 'displaymath',
|
||||||
|
'math', 'eqnarray', 'eqnarray*', 'gather', 'gather*',
|
||||||
|
'multline', 'multline*', 'flalign', 'flalign*',
|
||||||
|
'alignat', 'alignat*', 'cases', 'split', 'aligned',
|
||||||
|
# Tables and figures - preserve structure and content
|
||||||
|
'table', 'table*', 'tabular', 'tabularx', 'array', 'matrix',
|
||||||
|
'figure', 'figure*', 'subfigure', 'wrapfigure',
|
||||||
|
'minipage', 'tabbing', 'verbatim', 'longtable',
|
||||||
|
'sidewaystable', 'sidewaysfigure', 'floatrow',
|
||||||
|
# Arrays and matrices
|
||||||
|
'pmatrix', 'bmatrix', 'Bmatrix', 'vmatrix', 'Vmatrix',
|
||||||
|
'smallmatrix', 'array', 'matrix*', 'pmatrix*', 'bmatrix*',
|
||||||
|
# Algorithms and code
|
||||||
|
'algorithm', 'algorithmic', 'lstlisting', 'verbatim',
|
||||||
|
'minted', 'listing', 'algorithmic*', 'algorithm2e',
|
||||||
|
# Theorems and proofs
|
||||||
|
'theorem', 'proof', 'definition', 'lemma', 'corollary',
|
||||||
|
'proposition', 'example', 'remark', 'note', 'claim',
|
||||||
|
'axiom', 'property', 'assumption', 'conjecture', 'observation',
|
||||||
|
# Bibliography
|
||||||
|
'thebibliography', 'bibliography', 'references'
|
||||||
|
})
|
||||||
|
|
||||||
|
# 引用类命令的特殊处理配置
|
||||||
|
citation_commands: Set[str] = field(default_factory=lambda: {
|
||||||
|
# Basic citations
|
||||||
|
'cite', 'citep', 'citet', 'citeyear', 'citeauthor',
|
||||||
|
'citeyearpar', 'citetext', 'citenum',
|
||||||
|
# Natbib citations
|
||||||
|
'citefullauthor', 'citealp', 'citealt', 'citename',
|
||||||
|
'citepalias', 'citetalias', 'citetext',
|
||||||
|
# Cross-references
|
||||||
|
'ref', 'eqref', 'pageref', 'autoref', 'nameref', 'cref',
|
||||||
|
'Cref', 'vref', 'Vref', 'fref', 'pref',
|
||||||
|
# Hyperref
|
||||||
|
'hyperref', 'href', 'url',
|
||||||
|
# Labels
|
||||||
|
'label', 'tag'
|
||||||
|
})
|
||||||
|
|
||||||
|
preserve_commands: Set[str] = field(default_factory=lambda: {
|
||||||
|
# Text formatting
|
||||||
|
'emph', 'textbf', 'textit', 'underline', 'texttt', 'footnote',
|
||||||
|
'section', 'subsection', 'subsubsection', 'paragraph', 'part',
|
||||||
|
'chapter', 'title', 'author', 'date', 'thanks',
|
||||||
|
# Math operators and symbols
|
||||||
|
'frac', 'sum', 'int', 'prod', 'lim', 'sup', 'inf',
|
||||||
|
'partial', 'nabla', 'implies', 'iff', 'therefore',
|
||||||
|
'exists', 'forall', 'in', 'subset', 'subseteq',
|
||||||
|
# Greek letters and math symbols
|
||||||
|
'alpha', 'beta', 'gamma', 'delta', 'epsilon', 'zeta',
|
||||||
|
'eta', 'theta', 'iota', 'kappa', 'lambda', 'mu',
|
||||||
|
'nu', 'xi', 'pi', 'rho', 'sigma', 'tau',
|
||||||
|
'upsilon', 'phi', 'chi', 'psi', 'omega',
|
||||||
|
'Gamma', 'Delta', 'Theta', 'Lambda', 'Xi', 'Pi',
|
||||||
|
'Sigma', 'Upsilon', 'Phi', 'Psi', 'Omega',
|
||||||
|
# Math commands
|
||||||
|
'left', 'right', 'big', 'Big', 'bigg', 'Bigg',
|
||||||
|
'mathbf', 'mathit', 'mathsf', 'mathtt', 'mathbb',
|
||||||
|
'mathcal', 'mathfrak', 'mathscr', 'mathrm', 'mathop',
|
||||||
|
'operatorname', 'overline', 'underline', 'overbrace',
|
||||||
|
'underbrace', 'overset', 'underset', 'stackrel',
|
||||||
|
# Spacing and alignment
|
||||||
|
'quad', 'qquad', 'hspace', 'vspace', 'medskip',
|
||||||
|
'bigskip', 'smallskip', 'hfill', 'vfill', 'centering',
|
||||||
|
'raggedright', 'raggedleft'
|
||||||
|
})
|
||||||
|
|
||||||
|
remove_commands: Set[str] = field(default_factory=lambda: {
|
||||||
|
# Document setup
|
||||||
|
'documentclass', 'usepackage', 'input', 'include', 'includeonly',
|
||||||
|
'bibliographystyle', 'frontmatter', 'mainmatter',
|
||||||
|
'newtheorem', 'theoremstyle', 'proofname',
|
||||||
|
'newcommand', 'renewcommand', 'providecommand', 'DeclareMathOperator',
|
||||||
|
'newenvironment',
|
||||||
|
# Layout and spacing
|
||||||
|
'pagestyle', 'thispagestyle', 'newpage', 'clearpage',
|
||||||
|
'pagebreak', 'linebreak', 'newline', 'setlength',
|
||||||
|
'setcounter', 'addtocounter', 'makeatletter',
|
||||||
|
'makeatother', 'pagenumbering'
|
||||||
|
})
|
||||||
|
|
||||||
|
latex_chars: Dict[str, str] = field(default_factory=lambda: {
|
||||||
|
'~': ' ', '\\&': '&', '\\%': '%', '\\_': '_', '\\$': '$',
|
||||||
|
'\\#': '#', '\\{': '{', '\\}': '}', '``': '"', "''": '"',
|
||||||
|
'\\textbackslash': '\\', '\\ldots': '...', '\\dots': '...',
|
||||||
|
'\\textasciitilde': '~', '\\textasciicircum': '^'
|
||||||
|
})
|
||||||
|
|
||||||
|
# 保留原始格式的特殊命令模式
|
||||||
|
special_command_patterns: List[Tuple[str, str]] = field(default_factory=lambda: [
|
||||||
|
(r'\\cite\*?(?:\[[^\]]*\])?{([^}]+)}', r'\\cite{\1}'),
|
||||||
|
(r'\\ref\*?{([^}]+)}', r'\\ref{\1}'),
|
||||||
|
(r'\\label{([^}]+)}', r'\\label{\1}'),
|
||||||
|
(r'\\eqref{([^}]+)}', r'\\eqref{\1}'),
|
||||||
|
(r'\\autoref{([^}]+)}', r'\\autoref{\1}'),
|
||||||
|
(r'\\url{([^}]+)}', r'\\url{\1}'),
|
||||||
|
(r'\\href{([^}]+)}{([^}]+)}', r'\\href{\1}{\2}')
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
class LatexCleaner:
|
||||||
|
"""Enhanced LaTeX text cleaner that preserves mathematical content and citations."""
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[LatexConfig] = None):
|
||||||
|
self.config = config or LatexConfig()
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
# 初始化正则表达式缓存
|
||||||
|
self._regex_cache = {}
|
||||||
|
|
||||||
|
@lru_cache(maxsize=128)
|
||||||
|
def _get_env_pattern(self, env_name: str) -> Pattern:
|
||||||
|
"""Get cached regex pattern for environment matching."""
|
||||||
|
return re.compile(fr'\\begin{{{env_name}}}(.*?)\\end{{{env_name}}}', re.DOTALL)
|
||||||
|
|
||||||
|
def _get_env_type(self, env_name: str) -> EnvType:
|
||||||
|
"""Determine environment processing type."""
|
||||||
|
if env_name.rstrip('*') in {name.rstrip('*') for name in self.config.preserve_envs}:
|
||||||
|
return EnvType.PRESERVE
|
||||||
|
elif env_name in {'comment'}:
|
||||||
|
return EnvType.REMOVE
|
||||||
|
return EnvType.EXTRACT
|
||||||
|
|
||||||
|
def _preserve_special_commands(self, text: str) -> str:
|
||||||
|
"""Preserve special commands like citations and references with their complete structure."""
|
||||||
|
for pattern, replacement in self.config.special_command_patterns:
|
||||||
|
if pattern not in self._regex_cache:
|
||||||
|
self._regex_cache[pattern] = re.compile(pattern)
|
||||||
|
|
||||||
|
def replace_func(match):
|
||||||
|
# 保持原始命令格式
|
||||||
|
return match.group(0)
|
||||||
|
|
||||||
|
text = self._regex_cache[pattern].sub(replace_func, text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
def _process_environment(self, match: re.Match) -> str:
|
||||||
|
"""Process LaTeX environments while preserving complete content for special environments."""
|
||||||
|
try:
|
||||||
|
env_name = match.group(1)
|
||||||
|
content = match.group(2)
|
||||||
|
env_type = self._get_env_type(env_name)
|
||||||
|
|
||||||
|
if env_type == EnvType.PRESERVE:
|
||||||
|
# 完整保留环境内容
|
||||||
|
complete_env = match.group(0)
|
||||||
|
return f"\n[BEGIN_{env_name}]\n{complete_env}\n[END_{env_name}]\n"
|
||||||
|
elif env_type == EnvType.REMOVE:
|
||||||
|
return ' '
|
||||||
|
else:
|
||||||
|
# 处理嵌套环境
|
||||||
|
return self._clean_nested_environments(content)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error processing environment {match.group(1) if match else 'unknown'}: {e}")
|
||||||
|
return match.group(0)
|
||||||
|
|
||||||
|
def _preserve_inline_math(self, text: str) -> str:
|
||||||
|
"""Preserve complete inline math content."""
|
||||||
|
|
||||||
|
def preserve_math(match):
|
||||||
|
return f" {match.group(0)} "
|
||||||
|
|
||||||
|
patterns = [
|
||||||
|
(r'\$[^$]+\$', preserve_math),
|
||||||
|
(r'\\[\(\[].*?\\[\)\]]', preserve_math),
|
||||||
|
(r'\\begin{math}.*?\\end{math}', preserve_math)
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern, handler in patterns:
|
||||||
|
if pattern not in self._regex_cache:
|
||||||
|
self._regex_cache[pattern] = re.compile(pattern, re.DOTALL)
|
||||||
|
text = self._regex_cache[pattern].sub(handler, text)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def _clean_nested_environments(self, text: str) -> str:
|
||||||
|
"""Process nested environments recursively."""
|
||||||
|
pattern = r'\\begin{(\w+)}(.*?)\\end{\1}'
|
||||||
|
if pattern not in self._regex_cache:
|
||||||
|
self._regex_cache[pattern] = re.compile(pattern, re.DOTALL)
|
||||||
|
|
||||||
|
return self._regex_cache[pattern].sub(self._process_environment, text)
|
||||||
|
|
||||||
|
def _clean_commands(self, text: str) -> str:
|
||||||
|
"""Clean LaTeX commands while preserving important content."""
|
||||||
|
# 首先处理特殊命令
|
||||||
|
text = self._preserve_special_commands(text)
|
||||||
|
|
||||||
|
# 保留内联数学
|
||||||
|
text = self._preserve_inline_math(text)
|
||||||
|
|
||||||
|
# 移除指定的命令
|
||||||
|
for cmd in self.config.remove_commands:
|
||||||
|
if cmd not in self._regex_cache:
|
||||||
|
self._regex_cache[cmd] = re.compile(
|
||||||
|
fr'\\{cmd}\*?(?:\[.*?\])?(?:{{.*?}})*'
|
||||||
|
)
|
||||||
|
text = self._regex_cache[cmd].sub('', text)
|
||||||
|
|
||||||
|
# 处理带内容的命令
|
||||||
|
def handle_command(match: re.Match) -> str:
|
||||||
|
cmd = match.group(1).rstrip('*')
|
||||||
|
if cmd in self.config.preserve_commands or cmd in self.config.citation_commands:
|
||||||
|
return match.group(0) # 完整保留命令和内容
|
||||||
|
return ' '
|
||||||
|
|
||||||
|
if 'command_pattern' not in self._regex_cache:
|
||||||
|
self._regex_cache['command_pattern'] = re.compile(
|
||||||
|
r'\\(\w+)\*?(?:\[.*?\])?{(.*?)}'
|
||||||
|
)
|
||||||
|
|
||||||
|
text = self._regex_cache['command_pattern'].sub(handle_command, text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
def _normalize_text(self, text: str) -> str:
|
||||||
|
"""Normalize text while preserving special content markers."""
|
||||||
|
# 替换特殊字符
|
||||||
|
for char, replacement in self.config.latex_chars.items():
|
||||||
|
text = text.replace(char, replacement)
|
||||||
|
|
||||||
|
# 清理空白字符,同时保留环境标记
|
||||||
|
text = re.sub(r'\s+', ' ', text)
|
||||||
|
text = re.sub(r'\s*\[BEGIN_(\w+)\]\s*', r'\n[BEGIN_\1]\n', text)
|
||||||
|
text = re.sub(r'\s*\[END_(\w+)\]\s*', r'\n[END_\1]\n', text)
|
||||||
|
|
||||||
|
# 保持块级环境之间的分隔
|
||||||
|
text = re.sub(r'\n{3,}', '\n\n', text)
|
||||||
|
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
def clean_text(self, text: str) -> str:
|
||||||
|
"""Clean LaTeX text while preserving mathematical content, citations, and special environments."""
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 移除注释
|
||||||
|
text = re.sub(r'(?<!\\)%.*?(?=\n|$)', '', text, flags=re.MULTILINE)
|
||||||
|
|
||||||
|
# 处理环境
|
||||||
|
text = self._clean_nested_environments(text)
|
||||||
|
|
||||||
|
# 清理命令并规范化
|
||||||
|
text = self._clean_commands(text)
|
||||||
|
text = self._normalize_text(text)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error cleaning text: {e}")
|
||||||
|
return text # 发生错误时返回原始文本
|
||||||
|
|
||||||
|
|
||||||
|
def clean_latex_commands(text: str) -> str:
|
||||||
|
"""Convenience function for quick text cleaning with default config."""
|
||||||
|
cleaner = LatexCleaner()
|
||||||
|
return cleaner.clean_text(text)
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage:
|
||||||
|
if __name__ == "__main__":
|
||||||
|
text = r"""
|
||||||
|
\documentclass{article}
|
||||||
|
\begin{document}
|
||||||
|
|
||||||
|
\section{Introduction}
|
||||||
|
This is a reference to \cite{smith2020} and equation \eqref{eq:main}.
|
||||||
|
|
||||||
|
\begin{equation}\label{eq:main}
|
||||||
|
E = mc^2 \times \sum_{i=1}^{n} x_i
|
||||||
|
\end{equation}
|
||||||
|
|
||||||
|
See Figure \ref{fig:example} for details.
|
||||||
|
|
||||||
|
\begin{figure}
|
||||||
|
\includegraphics{image.png}
|
||||||
|
\caption{Example figure\label
|
||||||
|
\textbf{Important} result: $E=mc^2$ and
|
||||||
|
\begin{equation}
|
||||||
|
F = ma
|
||||||
|
\end{equation}
|
||||||
|
\label{sec:intro}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Custom configuration
|
||||||
|
config = LatexConfig(
|
||||||
|
preserve_envs={},
|
||||||
|
preserve_commands={'textbf', 'emph'},
|
||||||
|
latex_chars={'~': ' ', '\\&': '&'}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def read_tex_file(file_path):
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as file:
|
||||||
|
content = file.read()
|
||||||
|
return content
|
||||||
|
except FileNotFoundError:
|
||||||
|
return "文件未找到,请检查路径是否正确。"
|
||||||
|
except Exception as e:
|
||||||
|
return f"读取文件时发生错误: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
# 使用函数
|
||||||
|
file_path = 'test_cache/2411.03663/neurips_2024.tex'
|
||||||
|
content = read_tex_file(file_path)
|
||||||
|
cleaner = LatexCleaner(config)
|
||||||
|
text = cleaner.clean_text(text)
|
||||||
|
print(text)
|
||||||
@@ -0,0 +1,396 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LaTeXPatterns:
|
||||||
|
"""LaTeX模式存储类,用于集中管理所有LaTeX相关的正则表达式模式"""
|
||||||
|
special_envs = {
|
||||||
|
'math': [
|
||||||
|
# 基础数学环境
|
||||||
|
r'\\begin{(equation|align|gather|eqnarray|multline|flalign|alignat)\*?}.*?\\end{\1\*?}',
|
||||||
|
r'\$\$.*?\$\$',
|
||||||
|
r'\$[^$]+\$',
|
||||||
|
# 矩阵环境
|
||||||
|
r'\\begin{(matrix|pmatrix|bmatrix|Bmatrix|vmatrix|Vmatrix|smallmatrix)\*?}.*?\\end{\1\*?}',
|
||||||
|
# 数组环境
|
||||||
|
r'\\begin{(array|cases|aligned|gathered|split)\*?}.*?\\end{\1\*?}',
|
||||||
|
# 其他数学环境
|
||||||
|
r'\\begin{(subequations|math|displaymath)\*?}.*?\\end{\1\*?}'
|
||||||
|
],
|
||||||
|
|
||||||
|
'table': [
|
||||||
|
# 基础表格环境
|
||||||
|
r'\\begin{(table|tabular|tabularx|tabulary|longtable)\*?}.*?\\end{\1\*?}',
|
||||||
|
# 复杂表格环境
|
||||||
|
r'\\begin{(tabu|supertabular|xtabular|mpsupertabular)\*?}.*?\\end{\1\*?}',
|
||||||
|
# 自定义表格环境
|
||||||
|
r'\\begin{(threeparttable|tablefootnote)\*?}.*?\\end{\1\*?}',
|
||||||
|
# 表格注释环境
|
||||||
|
r'\\begin{(tablenotes)\*?}.*?\\end{\1\*?}'
|
||||||
|
],
|
||||||
|
|
||||||
|
'figure': [
|
||||||
|
# 图片环境
|
||||||
|
r'\\begin{figure\*?}.*?\\end{figure\*?}',
|
||||||
|
r'\\begin{(subfigure|wrapfigure)\*?}.*?\\end{\1\*?}',
|
||||||
|
# 图片插入命令
|
||||||
|
r'\\includegraphics(\[.*?\])?\{.*?\}',
|
||||||
|
# tikz 图形环境
|
||||||
|
r'\\begin{(tikzpicture|pgfpicture)\*?}.*?\\end{\1\*?}',
|
||||||
|
# 其他图形环境
|
||||||
|
r'\\begin{(picture|pspicture)\*?}.*?\\end{\1\*?}'
|
||||||
|
],
|
||||||
|
|
||||||
|
'algorithm': [
|
||||||
|
# 算法环境
|
||||||
|
r'\\begin{(algorithm|algorithmic|algorithm2e|algorithmicx)\*?}.*?\\end{\1\*?}',
|
||||||
|
r'\\begin{(lstlisting|verbatim|minted|listing)\*?}.*?\\end{\1\*?}',
|
||||||
|
# 代码块环境
|
||||||
|
r'\\begin{(code|verbatimtab|verbatimwrite)\*?}.*?\\end{\1\*?}',
|
||||||
|
# 伪代码环境
|
||||||
|
r'\\begin{(pseudocode|procedure)\*?}.*?\\end{\1\*?}'
|
||||||
|
],
|
||||||
|
|
||||||
|
'list': [
|
||||||
|
# 列表环境
|
||||||
|
r'\\begin{(itemize|enumerate|description)\*?}.*?\\end{\1\*?}',
|
||||||
|
r'\\begin{(list|compactlist|bulletlist)\*?}.*?\\end{\1\*?}',
|
||||||
|
# 自定义列表环境
|
||||||
|
r'\\begin{(tasks|todolist)\*?}.*?\\end{\1\*?}'
|
||||||
|
],
|
||||||
|
|
||||||
|
'theorem': [
|
||||||
|
# 定理类环境
|
||||||
|
r'\\begin{(theorem|lemma|proposition|corollary)\*?}.*?\\end{\1\*?}',
|
||||||
|
r'\\begin{(definition|example|proof|remark)\*?}.*?\\end{\1\*?}',
|
||||||
|
# 其他证明环境
|
||||||
|
r'\\begin{(axiom|property|assumption|conjecture)\*?}.*?\\end{\1\*?}'
|
||||||
|
],
|
||||||
|
|
||||||
|
'box': [
|
||||||
|
# 文本框环境
|
||||||
|
r'\\begin{(tcolorbox|mdframed|framed|shaded)\*?}.*?\\end{\1\*?}',
|
||||||
|
r'\\begin{(boxedminipage|shadowbox)\*?}.*?\\end{\1\*?}',
|
||||||
|
# 强调环境
|
||||||
|
r'\\begin{(important|warning|info|note)\*?}.*?\\end{\1\*?}'
|
||||||
|
],
|
||||||
|
|
||||||
|
'quote': [
|
||||||
|
# 引用环境
|
||||||
|
r'\\begin{(quote|quotation|verse|abstract)\*?}.*?\\end{\1\*?}',
|
||||||
|
r'\\begin{(excerpt|epigraph)\*?}.*?\\end{\1\*?}'
|
||||||
|
],
|
||||||
|
|
||||||
|
'bibliography': [
|
||||||
|
# 参考文献环境
|
||||||
|
r'\\begin{(thebibliography|bibliography)\*?}.*?\\end{\1\*?}',
|
||||||
|
r'\\begin{(biblist|citelist)\*?}.*?\\end{\1\*?}'
|
||||||
|
],
|
||||||
|
|
||||||
|
'index': [
|
||||||
|
# 索引环境
|
||||||
|
r'\\begin{(theindex|printindex)\*?}.*?\\end{\1\*?}',
|
||||||
|
r'\\begin{(glossary|acronym)\*?}.*?\\end{\1\*?}'
|
||||||
|
]
|
||||||
|
}
|
||||||
|
# 章节模式
|
||||||
|
section_patterns = [
|
||||||
|
# 基础章节命令
|
||||||
|
r'\\chapter\{([^}]+)\}',
|
||||||
|
r'\\section\{([^}]+)\}',
|
||||||
|
r'\\subsection\{([^}]+)\}',
|
||||||
|
r'\\subsubsection\{([^}]+)\}',
|
||||||
|
r'\\paragraph\{([^}]+)\}',
|
||||||
|
r'\\subparagraph\{([^}]+)\}',
|
||||||
|
|
||||||
|
# 带星号的变体(不编号)
|
||||||
|
r'\\chapter\*\{([^}]+)\}',
|
||||||
|
r'\\section\*\{([^}]+)\}',
|
||||||
|
r'\\subsection\*\{([^}]+)\}',
|
||||||
|
r'\\subsubsection\*\{([^}]+)\}',
|
||||||
|
r'\\paragraph\*\{([^}]+)\}',
|
||||||
|
r'\\subparagraph\*\{([^}]+)\}',
|
||||||
|
|
||||||
|
# 特殊章节
|
||||||
|
r'\\part\{([^}]+)\}',
|
||||||
|
r'\\part\*\{([^}]+)\}',
|
||||||
|
r'\\appendix\{([^}]+)\}',
|
||||||
|
|
||||||
|
# 前言部分
|
||||||
|
r'\\frontmatter\{([^}]+)\}',
|
||||||
|
r'\\mainmatter\{([^}]+)\}',
|
||||||
|
r'\\backmatter\{([^}]+)\}',
|
||||||
|
|
||||||
|
# 目录相关
|
||||||
|
r'\\tableofcontents',
|
||||||
|
r'\\listoffigures',
|
||||||
|
r'\\listoftables',
|
||||||
|
|
||||||
|
# 自定义章节命令
|
||||||
|
r'\\addchap\{([^}]+)\}', # KOMA-Script类
|
||||||
|
r'\\addsec\{([^}]+)\}', # KOMA-Script类
|
||||||
|
r'\\minisec\{([^}]+)\}', # KOMA-Script类
|
||||||
|
|
||||||
|
# 带可选参数的章节命令
|
||||||
|
r'\\chapter\[([^]]+)\]\{([^}]+)\}',
|
||||||
|
r'\\section\[([^]]+)\]\{([^}]+)\}',
|
||||||
|
r'\\subsection\[([^]]+)\]\{([^}]+)\}'
|
||||||
|
]
|
||||||
|
|
||||||
|
# 包含模式
|
||||||
|
include_patterns = [
|
||||||
|
r'\\(input|include|subfile)\{([^}]+)\}'
|
||||||
|
]
|
||||||
|
|
||||||
|
metadata_patterns = {
|
||||||
|
# 标题相关
|
||||||
|
'title': [
|
||||||
|
r'\\title\{([^}]+)\}',
|
||||||
|
r'\\Title\{([^}]+)\}',
|
||||||
|
r'\\doctitle\{([^}]+)\}',
|
||||||
|
r'\\subtitle\{([^}]+)\}',
|
||||||
|
r'\\chapter\*?\{([^}]+)\}', # 第一章可能作为标题
|
||||||
|
r'\\maketitle\s*\\section\*?\{([^}]+)\}' # 第一节可能作为标题
|
||||||
|
],
|
||||||
|
|
||||||
|
# 摘要相关
|
||||||
|
'abstract': [
|
||||||
|
r'\\begin{abstract}(.*?)\\end{abstract}',
|
||||||
|
r'\\abstract\{([^}]+)\}',
|
||||||
|
r'\\begin{摘要}(.*?)\\end{摘要}',
|
||||||
|
r'\\begin{Summary}(.*?)\\end{Summary}',
|
||||||
|
r'\\begin{synopsis}(.*?)\\end{synopsis}',
|
||||||
|
r'\\begin{abstracten}(.*?)\\end{abstracten}' # 英文摘要
|
||||||
|
],
|
||||||
|
|
||||||
|
# 作者信息
|
||||||
|
'author': [
|
||||||
|
r'\\author\{([^}]+)\}',
|
||||||
|
r'\\Author\{([^}]+)\}',
|
||||||
|
r'\\authorinfo\{([^}]+)\}',
|
||||||
|
r'\\authors\{([^}]+)\}',
|
||||||
|
r'\\author\[([^]]+)\]\{([^}]+)\}', # 带附加信息的作者
|
||||||
|
r'\\begin{authors}(.*?)\\end{authors}'
|
||||||
|
],
|
||||||
|
|
||||||
|
# 日期相关
|
||||||
|
'date': [
|
||||||
|
r'\\date\{([^}]+)\}',
|
||||||
|
r'\\Date\{([^}]+)\}',
|
||||||
|
r'\\submitdate\{([^}]+)\}',
|
||||||
|
r'\\publishdate\{([^}]+)\}',
|
||||||
|
r'\\revisiondate\{([^}]+)\}'
|
||||||
|
],
|
||||||
|
|
||||||
|
# 关键词
|
||||||
|
'keywords': [
|
||||||
|
r'\\keywords\{([^}]+)\}',
|
||||||
|
r'\\Keywords\{([^}]+)\}',
|
||||||
|
r'\\begin{keywords}(.*?)\\end{keywords}',
|
||||||
|
r'\\key\{([^}]+)\}',
|
||||||
|
r'\\begin{关键词}(.*?)\\end{关键词}'
|
||||||
|
],
|
||||||
|
|
||||||
|
# 机构/单位
|
||||||
|
'institution': [
|
||||||
|
r'\\institute\{([^}]+)\}',
|
||||||
|
r'\\institution\{([^}]+)\}',
|
||||||
|
r'\\affiliation\{([^}]+)\}',
|
||||||
|
r'\\organization\{([^}]+)\}',
|
||||||
|
r'\\department\{([^}]+)\}'
|
||||||
|
],
|
||||||
|
|
||||||
|
# 学科/主题
|
||||||
|
'subject': [
|
||||||
|
r'\\subject\{([^}]+)\}',
|
||||||
|
r'\\Subject\{([^}]+)\}',
|
||||||
|
r'\\field\{([^}]+)\}',
|
||||||
|
r'\\discipline\{([^}]+)\}'
|
||||||
|
],
|
||||||
|
|
||||||
|
# 版本信息
|
||||||
|
'version': [
|
||||||
|
r'\\version\{([^}]+)\}',
|
||||||
|
r'\\revision\{([^}]+)\}',
|
||||||
|
r'\\release\{([^}]+)\}'
|
||||||
|
],
|
||||||
|
|
||||||
|
# 许可证/版权
|
||||||
|
'license': [
|
||||||
|
r'\\license\{([^}]+)\}',
|
||||||
|
r'\\copyright\{([^}]+)\}',
|
||||||
|
r'\\begin{license}(.*?)\\end{license}'
|
||||||
|
],
|
||||||
|
|
||||||
|
# 联系方式
|
||||||
|
'contact': [
|
||||||
|
r'\\email\{([^}]+)\}',
|
||||||
|
r'\\phone\{([^}]+)\}',
|
||||||
|
r'\\address\{([^}]+)\}',
|
||||||
|
r'\\contact\{([^}]+)\}'
|
||||||
|
],
|
||||||
|
|
||||||
|
# 致谢
|
||||||
|
'acknowledgments': [
|
||||||
|
r'\\begin{acknowledgments}(.*?)\\end{acknowledgments}',
|
||||||
|
r'\\acknowledgments\{([^}]+)\}',
|
||||||
|
r'\\thanks\{([^}]+)\}',
|
||||||
|
r'\\begin{致谢}(.*?)\\end{致谢}'
|
||||||
|
],
|
||||||
|
|
||||||
|
# 项目/基金
|
||||||
|
'funding': [
|
||||||
|
r'\\funding\{([^}]+)\}',
|
||||||
|
r'\\grant\{([^}]+)\}',
|
||||||
|
r'\\project\{([^}]+)\}',
|
||||||
|
r'\\support\{([^}]+)\}'
|
||||||
|
],
|
||||||
|
|
||||||
|
# 分类号/编号
|
||||||
|
'classification': [
|
||||||
|
r'\\classification\{([^}]+)\}',
|
||||||
|
r'\\serialnumber\{([^}]+)\}',
|
||||||
|
r'\\id\{([^}]+)\}',
|
||||||
|
r'\\doi\{([^}]+)\}'
|
||||||
|
],
|
||||||
|
|
||||||
|
# 语言
|
||||||
|
'language': [
|
||||||
|
r'\\documentlanguage\{([^}]+)\}',
|
||||||
|
r'\\lang\{([^}]+)\}',
|
||||||
|
r'\\language\{([^}]+)\}'
|
||||||
|
]
|
||||||
|
}
|
||||||
|
latex_only_patterns = {
|
||||||
|
# 文档类和包引入
|
||||||
|
r'\\documentclass(\[.*?\])?\{.*?\}',
|
||||||
|
r'\\usepackage(\[.*?\])?\{.*?\}',
|
||||||
|
# 常见的文档设置命令
|
||||||
|
r'\\setlength\{.*?\}\{.*?\}',
|
||||||
|
r'\\newcommand\{.*?\}(\[.*?\])?\{.*?\}',
|
||||||
|
r'\\renewcommand\{.*?\}(\[.*?\])?\{.*?\}',
|
||||||
|
r'\\definecolor\{.*?\}\{.*?\}\{.*?\}',
|
||||||
|
# 页面设置相关
|
||||||
|
r'\\pagestyle\{.*?\}',
|
||||||
|
r'\\thispagestyle\{.*?\}',
|
||||||
|
# 其他常见的设置命令
|
||||||
|
r'\\bibliographystyle\{.*?\}',
|
||||||
|
r'\\bibliography\{.*?\}',
|
||||||
|
r'\\setcounter\{.*?\}\{.*?\}',
|
||||||
|
# 字体和文本设置命令
|
||||||
|
r'\\makeFNbottom',
|
||||||
|
r'\\@setfontsize\\[A-Z]+\{.*?\}\{.*?\}', # 匹配字体大小设置
|
||||||
|
r'\\renewcommand\\[A-Z]+\{\\@setfontsize\\[A-Z]+\{.*?\}\{.*?\}\}',
|
||||||
|
r'\\renewcommand\{?\\thefootnote\}?\{\\fnsymbol\{footnote\}\}',
|
||||||
|
r'\\renewcommand\\footnoterule\{.*?\}',
|
||||||
|
r'\\color\{.*?\}',
|
||||||
|
|
||||||
|
# 页面和节标题设置
|
||||||
|
r'\\setcounter\{secnumdepth\}\{.*?\}',
|
||||||
|
r'\\renewcommand\\@biblabel\[.*?\]\{.*?\}',
|
||||||
|
r'\\renewcommand\\@makefntext\[.*?\](\{.*?\})*',
|
||||||
|
r'\\renewcommand\{?\\figurename\}?\{.*?\}',
|
||||||
|
|
||||||
|
# 字体样式设置
|
||||||
|
r'\\sectionfont\{.*?\}',
|
||||||
|
r'\\subsectionfont\{.*?\}',
|
||||||
|
r'\\subsubsectionfont\{.*?\}',
|
||||||
|
|
||||||
|
# 间距和布局设置
|
||||||
|
r'\\setstretch\{.*?\}',
|
||||||
|
r'\\setlength\{\\skip\\footins\}\{.*?\}',
|
||||||
|
r'\\setlength\{\\footnotesep\}\{.*?\}',
|
||||||
|
r'\\setlength\{\\jot\}\{.*?\}',
|
||||||
|
r'\\hrule\s+width\s+.*?\s+height\s+.*?',
|
||||||
|
|
||||||
|
# makeatletter 和 makeatother
|
||||||
|
r'\\makeatletter\s*',
|
||||||
|
r'\\makeatother\s*',
|
||||||
|
r'\\footnotetext\{[^}]*\$\^{[^}]*}\$[^}]*\}', # 带有上标的脚注
|
||||||
|
# r'\\footnotetext\{[^}]*\}', # 普通脚注
|
||||||
|
# r'\\footnotetext\{.*?(?:\$\^{.*?}\$)?.*?(?:email\s*:\s*[^}]*)?.*?\}', # 带有邮箱的脚注
|
||||||
|
# r'\\footnotetext\{.*?(?:ESI|DOI).*?\}', # 带有 DOI 或 ESI 引用的脚注
|
||||||
|
# 文档结构命令
|
||||||
|
r'\\begin\{document\}',
|
||||||
|
r'\\end\{document\}',
|
||||||
|
r'\\maketitle',
|
||||||
|
r'\\printbibliography',
|
||||||
|
r'\\newpage',
|
||||||
|
|
||||||
|
# 输入文件命令
|
||||||
|
r'\\input\{[^}]*\}',
|
||||||
|
r'\\input\{.*?\.tex\}', # 特别匹配 .tex 后缀的输入
|
||||||
|
|
||||||
|
# 脚注相关
|
||||||
|
# r'\\footnotetext\[\d+\]\{[^}]*\}', # 带编号的脚注
|
||||||
|
|
||||||
|
# 致谢环境
|
||||||
|
r'\\begin\{ack\}',
|
||||||
|
r'\\end\{ack\}',
|
||||||
|
r'\\begin\{ack\}[^\n]*(?:\n.*?)*?\\end\{ack\}', # 匹配整个致谢环境及其内容
|
||||||
|
|
||||||
|
# 其他文档控制命令
|
||||||
|
r'\\renewcommand\{\\thefootnote\}\{\\fnsymbol\{footnote\}\}',
|
||||||
|
}
|
||||||
|
math_envs = [
|
||||||
|
# 基础数学环境
|
||||||
|
(r'\\begin{equation\*?}.*?\\end{equation\*?}', 'equation'), # 单行公式
|
||||||
|
(r'\\begin{align\*?}.*?\\end{align\*?}', 'align'), # 多行对齐公式
|
||||||
|
(r'\\begin{gather\*?}.*?\\end{gather\*?}', 'gather'), # 多行居中公式
|
||||||
|
(r'\$\$.*?\$\$', 'display'), # 行间公式
|
||||||
|
(r'\$.*?\$', 'inline'), # 行内公式
|
||||||
|
|
||||||
|
# 矩阵环境
|
||||||
|
(r'\\begin{matrix}.*?\\end{matrix}', 'matrix'), # 基础矩阵
|
||||||
|
(r'\\begin{pmatrix}.*?\\end{pmatrix}', 'pmatrix'), # 圆括号矩阵
|
||||||
|
(r'\\begin{bmatrix}.*?\\end{bmatrix}', 'bmatrix'), # 方括号矩阵
|
||||||
|
(r'\\begin{vmatrix}.*?\\end{vmatrix}', 'vmatrix'), # 竖线矩阵
|
||||||
|
(r'\\begin{Vmatrix}.*?\\end{Vmatrix}', 'Vmatrix'), # 双竖线矩阵
|
||||||
|
(r'\\begin{smallmatrix}.*?\\end{smallmatrix}', 'smallmatrix'), # 小号矩阵
|
||||||
|
|
||||||
|
# 数组环境
|
||||||
|
(r'\\begin{array}.*?\\end{array}', 'array'), # 数组
|
||||||
|
(r'\\begin{cases}.*?\\end{cases}', 'cases'), # 分段函数
|
||||||
|
|
||||||
|
# 多行公式环境
|
||||||
|
(r'\\begin{multline\*?}.*?\\end{multline\*?}', 'multline'), # 多行单个公式
|
||||||
|
(r'\\begin{split}.*?\\end{split}', 'split'), # 拆分长公式
|
||||||
|
(r'\\begin{alignat\*?}.*?\\end{alignat\*?}', 'alignat'), # 对齐环境带间距控制
|
||||||
|
(r'\\begin{flalign\*?}.*?\\end{flalign\*?}', 'flalign'), # 完全左对齐
|
||||||
|
|
||||||
|
# 特殊数学环境
|
||||||
|
(r'\\begin{subequations}.*?\\end{subequations}', 'subequations'), # 子公式编号
|
||||||
|
(r'\\begin{gathered}.*?\\end{gathered}', 'gathered'), # 居中对齐组
|
||||||
|
(r'\\begin{aligned}.*?\\end{aligned}', 'aligned'), # 内部对齐组
|
||||||
|
|
||||||
|
# 定理类环境
|
||||||
|
(r'\\begin{theorem}.*?\\end{theorem}', 'theorem'), # 定理
|
||||||
|
(r'\\begin{lemma}.*?\\end{lemma}', 'lemma'), # 引理
|
||||||
|
(r'\\begin{proof}.*?\\end{proof}', 'proof'), # 证明
|
||||||
|
|
||||||
|
# 数学模式中的表格环境
|
||||||
|
(r'\\begin{tabular}.*?\\end{tabular}', 'tabular'), # 表格
|
||||||
|
(r'\\begin{array}.*?\\end{array}', 'array'), # 数组
|
||||||
|
|
||||||
|
# 其他专业数学环境
|
||||||
|
(r'\\begin{CD}.*?\\end{CD}', 'CD'), # 交换图
|
||||||
|
(r'\\begin{boxed}.*?\\end{boxed}', 'boxed'), # 带框公式
|
||||||
|
(r'\\begin{empheq}.*?\\end{empheq}', 'empheq'), # 强调公式
|
||||||
|
|
||||||
|
# 化学方程式环境 (需要加载 mhchem 包)
|
||||||
|
(r'\\begin{reaction}.*?\\end{reaction}', 'reaction'), # 化学反应式
|
||||||
|
(r'\\ce\{.*?\}', 'chemequation'), # 化学方程式
|
||||||
|
|
||||||
|
# 物理单位环境 (需要加载 siunitx 包)
|
||||||
|
(r'\\SI\{.*?\}\{.*?\}', 'SI'), # 物理单位
|
||||||
|
(r'\\si\{.*?\}', 'si'), # 单位
|
||||||
|
|
||||||
|
# 补充环境
|
||||||
|
(r'\\begin{equation\+}.*?\\end{equation\+}', 'equation+'), # breqn包的自动换行公式
|
||||||
|
(r'\\begin{dmath\*?}.*?\\end{dmath\*?}', 'dmath'), # breqn包的显示数学模式
|
||||||
|
(r'\\begin{dgroup\*?}.*?\\end{dgroup\*?}', 'dgroup'), # breqn包的公式组
|
||||||
|
]
|
||||||
|
|
||||||
|
# 示例使用函数
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
@@ -0,0 +1,416 @@
|
|||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from copy import deepcopy
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import List, Dict, Tuple
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SectionLevel(Enum):
|
||||||
|
CHAPTER = 0
|
||||||
|
SECTION = 1
|
||||||
|
SUBSECTION = 2
|
||||||
|
SUBSUBSECTION = 3
|
||||||
|
PARAGRAPH = 4
|
||||||
|
SUBPARAGRAPH = 5
|
||||||
|
|
||||||
|
def __lt__(self, other):
|
||||||
|
if not isinstance(other, SectionLevel):
|
||||||
|
return NotImplemented
|
||||||
|
return self.value < other.value
|
||||||
|
|
||||||
|
def __le__(self, other):
|
||||||
|
if not isinstance(other, SectionLevel):
|
||||||
|
return NotImplemented
|
||||||
|
return self.value <= other.value
|
||||||
|
|
||||||
|
def __gt__(self, other):
|
||||||
|
if not isinstance(other, SectionLevel):
|
||||||
|
return NotImplemented
|
||||||
|
return self.value > other.value
|
||||||
|
|
||||||
|
def __ge__(self, other):
|
||||||
|
if not isinstance(other, SectionLevel):
|
||||||
|
return NotImplemented
|
||||||
|
return self.value >= other.value
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Section:
|
||||||
|
level: SectionLevel
|
||||||
|
title: str
|
||||||
|
content: str = ''
|
||||||
|
bibliography: str = ''
|
||||||
|
subsections: List['Section'] = field(default_factory=list)
|
||||||
|
|
||||||
|
def merge(self, other: 'Section') -> 'Section':
|
||||||
|
"""Merge this section with another section."""
|
||||||
|
if self.title != other.title or self.level != other.level:
|
||||||
|
raise ValueError("Can only merge sections with same title and level")
|
||||||
|
|
||||||
|
merged = deepcopy(self)
|
||||||
|
merged.content = self._merge_content(self.content, other.content)
|
||||||
|
|
||||||
|
# Create subsections lookup for efficient merging
|
||||||
|
subsections_map = {s.title: s for s in merged.subsections}
|
||||||
|
|
||||||
|
for other_subsection in other.subsections:
|
||||||
|
if other_subsection.title in subsections_map:
|
||||||
|
# Merge existing subsection
|
||||||
|
idx = next(i for i, s in enumerate(merged.subsections)
|
||||||
|
if s.title == other_subsection.title)
|
||||||
|
merged.subsections[idx] = merged.subsections[idx].merge(other_subsection)
|
||||||
|
else:
|
||||||
|
# Add new subsection
|
||||||
|
merged.subsections.append(deepcopy(other_subsection))
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_content(content1: str, content2: str) -> str:
|
||||||
|
"""Merge content strings intelligently."""
|
||||||
|
if not content1:
|
||||||
|
return content2
|
||||||
|
if not content2:
|
||||||
|
return content1
|
||||||
|
# Combine non-empty contents with a separator
|
||||||
|
return f"{content1}\n\n{content2}"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LatexEnvironment:
|
||||||
|
"""表示LaTeX环境的数据类"""
|
||||||
|
name: str
|
||||||
|
start: int
|
||||||
|
end: int
|
||||||
|
content: str
|
||||||
|
raw: str
|
||||||
|
|
||||||
|
|
||||||
|
class EnhancedSectionExtractor:
|
||||||
|
"""Enhanced section extractor with comprehensive content handling and hierarchy management."""
|
||||||
|
|
||||||
|
def __init__(self, preserve_environments: bool = True):
|
||||||
|
"""
|
||||||
|
初始化Section提取器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preserve_environments: 是否保留特定环境(如equation, figure等)的原始LaTeX代码
|
||||||
|
"""
|
||||||
|
self.preserve_environments = preserve_environments
|
||||||
|
|
||||||
|
# Section级别定义
|
||||||
|
self.section_levels = {
|
||||||
|
'chapter': SectionLevel.CHAPTER,
|
||||||
|
'section': SectionLevel.SECTION,
|
||||||
|
'subsection': SectionLevel.SUBSECTION,
|
||||||
|
'subsubsection': SectionLevel.SUBSUBSECTION,
|
||||||
|
'paragraph': SectionLevel.PARAGRAPH,
|
||||||
|
'subparagraph': SectionLevel.SUBPARAGRAPH
|
||||||
|
}
|
||||||
|
|
||||||
|
# 需要保留的环境类型
|
||||||
|
self.important_environments = {
|
||||||
|
'equation', 'equation*', 'align', 'align*',
|
||||||
|
'figure', 'table', 'algorithm', 'algorithmic',
|
||||||
|
'definition', 'theorem', 'lemma', 'proof',
|
||||||
|
'itemize', 'enumerate', 'description'
|
||||||
|
}
|
||||||
|
|
||||||
|
# 改进的section pattern
|
||||||
|
self.section_pattern = (
|
||||||
|
r'\\(?P<type>chapter|section|subsection|subsubsection|paragraph|subparagraph)'
|
||||||
|
r'\*?' # Optional star
|
||||||
|
r'(?:\[(?P<short>.*?)\])?' # Optional short title
|
||||||
|
r'{(?P<title>(?:[^{}]|\{[^{}]*\})*?)}' # Main title with nested braces support
|
||||||
|
)
|
||||||
|
|
||||||
|
# 环境匹配模式
|
||||||
|
self.environment_pattern = (
|
||||||
|
r'\\begin{(?P<env_name>[^}]+)}'
|
||||||
|
r'(?P<env_content>.*?)'
|
||||||
|
r'\\end{(?P=env_name)}'
|
||||||
|
)
|
||||||
|
|
||||||
|
def _find_environments(self, content: str) -> List[LatexEnvironment]:
|
||||||
|
"""
|
||||||
|
查找文档中的所有LaTeX环境。
|
||||||
|
支持嵌套环境的处理。
|
||||||
|
"""
|
||||||
|
environments = []
|
||||||
|
stack = []
|
||||||
|
|
||||||
|
# 使用正则表达式查找所有begin和end标记
|
||||||
|
begin_pattern = r'\\begin{([^}]+)}'
|
||||||
|
end_pattern = r'\\end{([^}]+)}'
|
||||||
|
|
||||||
|
# 组合模式来同时匹配begin和end
|
||||||
|
tokens = []
|
||||||
|
for match in re.finditer(fr'({begin_pattern})|({end_pattern})', content):
|
||||||
|
if match.group(1): # begin标记
|
||||||
|
tokens.append(('begin', match.group(1), match.start()))
|
||||||
|
else: # end标记
|
||||||
|
tokens.append(('end', match.group(2), match.start()))
|
||||||
|
|
||||||
|
# 处理环境嵌套
|
||||||
|
for token_type, env_name, pos in tokens:
|
||||||
|
if token_type == 'begin':
|
||||||
|
stack.append((env_name, pos))
|
||||||
|
elif token_type == 'end' and stack:
|
||||||
|
if stack[-1][0] == env_name:
|
||||||
|
start_env_name, start_pos = stack.pop()
|
||||||
|
env_content = content[start_pos:pos]
|
||||||
|
raw_content = content[start_pos:pos + len('\\end{' + env_name + '}')]
|
||||||
|
|
||||||
|
if start_env_name in self.important_environments:
|
||||||
|
environments.append(LatexEnvironment(
|
||||||
|
name=start_env_name,
|
||||||
|
start=start_pos,
|
||||||
|
end=pos + len('\\end{' + env_name + '}'),
|
||||||
|
content=env_content,
|
||||||
|
raw=raw_content
|
||||||
|
))
|
||||||
|
|
||||||
|
return sorted(environments, key=lambda x: x.start)
|
||||||
|
|
||||||
|
def _protect_environments(self, content: str) -> Tuple[str, Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
保护重要的LaTeX环境,用占位符替换它们。
|
||||||
|
返回处理后的内容和恢复映射。
|
||||||
|
"""
|
||||||
|
environments = self._find_environments(content)
|
||||||
|
replacements = {}
|
||||||
|
|
||||||
|
# 从后向前替换,避免位置改变的问题
|
||||||
|
for env in reversed(environments):
|
||||||
|
if env.name in self.important_environments:
|
||||||
|
placeholder = f'__ENV_{len(replacements)}__'
|
||||||
|
replacements[placeholder] = env.raw
|
||||||
|
content = content[:env.start] + placeholder + content[env.end:]
|
||||||
|
|
||||||
|
return content, replacements
|
||||||
|
|
||||||
|
def _restore_environments(self, content: str, replacements: Dict[str, str]) -> str:
|
||||||
|
"""
|
||||||
|
恢复之前保护的环境。
|
||||||
|
"""
|
||||||
|
for placeholder, original in replacements.items():
|
||||||
|
content = content.replace(placeholder, original)
|
||||||
|
return content
|
||||||
|
|
||||||
|
def extract(self, content: str) -> List[Section]:
|
||||||
|
"""
|
||||||
|
从LaTeX文档中提取sections及其内容。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: LaTeX文档内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Section]: 提取的section列表,包含层次结构
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 预处理:保护重要环境
|
||||||
|
if self.preserve_environments:
|
||||||
|
content, env_replacements = self._protect_environments(content)
|
||||||
|
|
||||||
|
# 查找所有sections
|
||||||
|
sections = self._find_all_sections(content)
|
||||||
|
if not sections:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 处理sections
|
||||||
|
root_sections = self._process_sections(content, sections)
|
||||||
|
|
||||||
|
# 如果需要,恢复环境
|
||||||
|
if self.preserve_environments:
|
||||||
|
for section in self._traverse_sections(root_sections):
|
||||||
|
section.content = self._restore_environments(section.content, env_replacements)
|
||||||
|
|
||||||
|
return root_sections
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error extracting sections: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _find_all_sections(self, content: str) -> List[dict]:
|
||||||
|
"""查找所有section命令及其位置。"""
|
||||||
|
sections = []
|
||||||
|
|
||||||
|
for match in re.finditer(self.section_pattern, content, re.DOTALL | re.MULTILINE):
|
||||||
|
section_type = match.group('type').lower()
|
||||||
|
if section_type not in self.section_levels:
|
||||||
|
continue
|
||||||
|
|
||||||
|
section = {
|
||||||
|
'type': section_type,
|
||||||
|
'level': self.section_levels[section_type],
|
||||||
|
'title': self._clean_title(match.group('title')),
|
||||||
|
'start': match.start(),
|
||||||
|
'command_end': match.end(),
|
||||||
|
}
|
||||||
|
sections.append(section)
|
||||||
|
|
||||||
|
return sorted(sections, key=lambda x: x['start'])
|
||||||
|
|
||||||
|
def _process_sections(self, content: str, sections: List[dict]) -> List[Section]:
|
||||||
|
"""处理sections以构建层次结构和提取内容。"""
|
||||||
|
# 计算content范围
|
||||||
|
self._calculate_content_ranges(content, sections)
|
||||||
|
|
||||||
|
# 构建层次结构
|
||||||
|
root_sections = []
|
||||||
|
section_stack = []
|
||||||
|
|
||||||
|
for section_info in sections:
|
||||||
|
new_section = Section(
|
||||||
|
level=section_info['level'],
|
||||||
|
title=section_info['title'],
|
||||||
|
content=self._extract_clean_content(content, section_info),
|
||||||
|
subsections=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 调整堆栈以找到正确的父section
|
||||||
|
while section_stack and section_stack[-1].level.value >= new_section.level.value:
|
||||||
|
section_stack.pop()
|
||||||
|
|
||||||
|
if section_stack:
|
||||||
|
section_stack[-1].subsections.append(new_section)
|
||||||
|
else:
|
||||||
|
root_sections.append(new_section)
|
||||||
|
|
||||||
|
section_stack.append(new_section)
|
||||||
|
|
||||||
|
return root_sections
|
||||||
|
|
||||||
|
def _calculate_content_ranges(self, content: str, sections: List[dict]):
|
||||||
|
for i, current in enumerate(sections):
|
||||||
|
content_start = current['command_end']
|
||||||
|
|
||||||
|
# 找到下一个section(无论什么级别)
|
||||||
|
content_end = len(content)
|
||||||
|
for next_section in sections[i + 1:]:
|
||||||
|
content_end = next_section['start']
|
||||||
|
break
|
||||||
|
|
||||||
|
current['content_range'] = (content_start, content_end)
|
||||||
|
|
||||||
|
def _calculate_content_ranges_with_subsection_content(self, content: str, sections: List[dict]):
|
||||||
|
"""为每个section计算内容范围。"""
|
||||||
|
for i, current in enumerate(sections):
|
||||||
|
content_start = current['command_end']
|
||||||
|
|
||||||
|
# 找到下一个同级或更高级的section
|
||||||
|
content_end = len(content)
|
||||||
|
for next_section in sections[i + 1:]:
|
||||||
|
if next_section['level'] <= current['level']:
|
||||||
|
content_end = next_section['start']
|
||||||
|
break
|
||||||
|
|
||||||
|
current['content_range'] = (content_start, content_end)
|
||||||
|
|
||||||
|
def _extract_clean_content(self, content: str, section_info: dict) -> str:
|
||||||
|
"""提取并清理section内容。"""
|
||||||
|
start, end = section_info['content_range']
|
||||||
|
raw_content = content[start:end]
|
||||||
|
|
||||||
|
# 清理内容
|
||||||
|
clean_content = self._clean_content(raw_content)
|
||||||
|
return clean_content
|
||||||
|
|
||||||
|
def _clean_content(self, content: str) -> str:
|
||||||
|
"""清理LaTeX内容同时保留重要信息。"""
|
||||||
|
# 移除注释
|
||||||
|
content = re.sub(r'(?<!\\)%.*?\n', '\n', content)
|
||||||
|
|
||||||
|
# LaTeX命令处理规则
|
||||||
|
replacements = [
|
||||||
|
# 保留引用
|
||||||
|
(r'\\cite(?:\[.*?\])?{(.*?)}', r'[cite:\1]'),
|
||||||
|
# 保留脚注
|
||||||
|
(r'\\footnote{(.*?)}', r'[footnote:\1]'),
|
||||||
|
# 处理引用
|
||||||
|
(r'\\ref{(.*?)}', r'[ref:\1]'),
|
||||||
|
# 保留URL
|
||||||
|
(r'\\url{(.*?)}', r'[url:\1]'),
|
||||||
|
# 保留超链接
|
||||||
|
(r'\\href{(.*?)}{(.*?)}', r'[\2](\1)'),
|
||||||
|
# 处理文本格式命令
|
||||||
|
(r'\\(?:textbf|textit|emph){(.*?)}', r'\1'),
|
||||||
|
# 保留特殊字符
|
||||||
|
(r'\\([&%$#_{}])', r'\1'),
|
||||||
|
]
|
||||||
|
|
||||||
|
# 应用所有替换规则
|
||||||
|
for pattern, replacement in replacements:
|
||||||
|
content = re.sub(pattern, replacement, content, flags=re.DOTALL)
|
||||||
|
|
||||||
|
# 清理多余的空白
|
||||||
|
content = re.sub(r'\n\s*\n', '\n\n', content)
|
||||||
|
return content.strip()
|
||||||
|
|
||||||
|
def _clean_title(self, title: str) -> str:
|
||||||
|
"""清理section标题。"""
|
||||||
|
# 处理嵌套的花括号
|
||||||
|
while '{' in title:
|
||||||
|
title = re.sub(r'{([^{}]*)}', r'\1', title)
|
||||||
|
|
||||||
|
# 处理LaTeX命令
|
||||||
|
title = re.sub(r'\\[a-zA-Z]+(?:\[.*?\])?{(.*?)}', r'\1', title)
|
||||||
|
title = re.sub(r'\\([&%$#_{}])', r'\1', title)
|
||||||
|
|
||||||
|
return title.strip()
|
||||||
|
|
||||||
|
def _traverse_sections(self, sections: List[Section]) -> List[Section]:
|
||||||
|
"""遍历所有sections(包括子sections)。"""
|
||||||
|
result = []
|
||||||
|
for section in sections:
|
||||||
|
result.append(section)
|
||||||
|
result.extend(self._traverse_sections(section.subsections))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def test_enhanced_extractor():
|
||||||
|
"""使用复杂的测试用例测试提取器。"""
|
||||||
|
test_content = r"""
|
||||||
|
\section{Complex Examples}
|
||||||
|
Here's a complex section with various environments.
|
||||||
|
|
||||||
|
\begin{equation}
|
||||||
|
E = mc^2
|
||||||
|
\end{equation}
|
||||||
|
|
||||||
|
\subsection{Nested Environments}
|
||||||
|
This subsection has nested environments.
|
||||||
|
|
||||||
|
\begin{figure}
|
||||||
|
\begin{equation*}
|
||||||
|
f(x) = \int_0^x g(t) dt
|
||||||
|
\end{equation*}
|
||||||
|
\caption{A nested equation in a figure}
|
||||||
|
\end{figure}
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
extractor = EnhancedSectionExtractor()
|
||||||
|
sections = extractor.extract(test_content)
|
||||||
|
|
||||||
|
def print_section(section, level=0):
|
||||||
|
print("\n" + " " * level + f"[{section.level.name}] {section.title}")
|
||||||
|
if section.content:
|
||||||
|
content_preview = section.content[:150] + "..." if len(section.content) > 150 else section.content
|
||||||
|
print(" " * (level + 1) + f"Content: {content_preview}")
|
||||||
|
for subsection in section.subsections:
|
||||||
|
print_section(subsection, level + 1)
|
||||||
|
|
||||||
|
print("\nExtracted Section Structure:")
|
||||||
|
for section in sections:
|
||||||
|
print_section(section)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_enhanced_extractor()
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SectionFragment:
|
||||||
|
"""Arxiv论文片段数据类"""
|
||||||
|
title: str # 论文标题
|
||||||
|
authors: str
|
||||||
|
abstract: str # 论文摘要
|
||||||
|
catalogs: str # 文章各章节的目录结构
|
||||||
|
arxiv_id: str = "" # 添加 arxiv_id 属性
|
||||||
|
current_section: str = "Introduction" # 当前片段所属的section或者subsection或者孙subsubsection名字
|
||||||
|
content: str = '' # 当前片段的内容
|
||||||
|
bibliography: str = '' # 当前片段的参考文献
|
||||||
@@ -0,0 +1,266 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Set, Optional
|
||||||
|
|
||||||
|
from crazy_functions.rag_fns.arxiv_fns.latex_patterns import LaTeXPatterns
|
||||||
|
|
||||||
|
|
||||||
|
class TexUtils:
|
||||||
|
"""TeX文档处理器类"""
|
||||||
|
|
||||||
|
def __init__(self, ):
|
||||||
|
"""
|
||||||
|
初始化TeX处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
char_range: 字符数范围(最小值, 最大值)
|
||||||
|
"""
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 初始化LaTeX环境和命令模式
|
||||||
|
self._init_patterns()
|
||||||
|
self.latex_only_patterns = LaTeXPatterns.latex_only_patterns
|
||||||
|
|
||||||
|
def _init_patterns(self):
|
||||||
|
"""初始化LaTeX模式匹配规则"""
|
||||||
|
# 特殊环境模式
|
||||||
|
self.special_envs = LaTeXPatterns.special_envs
|
||||||
|
# 章节模式
|
||||||
|
self.section_patterns = LaTeXPatterns.section_patterns
|
||||||
|
# 包含模式
|
||||||
|
self.include_patterns = LaTeXPatterns.include_patterns
|
||||||
|
# 元数据模式
|
||||||
|
self.metadata_patterns = LaTeXPatterns.metadata_patterns
|
||||||
|
|
||||||
|
def read_file(self, file_path: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
读取TeX文件内容,支持多种编码
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: 文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: 文件内容或None
|
||||||
|
"""
|
||||||
|
encodings = ['utf-8', 'latin1', 'gbk', 'gb2312', 'ascii']
|
||||||
|
for encoding in encodings:
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r', encoding=encoding) as f:
|
||||||
|
return f.read()
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.logger.warning(f"Failed to read {file_path} with all encodings")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def find_main_tex_file(self, directory: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
查找主TeX文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
directory: 目录路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: 主文件路径或None
|
||||||
|
"""
|
||||||
|
tex_files = list(Path(directory).rglob("*.tex"))
|
||||||
|
if not tex_files:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 按优先级查找
|
||||||
|
for tex_file in tex_files:
|
||||||
|
content = self.read_file(str(tex_file))
|
||||||
|
if content:
|
||||||
|
if r'\documentclass' in content:
|
||||||
|
return str(tex_file)
|
||||||
|
if tex_file.name.lower() == 'main.tex':
|
||||||
|
return str(tex_file)
|
||||||
|
|
||||||
|
# 返回最大的tex文件
|
||||||
|
return str(max(tex_files, key=lambda x: x.stat().st_size))
|
||||||
|
|
||||||
|
def resolve_includes(self, tex_file: str, processed: Set[str] = None) -> List[str]:
|
||||||
|
"""
|
||||||
|
解析TeX文件中的include引用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tex_file: TeX文件路径
|
||||||
|
processed: 已处理的文件集合
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: 相关文件路径列表
|
||||||
|
"""
|
||||||
|
if processed is None:
|
||||||
|
processed = set()
|
||||||
|
|
||||||
|
if tex_file in processed:
|
||||||
|
return []
|
||||||
|
|
||||||
|
processed.add(tex_file)
|
||||||
|
result = [tex_file]
|
||||||
|
content = self.read_file(tex_file)
|
||||||
|
|
||||||
|
if not content:
|
||||||
|
return result
|
||||||
|
|
||||||
|
base_dir = Path(tex_file).parent
|
||||||
|
for pattern in self.include_patterns:
|
||||||
|
for match in re.finditer(pattern, content):
|
||||||
|
included_file = match.group(2)
|
||||||
|
if not included_file.endswith('.tex'):
|
||||||
|
included_file += '.tex'
|
||||||
|
|
||||||
|
full_path = str(base_dir / included_file)
|
||||||
|
if os.path.exists(full_path) and full_path not in processed:
|
||||||
|
result.extend(self.resolve_includes(full_path, processed))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def resolve_references(self, tex_file: str, path_dir: str = None) -> str:
|
||||||
|
"""
|
||||||
|
解析TeX文件中的参考文献引用,返回所有引用文献的内容,只保留title、author和journal字段。
|
||||||
|
如果在tex_file目录下没找到bib文件,会在path_dir中查找。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tex_file: TeX文件路径
|
||||||
|
path_dir: 额外的参考文献搜索路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 所有参考文献内容的字符串,只包含特定字段,不同参考文献之间用空行分隔
|
||||||
|
"""
|
||||||
|
all_references = [] # 存储所有参考文献内容
|
||||||
|
content = self.read_file(tex_file)
|
||||||
|
|
||||||
|
if not content:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 扩展参考文献引用的模式
|
||||||
|
bib_patterns = [
|
||||||
|
r'\\bibliography\{([^}]+)\}',
|
||||||
|
r'\\addbibresource\{([^}]+)\}',
|
||||||
|
r'\\bibliographyfile\{([^}]+)\}',
|
||||||
|
r'\\begin\{thebibliography\}',
|
||||||
|
r'\\bibinput\{([^}]+)\}',
|
||||||
|
r'\\newrefsection\{([^}]+)\}'
|
||||||
|
]
|
||||||
|
|
||||||
|
base_dir = Path(tex_file).parent
|
||||||
|
found_in_tex_dir = False
|
||||||
|
|
||||||
|
# 首先在tex文件目录下查找显式引用的bib文件
|
||||||
|
for pattern in bib_patterns:
|
||||||
|
for match in re.finditer(pattern, content):
|
||||||
|
if not match.groups():
|
||||||
|
continue
|
||||||
|
|
||||||
|
bib_files = match.group(1).split(',')
|
||||||
|
for bib_file in bib_files:
|
||||||
|
bib_file = bib_file.strip()
|
||||||
|
if not bib_file.endswith('.bib'):
|
||||||
|
bib_file += '.bib'
|
||||||
|
|
||||||
|
full_path = str(base_dir / bib_file)
|
||||||
|
if os.path.exists(full_path):
|
||||||
|
found_in_tex_dir = True
|
||||||
|
bib_content = self.read_file(full_path)
|
||||||
|
if bib_content:
|
||||||
|
processed_refs = self._process_bib_content(bib_content)
|
||||||
|
all_references.extend(processed_refs)
|
||||||
|
|
||||||
|
# 如果在tex文件目录下没找到bib文件,且提供了额外搜索路径
|
||||||
|
if not found_in_tex_dir and path_dir:
|
||||||
|
search_dir = Path(path_dir)
|
||||||
|
try:
|
||||||
|
for bib_path in search_dir.glob('**/*.bib'):
|
||||||
|
bib_content = self.read_file(str(bib_path))
|
||||||
|
if bib_content:
|
||||||
|
processed_refs = self._process_bib_content(bib_content)
|
||||||
|
all_references.extend(processed_refs)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error searching in path_dir: {e}")
|
||||||
|
|
||||||
|
# 合并所有参考文献内容,用空行分隔
|
||||||
|
return "\n\n".join(all_references)
|
||||||
|
|
||||||
|
def _process_bib_content(self, content: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
处理bib文件内容,提取每个参考文献的特定字段
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: bib文件内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: 处理后的参考文献列表
|
||||||
|
"""
|
||||||
|
processed_refs = []
|
||||||
|
# 匹配完整的参考文献条目
|
||||||
|
ref_pattern = r'@\w+\{[^@]*\}'
|
||||||
|
# 匹配参考文献类型和键值
|
||||||
|
entry_start_pattern = r'@(\w+)\{([^,]*?),'
|
||||||
|
# 匹配字段
|
||||||
|
field_pattern = r'(\w+)\s*=\s*\{([^}]*)\}'
|
||||||
|
|
||||||
|
# 查找所有参考文献条目
|
||||||
|
for ref_match in re.finditer(ref_pattern, content, re.DOTALL):
|
||||||
|
ref_content = ref_match.group(0)
|
||||||
|
|
||||||
|
# 获取参考文献类型和键值
|
||||||
|
entry_match = re.match(entry_start_pattern, ref_content)
|
||||||
|
if not entry_match:
|
||||||
|
continue
|
||||||
|
|
||||||
|
entry_type, cite_key = entry_match.groups()
|
||||||
|
|
||||||
|
# 提取需要的字段
|
||||||
|
needed_fields = {'title': None, 'author': None, 'journal': None}
|
||||||
|
for field_match in re.finditer(field_pattern, ref_content):
|
||||||
|
field_name, field_value = field_match.groups()
|
||||||
|
field_name = field_name.lower()
|
||||||
|
if field_name in needed_fields:
|
||||||
|
needed_fields[field_name] = field_value.strip()
|
||||||
|
|
||||||
|
# 构建新的参考文献条目
|
||||||
|
if any(needed_fields.values()): # 如果至少有一个需要的字段
|
||||||
|
ref_lines = [f"@{entry_type}{{{cite_key},"]
|
||||||
|
for field_name, field_value in needed_fields.items():
|
||||||
|
if field_value:
|
||||||
|
ref_lines.append(f" {field_name}={{{field_value}}},")
|
||||||
|
ref_lines[-1] = ref_lines[-1][:-1] # 移除最后一个逗号
|
||||||
|
ref_lines.append("}")
|
||||||
|
|
||||||
|
processed_refs.append("\n".join(ref_lines))
|
||||||
|
|
||||||
|
return processed_refs
|
||||||
|
|
||||||
|
def _extract_inline_references(self, content: str) -> str:
|
||||||
|
"""
|
||||||
|
从tex文件内容中提取直接写在文件中的参考文献
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: tex文件内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 提取的参考文献内容,如果没有找到则返回空字符串
|
||||||
|
"""
|
||||||
|
# 查找参考文献环境
|
||||||
|
bib_start = r'\\begin\{thebibliography\}'
|
||||||
|
bib_end = r'\\end\{thebibliography\}'
|
||||||
|
|
||||||
|
start_match = re.search(bib_start, content)
|
||||||
|
end_match = re.search(bib_end, content)
|
||||||
|
|
||||||
|
if start_match and end_match:
|
||||||
|
return content[start_match.start():end_match.end()]
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _preprocess_content(self, content: str) -> str:
|
||||||
|
"""预处理TeX内容"""
|
||||||
|
# 移除注释
|
||||||
|
content = re.sub(r'(?m)%.*$', '', content)
|
||||||
|
# 规范化空白字符
|
||||||
|
# content = re.sub(r'\s+', ' ', content)
|
||||||
|
content = re.sub(r'\n\s*\n', '\n\n', content)
|
||||||
|
return content.strip()
|
||||||
@@ -1,17 +1,13 @@
|
|||||||
import llama_index
|
|
||||||
import os
|
|
||||||
import atexit
|
import atexit
|
||||||
from loguru import logger
|
import os
|
||||||
from typing import List
|
|
||||||
from llama_index.core import Document
|
from llama_index.core import Document
|
||||||
from llama_index.core.schema import TextNode
|
|
||||||
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
|
|
||||||
from shared_utils.connect_void_terminal import get_chat_default_kwargs
|
|
||||||
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
|
|
||||||
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
|
|
||||||
from llama_index.core.ingestion import run_transformations
|
from llama_index.core.ingestion import run_transformations
|
||||||
from llama_index.core import PromptTemplate
|
from llama_index.core.schema import TextNode
|
||||||
from llama_index.core.response_synthesizers import TreeSummarize
|
from loguru import logger
|
||||||
|
|
||||||
|
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
|
||||||
|
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
|
||||||
|
|
||||||
DEFAULT_QUERY_GENERATION_PROMPT = """\
|
DEFAULT_QUERY_GENERATION_PROMPT = """\
|
||||||
Now, you have context information as below:
|
Now, you have context information as below:
|
||||||
@@ -72,11 +68,60 @@ class LlamaIndexRagWorker(SaveLoad):
|
|||||||
self.embed_model = OpenAiEmbeddingModel(llm_kwargs)
|
self.embed_model = OpenAiEmbeddingModel(llm_kwargs)
|
||||||
self.user_name = user_name
|
self.user_name = user_name
|
||||||
self.checkpoint_dir = checkpoint_dir
|
self.checkpoint_dir = checkpoint_dir
|
||||||
if auto_load_checkpoint:
|
|
||||||
self.vs_index = self.load_from_checkpoint(checkpoint_dir)
|
# 确保checkpoint_dir存在
|
||||||
|
if checkpoint_dir:
|
||||||
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
|
|
||||||
|
logger.info(f"Initializing LlamaIndexRagWorker with checkpoint_dir: {checkpoint_dir}")
|
||||||
|
|
||||||
|
# 初始化向量存储
|
||||||
|
if auto_load_checkpoint and self.does_checkpoint_exist():
|
||||||
|
logger.info("Loading existing vector store from checkpoint")
|
||||||
|
self.vs_index = self.load_from_checkpoint()
|
||||||
else:
|
else:
|
||||||
self.vs_index = self.create_new_vs(checkpoint_dir)
|
logger.info("Creating new vector store")
|
||||||
atexit.register(lambda: self.save_to_checkpoint(checkpoint_dir))
|
self.vs_index = self.create_new_vs()
|
||||||
|
|
||||||
|
# 注册退出时保存
|
||||||
|
atexit.register(self.save_to_checkpoint)
|
||||||
|
|
||||||
|
def add_text_to_vector_store(self, text: str) -> None:
|
||||||
|
"""添加文本到向量存储"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Adding text to vector store (first 100 chars): {text[:100]}...")
|
||||||
|
node = TextNode(text=text)
|
||||||
|
nodes = run_transformations(
|
||||||
|
[node],
|
||||||
|
self.vs_index._transformations,
|
||||||
|
show_progress=True
|
||||||
|
)
|
||||||
|
self.vs_index.insert_nodes(nodes)
|
||||||
|
|
||||||
|
# 立即保存
|
||||||
|
self.save_to_checkpoint()
|
||||||
|
|
||||||
|
if self.debug_mode:
|
||||||
|
self.inspect_vector_store()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error adding text to vector store: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def save_to_checkpoint(self, checkpoint_dir=None):
|
||||||
|
"""保存向量存储到检查点"""
|
||||||
|
try:
|
||||||
|
if checkpoint_dir is None:
|
||||||
|
checkpoint_dir = self.checkpoint_dir
|
||||||
|
logger.info(f'Saving vector store to: {checkpoint_dir}')
|
||||||
|
if checkpoint_dir:
|
||||||
|
self.vs_index.storage_context.persist(persist_dir=checkpoint_dir)
|
||||||
|
logger.info('Vector store saved successfully')
|
||||||
|
else:
|
||||||
|
logger.warning('No checkpoint directory specified, skipping save')
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving checkpoint: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
def assign_embedding_model(self):
|
def assign_embedding_model(self):
|
||||||
pass
|
pass
|
||||||
@@ -85,7 +130,7 @@ class LlamaIndexRagWorker(SaveLoad):
|
|||||||
# This function is for debugging
|
# This function is for debugging
|
||||||
self.vs_index.storage_context.index_store.to_dict()
|
self.vs_index.storage_context.index_store.to_dict()
|
||||||
docstore = self.vs_index.storage_context.docstore.docs
|
docstore = self.vs_index.storage_context.docstore.docs
|
||||||
vector_store_preview = "\n".join([ f"{_id} | {tn.text}" for _id, tn in docstore.items() ])
|
vector_store_preview = "\n".join([f"{_id} | {tn.text}" for _id, tn in docstore.items()])
|
||||||
logger.info('\n++ --------inspect_vector_store begin--------')
|
logger.info('\n++ --------inspect_vector_store begin--------')
|
||||||
logger.info(vector_store_preview)
|
logger.info(vector_store_preview)
|
||||||
logger.info('oo --------inspect_vector_store end--------')
|
logger.info('oo --------inspect_vector_store end--------')
|
||||||
@@ -94,20 +139,10 @@ class LlamaIndexRagWorker(SaveLoad):
|
|||||||
def add_documents_to_vector_store(self, document_list):
|
def add_documents_to_vector_store(self, document_list):
|
||||||
documents = [Document(text=t) for t in document_list]
|
documents = [Document(text=t) for t in document_list]
|
||||||
documents_nodes = run_transformations(
|
documents_nodes = run_transformations(
|
||||||
documents, # type: ignore
|
documents, # type: ignore
|
||||||
self.vs_index._transformations,
|
self.vs_index._transformations,
|
||||||
show_progress=True
|
show_progress=True
|
||||||
)
|
)
|
||||||
self.vs_index.insert_nodes(documents_nodes)
|
|
||||||
if self.debug_mode: self.inspect_vector_store()
|
|
||||||
|
|
||||||
def add_text_to_vector_store(self, text):
|
|
||||||
node = TextNode(text=text)
|
|
||||||
documents_nodes = run_transformations(
|
|
||||||
[node],
|
|
||||||
self.vs_index._transformations,
|
|
||||||
show_progress=True
|
|
||||||
)
|
|
||||||
self.vs_index.insert_nodes(documents_nodes)
|
self.vs_index.insert_nodes(documents_nodes)
|
||||||
if self.debug_mode: self.inspect_vector_store()
|
if self.debug_mode: self.inspect_vector_store()
|
||||||
|
|
||||||
@@ -123,8 +158,8 @@ class LlamaIndexRagWorker(SaveLoad):
|
|||||||
def build_prompt(self, query, nodes):
|
def build_prompt(self, query, nodes):
|
||||||
context_str = self.generate_node_array_preview(nodes)
|
context_str = self.generate_node_array_preview(nodes)
|
||||||
return DEFAULT_QUERY_GENERATION_PROMPT.format(context_str=context_str, query_str=query)
|
return DEFAULT_QUERY_GENERATION_PROMPT.format(context_str=context_str, query_str=query)
|
||||||
|
|
||||||
def generate_node_array_preview(self, nodes):
|
def generate_node_array_preview(self, nodes):
|
||||||
buf = "\n".join(([f"(No.{i+1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)]))
|
buf = "\n".join(([f"(No.{i + 1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)]))
|
||||||
if self.debug_mode: logger.info(buf)
|
if self.debug_mode: logger.info(buf)
|
||||||
return buf
|
return buf
|
||||||
|
|||||||
@@ -1,20 +1,14 @@
|
|||||||
import llama_index
|
|
||||||
import os
|
|
||||||
import atexit
|
import atexit
|
||||||
|
import os
|
||||||
from typing import List
|
from typing import List
|
||||||
from loguru import logger
|
|
||||||
from llama_index.core import Document
|
|
||||||
from llama_index.core.schema import TextNode
|
|
||||||
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
|
|
||||||
from shared_utils.connect_void_terminal import get_chat_default_kwargs
|
|
||||||
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
|
|
||||||
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
|
|
||||||
from llama_index.core.ingestion import run_transformations
|
|
||||||
from llama_index.core import PromptTemplate
|
|
||||||
from llama_index.core.response_synthesizers import TreeSummarize
|
|
||||||
from llama_index.core import StorageContext
|
from llama_index.core import StorageContext
|
||||||
from llama_index.vector_stores.milvus import MilvusVectorStore
|
from llama_index.vector_stores.milvus import MilvusVectorStore
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
|
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
|
||||||
|
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
|
||||||
|
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
|
||||||
|
|
||||||
DEFAULT_QUERY_GENERATION_PROMPT = """\
|
DEFAULT_QUERY_GENERATION_PROMPT = """\
|
||||||
Now, you have context information as below:
|
Now, you have context information as below:
|
||||||
@@ -65,17 +59,19 @@ class MilvusSaveLoad():
|
|||||||
|
|
||||||
def create_new_vs(self, checkpoint_dir, overwrite=False):
|
def create_new_vs(self, checkpoint_dir, overwrite=False):
|
||||||
vector_store = MilvusVectorStore(
|
vector_store = MilvusVectorStore(
|
||||||
uri=os.path.join(checkpoint_dir, "milvus_demo.db"),
|
uri=os.path.join(checkpoint_dir, "milvus_demo.db"),
|
||||||
dim=self.embed_model.embedding_dimension(),
|
dim=self.embed_model.embedding_dimension(),
|
||||||
overwrite=overwrite
|
overwrite=overwrite
|
||||||
)
|
)
|
||||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||||
index = GptacVectorStoreIndex.default_vector_store(storage_context=storage_context, embed_model=self.embed_model)
|
index = GptacVectorStoreIndex.default_vector_store(storage_context=storage_context,
|
||||||
|
embed_model=self.embed_model)
|
||||||
return index
|
return index
|
||||||
|
|
||||||
def purge(self):
|
def purge(self):
|
||||||
self.vs_index = self.create_new_vs(self.checkpoint_dir, overwrite=True)
|
self.vs_index = self.create_new_vs(self.checkpoint_dir, overwrite=True)
|
||||||
|
|
||||||
|
|
||||||
class MilvusRagWorker(MilvusSaveLoad, LlamaIndexRagWorker):
|
class MilvusRagWorker(MilvusSaveLoad, LlamaIndexRagWorker):
|
||||||
|
|
||||||
def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None:
|
def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None:
|
||||||
@@ -96,7 +92,7 @@ class MilvusRagWorker(MilvusSaveLoad, LlamaIndexRagWorker):
|
|||||||
docstore = self.vs_index.storage_context.docstore.docs
|
docstore = self.vs_index.storage_context.docstore.docs
|
||||||
if not docstore.items():
|
if not docstore.items():
|
||||||
raise ValueError("cannot inspect")
|
raise ValueError("cannot inspect")
|
||||||
vector_store_preview = "\n".join([ f"{_id} | {tn.text}" for _id, tn in docstore.items() ])
|
vector_store_preview = "\n".join([f"{_id} | {tn.text}" for _id, tn in docstore.items()])
|
||||||
except:
|
except:
|
||||||
dummy_retrieve_res: List["NodeWithScore"] = self.vs_index.as_retriever().retrieve(' ')
|
dummy_retrieve_res: List["NodeWithScore"] = self.vs_index.as_retriever().retrieve(' ')
|
||||||
vector_store_preview = "\n".join(
|
vector_store_preview = "\n".join(
|
||||||
|
|||||||
@@ -0,0 +1,47 @@
|
|||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
supports_format = ['.csv', '.docx', '.doc', '.epub', '.ipynb', '.mbox', '.md', '.pdf', '.txt', '.ppt',
|
||||||
|
'.pptm', '.pptx', '.py', '.xls', '.xlsx', '.html', '.json', '.xml', '.yaml', '.yml', '.m']
|
||||||
|
|
||||||
|
|
||||||
|
def read_docx_doc(file_path):
|
||||||
|
if file_path.split(".")[-1] == "docx":
|
||||||
|
from docx import Document
|
||||||
|
doc = Document(file_path)
|
||||||
|
file_content = "\n".join([para.text for para in doc.paragraphs])
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
import win32com.client
|
||||||
|
word = win32com.client.Dispatch("Word.Application")
|
||||||
|
word.visible = False
|
||||||
|
# 打开文件
|
||||||
|
doc = word.Documents.Open(os.getcwd() + '/' + file_path)
|
||||||
|
# file_content = doc.Content.Text
|
||||||
|
doc = word.ActiveDocument
|
||||||
|
file_content = doc.Range().Text
|
||||||
|
doc.Close()
|
||||||
|
word.Quit()
|
||||||
|
except:
|
||||||
|
raise RuntimeError('请先将.doc文档转换为.docx文档。')
|
||||||
|
return file_content
|
||||||
|
|
||||||
|
|
||||||
|
# 修改后的 extract_text 函数,结合 SimpleDirectoryReader 和自定义解析逻辑
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def extract_text(file_path):
|
||||||
|
_, ext = os.path.splitext(file_path.lower())
|
||||||
|
|
||||||
|
# 使用 SimpleDirectoryReader 处理它支持的文件格式
|
||||||
|
if ext in ['.docx', '.doc']:
|
||||||
|
return read_docx_doc(file_path)
|
||||||
|
try:
|
||||||
|
reader = SimpleDirectoryReader(input_files=[file_path])
|
||||||
|
documents = reader.load_data()
|
||||||
|
if len(documents) > 0:
|
||||||
|
return documents[0].text
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
from llama_index.core import VectorStoreIndex
|
from typing import Any, List, Optional
|
||||||
from typing import Any, List, Optional
|
|
||||||
|
|
||||||
|
from llama_index.core import VectorStoreIndex
|
||||||
from llama_index.core.callbacks.base import CallbackManager
|
from llama_index.core.callbacks.base import CallbackManager
|
||||||
from llama_index.core.schema import TransformComponent
|
from llama_index.core.schema import TransformComponent
|
||||||
from llama_index.core.service_context import ServiceContext
|
from llama_index.core.service_context import ServiceContext
|
||||||
@@ -13,18 +13,18 @@ from llama_index.core.storage.storage_context import StorageContext
|
|||||||
|
|
||||||
|
|
||||||
class GptacVectorStoreIndex(VectorStoreIndex):
|
class GptacVectorStoreIndex(VectorStoreIndex):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_vector_store(
|
def default_vector_store(
|
||||||
cls,
|
cls,
|
||||||
storage_context: Optional[StorageContext] = None,
|
storage_context: Optional[StorageContext] = None,
|
||||||
show_progress: bool = False,
|
show_progress: bool = False,
|
||||||
callback_manager: Optional[CallbackManager] = None,
|
callback_manager: Optional[CallbackManager] = None,
|
||||||
transformations: Optional[List[TransformComponent]] = None,
|
transformations: Optional[List[TransformComponent]] = None,
|
||||||
# deprecated
|
# deprecated
|
||||||
service_context: Optional[ServiceContext] = None,
|
service_context: Optional[ServiceContext] = None,
|
||||||
embed_model = None,
|
embed_model=None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
"""Create index from documents.
|
"""Create index from documents.
|
||||||
|
|
||||||
@@ -36,15 +36,14 @@ class GptacVectorStoreIndex(VectorStoreIndex):
|
|||||||
storage_context = storage_context or StorageContext.from_defaults()
|
storage_context = storage_context or StorageContext.from_defaults()
|
||||||
docstore = storage_context.docstore
|
docstore = storage_context.docstore
|
||||||
callback_manager = (
|
callback_manager = (
|
||||||
callback_manager
|
callback_manager
|
||||||
or callback_manager_from_settings_or_context(Settings, service_context)
|
or callback_manager_from_settings_or_context(Settings, service_context)
|
||||||
)
|
)
|
||||||
transformations = transformations or transformations_from_settings_or_context(
|
transformations = transformations or transformations_from_settings_or_context(
|
||||||
Settings, service_context
|
Settings, service_context
|
||||||
)
|
)
|
||||||
|
|
||||||
with callback_manager.as_trace("index_construction"):
|
with callback_manager.as_trace("index_construction"):
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
nodes=[],
|
nodes=[],
|
||||||
storage_context=storage_context,
|
storage_context=storage_context,
|
||||||
@@ -55,4 +54,3 @@ class GptacVectorStoreIndex(VectorStoreIndex):
|
|||||||
embed_model=embed_model,
|
embed_model=embed_model,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,127 +0,0 @@
|
|||||||
from toolbox import update_ui
|
|
||||||
from toolbox import CatchException, report_exception
|
|
||||||
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
|
||||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
|
||||||
fast_debug = False
|
|
||||||
|
|
||||||
|
|
||||||
def 解析docx(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
|
||||||
import time, os
|
|
||||||
# pip install python-docx 用于docx格式,跨平台
|
|
||||||
# pip install pywin32 用于doc格式,仅支持Win平台
|
|
||||||
for index, fp in enumerate(file_manifest):
|
|
||||||
if fp.split(".")[-1] == "docx":
|
|
||||||
from docx import Document
|
|
||||||
doc = Document(fp)
|
|
||||||
file_content = "\n".join([para.text for para in doc.paragraphs])
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
import win32com.client
|
|
||||||
word = win32com.client.Dispatch("Word.Application")
|
|
||||||
word.visible = False
|
|
||||||
# 打开文件
|
|
||||||
doc = word.Documents.Open(os.getcwd() + '/' + fp)
|
|
||||||
# file_content = doc.Content.Text
|
|
||||||
doc = word.ActiveDocument
|
|
||||||
file_content = doc.Range().Text
|
|
||||||
doc.Close()
|
|
||||||
word.Quit()
|
|
||||||
except:
|
|
||||||
raise RuntimeError('请先将.doc文档转换为.docx文档。')
|
|
||||||
|
|
||||||
# private_upload里面的文件名在解压zip后容易出现乱码(rar和7z格式正常),故可以只分析文章内容,不输入文件名
|
|
||||||
from crazy_functions.pdf_fns.breakdown_txt import breakdown_text_to_satisfy_token_limit
|
|
||||||
from request_llms.bridge_all import model_info
|
|
||||||
max_token = model_info[llm_kwargs['llm_model']]['max_token']
|
|
||||||
TOKEN_LIMIT_PER_FRAGMENT = max_token * 3 // 4
|
|
||||||
paper_fragments = breakdown_text_to_satisfy_token_limit(txt=file_content, limit=TOKEN_LIMIT_PER_FRAGMENT, llm_model=llm_kwargs['llm_model'])
|
|
||||||
this_paper_history = []
|
|
||||||
for i, paper_frag in enumerate(paper_fragments):
|
|
||||||
i_say = f'请对下面的文章片段用中文做概述,文件名是{os.path.relpath(fp, project_folder)},文章内容是 ```{paper_frag}```'
|
|
||||||
i_say_show_user = f'请对下面的文章片段做概述: {os.path.abspath(fp)}的第{i+1}/{len(paper_fragments)}个片段。'
|
|
||||||
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
|
||||||
inputs=i_say,
|
|
||||||
inputs_show_user=i_say_show_user,
|
|
||||||
llm_kwargs=llm_kwargs,
|
|
||||||
chatbot=chatbot,
|
|
||||||
history=[],
|
|
||||||
sys_prompt="总结文章。"
|
|
||||||
)
|
|
||||||
|
|
||||||
chatbot[-1] = (i_say_show_user, gpt_say)
|
|
||||||
history.extend([i_say_show_user,gpt_say])
|
|
||||||
this_paper_history.extend([i_say_show_user,gpt_say])
|
|
||||||
|
|
||||||
# 已经对该文章的所有片段总结完毕,如果文章被切分了,
|
|
||||||
if len(paper_fragments) > 1:
|
|
||||||
i_say = f"根据以上的对话,总结文章{os.path.abspath(fp)}的主要内容。"
|
|
||||||
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
|
||||||
inputs=i_say,
|
|
||||||
inputs_show_user=i_say,
|
|
||||||
llm_kwargs=llm_kwargs,
|
|
||||||
chatbot=chatbot,
|
|
||||||
history=this_paper_history,
|
|
||||||
sys_prompt="总结文章。"
|
|
||||||
)
|
|
||||||
|
|
||||||
history.extend([i_say,gpt_say])
|
|
||||||
this_paper_history.extend([i_say,gpt_say])
|
|
||||||
|
|
||||||
res = write_history_to_file(history)
|
|
||||||
promote_file_to_downloadzone(res, chatbot=chatbot)
|
|
||||||
chatbot.append(("完成了吗?", res))
|
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
|
||||||
|
|
||||||
res = write_history_to_file(history)
|
|
||||||
promote_file_to_downloadzone(res, chatbot=chatbot)
|
|
||||||
chatbot.append(("所有文件都总结完成了吗?", res))
|
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
|
||||||
|
|
||||||
|
|
||||||
@CatchException
|
|
||||||
def 总结word文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
|
||||||
import glob, os
|
|
||||||
|
|
||||||
# 基本信息:功能、贡献者
|
|
||||||
chatbot.append([
|
|
||||||
"函数插件功能?",
|
|
||||||
"批量总结Word文档。函数插件贡献者: JasonGuo1。注意, 如果是.doc文件, 请先转化为.docx格式。"])
|
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
|
||||||
|
|
||||||
# 尝试导入依赖,如果缺少依赖,则给出安装建议
|
|
||||||
try:
|
|
||||||
from docx import Document
|
|
||||||
except:
|
|
||||||
report_exception(chatbot, history,
|
|
||||||
a=f"解析项目: {txt}",
|
|
||||||
b=f"导入软件依赖失败。使用该模块需要额外依赖,安装方法```pip install --upgrade python-docx pywin32```。")
|
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
|
||||||
return
|
|
||||||
|
|
||||||
# 清空历史,以免输入溢出
|
|
||||||
history = []
|
|
||||||
|
|
||||||
# 检测输入参数,如没有给定输入参数,直接退出
|
|
||||||
if os.path.exists(txt):
|
|
||||||
project_folder = txt
|
|
||||||
else:
|
|
||||||
if txt == "": txt = '空空如也的输入栏'
|
|
||||||
report_exception(chatbot, history, a=f"解析项目: {txt}", b=f"找不到本地项目或无权访问: {txt}")
|
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
|
||||||
return
|
|
||||||
|
|
||||||
# 搜索需要处理的文件清单
|
|
||||||
if txt.endswith('.docx') or txt.endswith('.doc'):
|
|
||||||
file_manifest = [txt]
|
|
||||||
else:
|
|
||||||
file_manifest = [f for f in glob.glob(f'{project_folder}/**/*.docx', recursive=True)] + \
|
|
||||||
[f for f in glob.glob(f'{project_folder}/**/*.doc', recursive=True)]
|
|
||||||
|
|
||||||
# 如果没找到任何文件
|
|
||||||
if len(file_manifest) == 0:
|
|
||||||
report_exception(chatbot, history, a=f"解析项目: {txt}", b=f"找不到任何.docx或doc文件: {txt}")
|
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
|
||||||
return
|
|
||||||
|
|
||||||
# 开始正式执行任务
|
|
||||||
yield from 解析docx(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt)
|
|
||||||
496
crazy_functions/批量文件询问.py
普通文件
496
crazy_functions/批量文件询问.py
普通文件
@@ -0,0 +1,496 @@
|
|||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Tuple, Dict, Generator
|
||||||
|
|
||||||
|
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||||
|
from crazy_functions.pdf_fns.breakdown_txt import breakdown_text_to_satisfy_token_limit
|
||||||
|
from crazy_functions.rag_fns.rag_file_support import extract_text
|
||||||
|
from request_llms.bridge_all import model_info
|
||||||
|
from toolbox import update_ui, CatchException, report_exception
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FileFragment:
|
||||||
|
"""文件片段数据类,用于组织处理单元"""
|
||||||
|
file_path: str
|
||||||
|
content: str
|
||||||
|
rel_path: str
|
||||||
|
fragment_index: int
|
||||||
|
total_fragments: int
|
||||||
|
|
||||||
|
|
||||||
|
class BatchDocumentSummarizer:
|
||||||
|
"""优化的文档总结器 - 批处理版本"""
|
||||||
|
|
||||||
|
def __init__(self, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List, history: List, system_prompt: str):
|
||||||
|
"""初始化总结器"""
|
||||||
|
self.llm_kwargs = llm_kwargs
|
||||||
|
self.plugin_kwargs = plugin_kwargs
|
||||||
|
self.chatbot = chatbot
|
||||||
|
self.history = history
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
self.failed_files = []
|
||||||
|
self.file_summaries_map = {}
|
||||||
|
|
||||||
|
def _get_token_limit(self) -> int:
|
||||||
|
"""获取模型token限制"""
|
||||||
|
max_token = model_info[self.llm_kwargs['llm_model']]['max_token']
|
||||||
|
return max_token * 3 // 4
|
||||||
|
|
||||||
|
def _create_batch_inputs(self, fragments: List[FileFragment]) -> Tuple[List, List, List]:
|
||||||
|
"""创建批处理输入"""
|
||||||
|
inputs_array = []
|
||||||
|
inputs_show_user_array = []
|
||||||
|
history_array = []
|
||||||
|
|
||||||
|
for frag in fragments:
|
||||||
|
if self.plugin_kwargs.get("advanced_arg"):
|
||||||
|
i_say = (f'请按照用户要求对文件内容进行处理,文件名为{os.path.basename(frag.file_path)},'
|
||||||
|
f'用户要求为:{self.plugin_kwargs["advanced_arg"]}:'
|
||||||
|
f'文件内容是 ```{frag.content}```')
|
||||||
|
i_say_show_user = (f'正在处理 {frag.rel_path} (片段 {frag.fragment_index + 1}/{frag.total_fragments})')
|
||||||
|
else:
|
||||||
|
i_say = (f'请对下面的内容用中文做总结,不超过500字,文件名是{os.path.basename(frag.file_path)},'
|
||||||
|
f'内容是 ```{frag.content}```')
|
||||||
|
i_say_show_user = f'正在处理 {frag.rel_path} (片段 {frag.fragment_index + 1}/{frag.total_fragments})'
|
||||||
|
|
||||||
|
inputs_array.append(i_say)
|
||||||
|
inputs_show_user_array.append(i_say_show_user)
|
||||||
|
history_array.append([])
|
||||||
|
|
||||||
|
return inputs_array, inputs_show_user_array, history_array
|
||||||
|
|
||||||
|
def _process_single_file_with_timeout(self, file_info: Tuple[str, str], mutable_status: List) -> List[FileFragment]:
|
||||||
|
"""包装了超时控制的文件处理函数"""
|
||||||
|
|
||||||
|
def timeout_handler():
|
||||||
|
thread = threading.current_thread()
|
||||||
|
if hasattr(thread, '_timeout_occurred'):
|
||||||
|
thread._timeout_occurred = True
|
||||||
|
|
||||||
|
# 设置超时标记
|
||||||
|
thread = threading.current_thread()
|
||||||
|
thread._timeout_occurred = False
|
||||||
|
|
||||||
|
# 设置超时定时器
|
||||||
|
timer = threading.Timer(self.watch_dog_patience, timeout_handler)
|
||||||
|
timer.start()
|
||||||
|
|
||||||
|
try:
|
||||||
|
fp, project_folder = file_info
|
||||||
|
fragments = []
|
||||||
|
|
||||||
|
# 定期检查是否超时
|
||||||
|
def check_timeout():
|
||||||
|
if hasattr(thread, '_timeout_occurred') and thread._timeout_occurred:
|
||||||
|
raise TimeoutError("处理超时")
|
||||||
|
|
||||||
|
# 更新状态
|
||||||
|
mutable_status[0] = "检查文件大小"
|
||||||
|
mutable_status[1] = time.time()
|
||||||
|
check_timeout()
|
||||||
|
|
||||||
|
# 文件大小检查
|
||||||
|
if os.path.getsize(fp) > self.max_file_size:
|
||||||
|
self.failed_files.append((fp, f"文件过大:超过{self.max_file_size / 1024 / 1024}MB"))
|
||||||
|
mutable_status[2] = "文件过大"
|
||||||
|
return fragments
|
||||||
|
|
||||||
|
check_timeout()
|
||||||
|
|
||||||
|
# 更新状态
|
||||||
|
mutable_status[0] = "提取文件内容"
|
||||||
|
mutable_status[1] = time.time()
|
||||||
|
|
||||||
|
# 提取内容
|
||||||
|
content = extract_text(fp)
|
||||||
|
if content is None:
|
||||||
|
self.failed_files.append((fp, "文件解析失败:不支持的格式或文件损坏"))
|
||||||
|
mutable_status[2] = "格式不支持"
|
||||||
|
return fragments
|
||||||
|
elif not content.strip():
|
||||||
|
self.failed_files.append((fp, "文件内容为空"))
|
||||||
|
mutable_status[2] = "内容为空"
|
||||||
|
return fragments
|
||||||
|
|
||||||
|
check_timeout()
|
||||||
|
|
||||||
|
# 更新状态
|
||||||
|
mutable_status[0] = "分割文本"
|
||||||
|
mutable_status[1] = time.time()
|
||||||
|
|
||||||
|
# 分割文本
|
||||||
|
try:
|
||||||
|
paper_fragments = breakdown_text_to_satisfy_token_limit(
|
||||||
|
txt=content,
|
||||||
|
limit=self._get_token_limit(),
|
||||||
|
llm_model=self.llm_kwargs['llm_model']
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.failed_files.append((fp, f"文本分割失败:{str(e)}"))
|
||||||
|
mutable_status[2] = "分割失败"
|
||||||
|
return fragments
|
||||||
|
|
||||||
|
check_timeout()
|
||||||
|
|
||||||
|
# 处理片段
|
||||||
|
rel_path = os.path.relpath(fp, project_folder)
|
||||||
|
for i, frag in enumerate(paper_fragments):
|
||||||
|
if frag.strip():
|
||||||
|
fragments.append(FileFragment(
|
||||||
|
file_path=fp,
|
||||||
|
content=frag,
|
||||||
|
rel_path=rel_path,
|
||||||
|
fragment_index=i,
|
||||||
|
total_fragments=len(paper_fragments)
|
||||||
|
))
|
||||||
|
|
||||||
|
mutable_status[2] = "处理完成"
|
||||||
|
return fragments
|
||||||
|
|
||||||
|
except TimeoutError as e:
|
||||||
|
self.failed_files.append((fp, "处理超时"))
|
||||||
|
mutable_status[2] = "处理超时"
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
self.failed_files.append((fp, f"处理失败:{str(e)}"))
|
||||||
|
mutable_status[2] = "处理异常"
|
||||||
|
return []
|
||||||
|
finally:
|
||||||
|
timer.cancel()
|
||||||
|
|
||||||
|
def prepare_fragments(self, project_folder: str, file_paths: List[str]) -> Generator:
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import Generator, List
|
||||||
|
"""并行准备所有文件的处理片段"""
|
||||||
|
all_fragments = []
|
||||||
|
total_files = len(file_paths)
|
||||||
|
|
||||||
|
# 配置参数
|
||||||
|
self.refresh_interval = 0.2 # UI刷新间隔
|
||||||
|
self.watch_dog_patience = 5 # 看门狗超时时间
|
||||||
|
self.max_file_size = 10 * 1024 * 1024 # 10MB限制
|
||||||
|
self.max_workers = min(32, len(file_paths)) # 最多32个线程
|
||||||
|
|
||||||
|
# 创建有超时控制的线程池
|
||||||
|
executor = ThreadPoolExecutor(max_workers=self.max_workers)
|
||||||
|
|
||||||
|
# 用于跨线程状态传递的可变列表 - 增加文件名信息
|
||||||
|
mutable_status_array = [["等待中", time.time(), "pending", file_path] for file_path in file_paths]
|
||||||
|
|
||||||
|
# 创建文件处理任务
|
||||||
|
file_infos = [(fp, project_folder) for fp in file_paths]
|
||||||
|
|
||||||
|
# 提交所有任务,使用带超时控制的处理函数
|
||||||
|
futures = [
|
||||||
|
executor.submit(
|
||||||
|
self._process_single_file_with_timeout,
|
||||||
|
file_info,
|
||||||
|
mutable_status_array[i]
|
||||||
|
) for i, file_info in enumerate(file_infos)
|
||||||
|
]
|
||||||
|
|
||||||
|
# 更新UI的计数器
|
||||||
|
cnt = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 监控任务执行
|
||||||
|
while True:
|
||||||
|
time.sleep(self.refresh_interval)
|
||||||
|
cnt += 1
|
||||||
|
|
||||||
|
# 检查任务完成状态
|
||||||
|
worker_done = [f.done() for f in futures]
|
||||||
|
|
||||||
|
# 更新状态显示
|
||||||
|
status_str = ""
|
||||||
|
for i, (status, timestamp, desc, file_path) in enumerate(mutable_status_array):
|
||||||
|
# 获取文件名(去掉路径)
|
||||||
|
file_name = os.path.basename(file_path)
|
||||||
|
if worker_done[i]:
|
||||||
|
status_str += f"文件 {file_name}: {desc}\n"
|
||||||
|
else:
|
||||||
|
status_str += f"文件 {file_name}: {status} {desc}\n"
|
||||||
|
|
||||||
|
# 更新UI
|
||||||
|
self.chatbot[-1] = [
|
||||||
|
"处理进度",
|
||||||
|
f"正在处理文件...\n\n{status_str}" + "." * (cnt % 10 + 1)
|
||||||
|
]
|
||||||
|
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||||
|
|
||||||
|
# 检查是否所有任务完成
|
||||||
|
if all(worker_done):
|
||||||
|
break
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 确保线程池正确关闭
|
||||||
|
executor.shutdown(wait=False)
|
||||||
|
|
||||||
|
# 收集结果
|
||||||
|
processed_files = 0
|
||||||
|
for future in futures:
|
||||||
|
try:
|
||||||
|
fragments = future.result(timeout=0.1) # 给予一个短暂的超时时间来获取结果
|
||||||
|
all_fragments.extend(fragments)
|
||||||
|
processed_files += 1
|
||||||
|
except concurrent.futures.TimeoutError:
|
||||||
|
# 处理获取结果超时
|
||||||
|
file_index = futures.index(future)
|
||||||
|
self.failed_files.append((file_paths[file_index], "结果获取超时"))
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
# 处理其他异常
|
||||||
|
file_index = futures.index(future)
|
||||||
|
self.failed_files.append((file_paths[file_index], f"未知错误:{str(e)}"))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 最终进度更新
|
||||||
|
self.chatbot.append([
|
||||||
|
"文件处理完成",
|
||||||
|
f"成功处理 {len(all_fragments)} 个片段,失败 {len(self.failed_files)} 个文件"
|
||||||
|
])
|
||||||
|
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||||
|
|
||||||
|
return all_fragments
|
||||||
|
|
||||||
|
def _process_fragments_batch(self, fragments: List[FileFragment]) -> Generator:
|
||||||
|
"""批量处理文件片段"""
|
||||||
|
from collections import defaultdict
|
||||||
|
batch_size = 64 # 每批处理的片段数
|
||||||
|
max_retries = 3 # 最大重试次数
|
||||||
|
retry_delay = 5 # 重试延迟(秒)
|
||||||
|
results = defaultdict(list)
|
||||||
|
|
||||||
|
# 按批次处理
|
||||||
|
for i in range(0, len(fragments), batch_size):
|
||||||
|
batch = fragments[i:i + batch_size]
|
||||||
|
|
||||||
|
inputs_array, inputs_show_user_array, history_array = self._create_batch_inputs(batch)
|
||||||
|
sys_prompt_array = ["请总结以下内容:"] * len(batch)
|
||||||
|
|
||||||
|
# 添加重试机制
|
||||||
|
for retry in range(max_retries):
|
||||||
|
try:
|
||||||
|
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
|
||||||
|
inputs_array=inputs_array,
|
||||||
|
inputs_show_user_array=inputs_show_user_array,
|
||||||
|
llm_kwargs=self.llm_kwargs,
|
||||||
|
chatbot=self.chatbot,
|
||||||
|
history_array=history_array,
|
||||||
|
sys_prompt_array=sys_prompt_array,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 处理响应
|
||||||
|
for j, frag in enumerate(batch):
|
||||||
|
summary = response_collection[j * 2 + 1]
|
||||||
|
if summary and summary.strip():
|
||||||
|
results[frag.rel_path].append({
|
||||||
|
'index': frag.fragment_index,
|
||||||
|
'summary': summary,
|
||||||
|
'total': frag.total_fragments
|
||||||
|
})
|
||||||
|
break # 成功处理,跳出重试循环
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if retry == max_retries - 1: # 最后一次重试失败
|
||||||
|
for frag in batch:
|
||||||
|
self.failed_files.append((frag.file_path, f"处理失败:{str(e)}"))
|
||||||
|
else:
|
||||||
|
yield from update_ui(self.chatbot.append([f"批次处理失败,{retry_delay}秒后重试...", str(e)]))
|
||||||
|
time.sleep(retry_delay)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _generate_final_summary_request(self) -> Tuple[List, List, List]:
|
||||||
|
"""准备最终总结请求"""
|
||||||
|
if not self.file_summaries_map:
|
||||||
|
return (["无可用的文件总结"], ["生成最终总结"], [[]])
|
||||||
|
|
||||||
|
summaries = list(self.file_summaries_map.values())
|
||||||
|
if all(not summary for summary in summaries):
|
||||||
|
return (["所有文件处理均失败"], ["生成最终总结"], [[]])
|
||||||
|
|
||||||
|
if self.plugin_kwargs.get("advanced_arg"):
|
||||||
|
i_say = "根据以上所有文件的处理结果,按要求进行综合处理:" + self.plugin_kwargs['advanced_arg']
|
||||||
|
else:
|
||||||
|
i_say = "请根据以上所有文件的处理结果,生成最终的总结,不超过1000字。"
|
||||||
|
|
||||||
|
return ([i_say], [i_say], [summaries])
|
||||||
|
|
||||||
|
def process_files(self, project_folder: str, file_paths: List[str]) -> Generator:
|
||||||
|
"""处理所有文件"""
|
||||||
|
total_files = len(file_paths)
|
||||||
|
self.chatbot.append([f"开始处理", f"总计 {total_files} 个文件"])
|
||||||
|
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||||
|
|
||||||
|
# 1. 准备所有文件片段
|
||||||
|
# 在 process_files 函数中:
|
||||||
|
fragments = yield from self.prepare_fragments(project_folder, file_paths)
|
||||||
|
if not fragments:
|
||||||
|
self.chatbot.append(["处理失败", "没有可处理的文件内容"])
|
||||||
|
return "没有可处理的文件内容"
|
||||||
|
|
||||||
|
# 2. 批量处理所有文件片段
|
||||||
|
self.chatbot.append([f"文件分析", f"共计 {len(fragments)} 个处理单元"])
|
||||||
|
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||||
|
|
||||||
|
try:
|
||||||
|
file_summaries = yield from self._process_fragments_batch(fragments)
|
||||||
|
except Exception as e:
|
||||||
|
self.chatbot.append(["处理错误", f"批处理过程失败:{str(e)}"])
|
||||||
|
return "处理过程发生错误"
|
||||||
|
|
||||||
|
# 3. 为每个文件生成整体总结
|
||||||
|
self.chatbot.append(["生成总结", "正在汇总文件内容..."])
|
||||||
|
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||||
|
|
||||||
|
# 处理每个文件的总结
|
||||||
|
for rel_path, summaries in file_summaries.items():
|
||||||
|
if len(summaries) > 1: # 多片段文件需要生成整体总结
|
||||||
|
sorted_summaries = sorted(summaries, key=lambda x: x['index'])
|
||||||
|
if self.plugin_kwargs.get("advanced_arg"):
|
||||||
|
|
||||||
|
i_say = f'请按照用户要求对文件内容进行处理,用户要求为:{self.plugin_kwargs["advanced_arg"]}:'
|
||||||
|
else:
|
||||||
|
i_say = f"请总结文件 {os.path.basename(rel_path)} 的主要内容,不超过500字。"
|
||||||
|
|
||||||
|
try:
|
||||||
|
summary_texts = [s['summary'] for s in sorted_summaries]
|
||||||
|
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
|
||||||
|
inputs_array=[i_say],
|
||||||
|
inputs_show_user_array=[f"生成 {rel_path} 的处理结果"],
|
||||||
|
llm_kwargs=self.llm_kwargs,
|
||||||
|
chatbot=self.chatbot,
|
||||||
|
history_array=[summary_texts],
|
||||||
|
sys_prompt_array=["你是一个优秀的助手,"],
|
||||||
|
)
|
||||||
|
self.file_summaries_map[rel_path] = response_collection[1]
|
||||||
|
except Exception as e:
|
||||||
|
self.chatbot.append(["警告", f"文件 {rel_path} 总结生成失败:{str(e)}"])
|
||||||
|
self.file_summaries_map[rel_path] = "总结生成失败"
|
||||||
|
else: # 单片段文件直接使用其唯一的总结
|
||||||
|
self.file_summaries_map[rel_path] = summaries[0]['summary']
|
||||||
|
|
||||||
|
# 4. 生成最终总结
|
||||||
|
if total_files ==1:
|
||||||
|
return "文件数为1,此时不调用总结模块"
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
# 收集所有文件的总结用于生成最终总结
|
||||||
|
file_summaries_for_final = []
|
||||||
|
for rel_path, summary in self.file_summaries_map.items():
|
||||||
|
file_summaries_for_final.append(f"文件 {rel_path} 的总结:\n{summary}")
|
||||||
|
|
||||||
|
if self.plugin_kwargs.get("advanced_arg"):
|
||||||
|
final_summary_prompt = ("根据以下所有文件的总结内容,按要求进行综合处理:" +
|
||||||
|
self.plugin_kwargs['advanced_arg'])
|
||||||
|
else:
|
||||||
|
final_summary_prompt = "请根据以下所有文件的总结内容,生成最终的总结报告。"
|
||||||
|
|
||||||
|
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
|
||||||
|
inputs_array=[final_summary_prompt],
|
||||||
|
inputs_show_user_array=["生成最终总结报告"],
|
||||||
|
llm_kwargs=self.llm_kwargs,
|
||||||
|
chatbot=self.chatbot,
|
||||||
|
history_array=[file_summaries_for_final],
|
||||||
|
sys_prompt_array=["总结所有文件内容。"],
|
||||||
|
max_workers=1
|
||||||
|
)
|
||||||
|
|
||||||
|
return response_collection[1] if len(response_collection) > 1 else "生成总结失败"
|
||||||
|
except Exception as e:
|
||||||
|
self.chatbot.append(["错误", f"最终总结生成失败:{str(e)}"])
|
||||||
|
return "生成总结失败"
|
||||||
|
|
||||||
|
def save_results(self, final_summary: str):
|
||||||
|
"""保存结果到文件"""
|
||||||
|
from toolbox import promote_file_to_downloadzone, write_history_to_file
|
||||||
|
from crazy_functions.doc_fns.batch_file_query_doc import MarkdownFormatter, HtmlFormatter, WordFormatter
|
||||||
|
import os
|
||||||
|
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||||
|
|
||||||
|
# 创建各种格式化器
|
||||||
|
md_formatter = MarkdownFormatter(final_summary, self.file_summaries_map, self.failed_files)
|
||||||
|
html_formatter = HtmlFormatter(final_summary, self.file_summaries_map, self.failed_files)
|
||||||
|
word_formatter = WordFormatter(final_summary, self.file_summaries_map, self.failed_files)
|
||||||
|
|
||||||
|
result_files = []
|
||||||
|
|
||||||
|
# 保存 Markdown
|
||||||
|
md_content = md_formatter.create_document()
|
||||||
|
result_file_md = write_history_to_file(
|
||||||
|
history=[md_content], # 直接传入内容列表
|
||||||
|
file_basename=f"文档总结_{timestamp}.md"
|
||||||
|
)
|
||||||
|
result_files.append(result_file_md)
|
||||||
|
|
||||||
|
# 保存 HTML
|
||||||
|
html_content = html_formatter.create_document()
|
||||||
|
result_file_html = write_history_to_file(
|
||||||
|
history=[html_content],
|
||||||
|
file_basename=f"文档总结_{timestamp}.html"
|
||||||
|
)
|
||||||
|
result_files.append(result_file_html)
|
||||||
|
|
||||||
|
# 保存 Word
|
||||||
|
doc = word_formatter.create_document()
|
||||||
|
# 由于 Word 文档需要用 doc.save(),我们使用与 md 文件相同的目录
|
||||||
|
result_file_docx = os.path.join(
|
||||||
|
os.path.dirname(result_file_md),
|
||||||
|
f"文档总结_{timestamp}.docx"
|
||||||
|
)
|
||||||
|
doc.save(result_file_docx)
|
||||||
|
result_files.append(result_file_docx)
|
||||||
|
|
||||||
|
# 添加到下载区
|
||||||
|
for file in result_files:
|
||||||
|
promote_file_to_downloadzone(file, chatbot=self.chatbot)
|
||||||
|
|
||||||
|
self.chatbot.append(["处理完成", f"结果已保存至: {', '.join(result_files)}"])
|
||||||
|
@CatchException
|
||||||
|
def 批量文件询问(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
|
||||||
|
history: List, system_prompt: str, user_request: str):
|
||||||
|
"""主函数 - 优化版本"""
|
||||||
|
# 初始化
|
||||||
|
import glob
|
||||||
|
import re
|
||||||
|
from crazy_functions.rag_fns.rag_file_support import supports_format
|
||||||
|
from toolbox import report_exception
|
||||||
|
|
||||||
|
summarizer = BatchDocumentSummarizer(llm_kwargs, plugin_kwargs, chatbot, history, system_prompt)
|
||||||
|
chatbot.append(["函数插件功能", f"作者:lbykkkk,批量总结文件。支持格式: {', '.join(supports_format)}等其他文本格式文件,如果长时间卡在文件处理过程,请查看处理进度,然后删除所有处于“pending”状态的文件,然后重新上传处理。"])
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
|
||||||
|
# 验证输入路径
|
||||||
|
if not os.path.exists(txt):
|
||||||
|
report_exception(chatbot, history, a=f"解析项目: {txt}", b=f"找不到项目或无权访问: {txt}")
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 获取文件列表
|
||||||
|
project_folder = txt
|
||||||
|
extract_folder = next((d for d in glob.glob(f'{project_folder}/*')
|
||||||
|
if os.path.isdir(d) and d.endswith('.extract')), project_folder)
|
||||||
|
|
||||||
|
exclude_patterns = r'/[^/]+\.(zip|rar|7z|tar|gz)$'
|
||||||
|
file_manifest = [f for f in glob.glob(f'{extract_folder}/**', recursive=True)
|
||||||
|
if os.path.isfile(f) and not re.search(exclude_patterns, f)]
|
||||||
|
|
||||||
|
if not file_manifest:
|
||||||
|
report_exception(chatbot, history, a=f"解析项目: {txt}", b="未找到支持的文件类型")
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 处理所有文件并生成总结
|
||||||
|
final_summary = yield from summarizer.process_files(project_folder, file_manifest)
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
|
||||||
|
# 保存结果
|
||||||
|
summarizer.save_results(final_summary)
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
@@ -180,6 +180,7 @@ version: '3'
|
|||||||
services:
|
services:
|
||||||
gpt_academic_with_latex:
|
gpt_academic_with_latex:
|
||||||
image: ghcr.io/binary-husky/gpt_academic_with_latex:master # (Auto Built by Dockerfile: docs/GithubAction+NoLocal+Latex)
|
image: ghcr.io/binary-husky/gpt_academic_with_latex:master # (Auto Built by Dockerfile: docs/GithubAction+NoLocal+Latex)
|
||||||
|
# 对于ARM64设备,请将以上镜像名称替换为 ghcr.io/binary-husky/gpt_academic_with_latex_arm:master
|
||||||
environment:
|
environment:
|
||||||
# 请查阅 `config.py` 以查看所有的配置信息
|
# 请查阅 `config.py` 以查看所有的配置信息
|
||||||
API_KEY: ' sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx '
|
API_KEY: ' sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx '
|
||||||
|
|||||||
@@ -1,35 +1,34 @@
|
|||||||
# 此Dockerfile适用于“无本地模型”的环境构建,如果需要使用chatglm等本地模型,请参考 docs/Dockerfile+ChatGLM
|
# 此Dockerfile适用于"无本地模型"的环境构建,如果需要使用chatglm等本地模型,请参考 docs/Dockerfile+ChatGLM
|
||||||
# - 1 修改 `config.py`
|
# - 1 修改 `config.py`
|
||||||
# - 2 构建 docker build -t gpt-academic-nolocal-latex -f docs/GithubAction+NoLocal+Latex .
|
# - 2 构建 docker build -t gpt-academic-nolocal-latex -f docs/GithubAction+NoLocal+Latex .
|
||||||
# - 3 运行 docker run -v /home/fuqingxu/arxiv_cache:/root/arxiv_cache --rm -it --net=host gpt-academic-nolocal-latex
|
# - 3 运行 docker run -v /home/fuqingxu/arxiv_cache:/root/arxiv_cache --rm -it --net=host gpt-academic-nolocal-latex
|
||||||
|
|
||||||
FROM fuqingxu/python311_texlive_ctex:latest
|
FROM menghuan1918/ubuntu_uv_ctex:latest
|
||||||
ENV PATH "$PATH:/usr/local/texlive/2022/bin/x86_64-linux"
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
ENV PATH "$PATH:/usr/local/texlive/2023/bin/x86_64-linux"
|
SHELL ["/bin/bash", "-c"]
|
||||||
ENV PATH "$PATH:/usr/local/texlive/2024/bin/x86_64-linux"
|
|
||||||
ENV PATH "$PATH:/usr/local/texlive/2025/bin/x86_64-linux"
|
|
||||||
ENV PATH "$PATH:/usr/local/texlive/2026/bin/x86_64-linux"
|
|
||||||
|
|
||||||
# 指定路径
|
|
||||||
WORKDIR /gpt
|
WORKDIR /gpt
|
||||||
|
|
||||||
RUN pip3 install openai numpy arxiv rich
|
# 先复制依赖文件
|
||||||
RUN pip3 install colorama Markdown pygments pymupdf
|
COPY requirements.txt .
|
||||||
RUN pip3 install python-docx pdfminer
|
|
||||||
RUN pip3 install nougat-ocr
|
|
||||||
|
|
||||||
# 装载项目文件
|
|
||||||
COPY . .
|
|
||||||
|
|
||||||
|
|
||||||
# 安装依赖
|
# 安装依赖
|
||||||
RUN pip3 install -r requirements.txt
|
RUN pip install --break-system-packages openai numpy arxiv rich colorama Markdown pygments pymupdf python-docx pdfminer \
|
||||||
|
&& pip install --break-system-packages -r requirements.txt \
|
||||||
|
&& if [ "$(uname -m)" = "x86_64" ]; then \
|
||||||
|
pip install --break-system-packages nougat-ocr; \
|
||||||
|
fi \
|
||||||
|
&& pip cache purge \
|
||||||
|
&& rm -rf /root/.cache/pip/*
|
||||||
|
|
||||||
# edge-tts需要的依赖
|
# 创建非root用户
|
||||||
RUN apt update && apt install ffmpeg -y
|
RUN useradd -m gptuser && chown -R gptuser /gpt
|
||||||
|
USER gptuser
|
||||||
|
|
||||||
|
# 最后才复制代码文件,这样代码更新时只需重建最后几层,可以大幅减少docker pull所需的大小
|
||||||
|
COPY --chown=gptuser:gptuser . .
|
||||||
|
|
||||||
# 可选步骤,用于预热模块
|
# 可选步骤,用于预热模块
|
||||||
RUN python3 -c 'from check_proxy import warm_up_modules; warm_up_modules()'
|
RUN python3 -c 'from check_proxy import warm_up_modules; warm_up_modules()'
|
||||||
|
|
||||||
# 启动
|
# 启动
|
||||||
CMD ["python3", "-u", "main.py"]
|
CMD ["python3", "-u", "main.py"]
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
# 此Dockerfile适用于“无本地模型”的环境构建,如果需要使用chatglm等本地模型,请参考 docs/Dockerfile+ChatGLM
|
|
||||||
# - 1 修改 `config.py`
|
|
||||||
# - 2 构建 docker build -t gpt-academic-nolocal-latex -f docs/GithubAction+NoLocal+Latex .
|
|
||||||
# - 3 运行 docker run -v /home/fuqingxu/arxiv_cache:/root/arxiv_cache --rm -it --net=host gpt-academic-nolocal-latex
|
|
||||||
|
|
||||||
FROM menghuan1918/ubuntu_uv_ctex:latest
|
|
||||||
ENV DEBIAN_FRONTEND=noninteractive
|
|
||||||
SHELL ["/bin/bash", "-c"]
|
|
||||||
WORKDIR /gpt
|
|
||||||
COPY . .
|
|
||||||
RUN /root/.cargo/bin/uv venv --seed \
|
|
||||||
&& source .venv/bin/activate \
|
|
||||||
&& /root/.cargo/bin/uv pip install openai numpy arxiv rich colorama Markdown pygments pymupdf python-docx pdfminer \
|
|
||||||
&& /root/.cargo/bin/uv pip install -r requirements.txt \
|
|
||||||
&& /root/.cargo/bin/uv clean
|
|
||||||
|
|
||||||
# 对齐python3
|
|
||||||
RUN rm -f /usr/bin/python3 && ln -s /gpt/.venv/bin/python /usr/bin/python3
|
|
||||||
RUN rm -f /usr/bin/python && ln -s /gpt/.venv/bin/python /usr/bin/python
|
|
||||||
|
|
||||||
# 可选步骤,用于预热模块
|
|
||||||
RUN python3 -c 'from check_proxy import warm_up_modules; warm_up_modules()'
|
|
||||||
|
|
||||||
# 启动
|
|
||||||
CMD ["python3", "-u", "main.py"]
|
|
||||||
216
instruction.txt
普通文件
216
instruction.txt
普通文件
@@ -0,0 +1,216 @@
|
|||||||
|
|
||||||
|
1、GPT Academic 项目结构
|
||||||
|
.
|
||||||
|
├── Dockerfile
|
||||||
|
├── LICENSE
|
||||||
|
├── README.md
|
||||||
|
├── check_proxy.py
|
||||||
|
├── config.py
|
||||||
|
├── config_private.py
|
||||||
|
├── core_functional.py
|
||||||
|
├── crazy_functional.py
|
||||||
|
├── crazy_functions
|
||||||
|
│ ├── Arxiv_论文对话.py
|
||||||
|
│ ├── Conversation_To_File.py
|
||||||
|
│ ├── Image_Generate.py
|
||||||
|
│ ├── Image_Generate_Wrap.py
|
||||||
|
│ ├── Internet_GPT.py
|
||||||
|
│ ├── Internet_GPT_Wrap.py
|
||||||
|
│ ├── Latex_Function.py
|
||||||
|
│ ├── Latex_Function_Wrap.py
|
||||||
|
│ ├── Latex全文润色.py
|
||||||
|
│ ├── Latex全文翻译.py
|
||||||
|
│ ├── Markdown_Translate.py
|
||||||
|
│ ├── PDF_Translate.py
|
||||||
|
│ ├── PDF_Translate_Wrap.py
|
||||||
|
│ ├── Rag_Interface.py
|
||||||
|
│ ├── Social_Helper.py
|
||||||
|
│ ├── SourceCode_Analyse.py
|
||||||
|
│ ├── SourceCode_Comment.py
|
||||||
|
│ ├── SourceCode_Comment_Wrap.py
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ │ ├── auto_agent.py
|
||||||
|
│ │ ├── echo_agent.py
|
||||||
|
│ │ ├── general.py
|
||||||
|
│ │ ├── persistent.py
|
||||||
|
│ │ ├── pipe.py
|
||||||
|
│ │ ├── python_comment_agent.py
|
||||||
|
│ │ ├── python_comment_compare.html
|
||||||
|
│ │ └── watchdog.py
|
||||||
|
│ ├── ast_fns
|
||||||
|
│ │ └── comment_remove.py
|
||||||
|
│ ├── chatglm微调工具.py
|
||||||
|
│ ├── crazy_utils.py
|
||||||
|
│ ├── diagram_fns
|
||||||
|
│ │ └── file_tree.py
|
||||||
|
│ ├── game_fns
|
||||||
|
│ │ ├── game_ascii_art.py
|
||||||
|
│ │ ├── game_interactive_story.py
|
||||||
|
│ │ └── game_utils.py
|
||||||
|
│ ├── gen_fns
|
||||||
|
│ │ └── gen_fns_shared.py
|
||||||
|
│ ├── ipc_fns
|
||||||
|
│ │ └── mp.py
|
||||||
|
│ ├── json_fns
|
||||||
|
│ │ ├── pydantic_io.py
|
||||||
|
│ │ └── select_tool.py
|
||||||
|
│ ├── latex_fns
|
||||||
|
│ │ ├── latex_actions.py
|
||||||
|
│ │ ├── latex_pickle_io.py
|
||||||
|
│ │ └── latex_toolbox.py
|
||||||
|
│ ├── live_audio
|
||||||
|
│ │ ├── aliyunASR.py
|
||||||
|
│ │ └── audio_io.py
|
||||||
|
│ ├── multi_stage
|
||||||
|
│ │ └── multi_stage_utils.py
|
||||||
|
│ ├── rag_essay_fns
|
||||||
|
│ │ └── multi_stage_utils.py
|
||||||
|
│ ├── pdf_fns
|
||||||
|
│ │ ├── breakdown_txt.py
|
||||||
|
│ │ ├── parse_pdf.py
|
||||||
|
│ │ ├── parse_pdf_grobid.py
|
||||||
|
│ │ ├── parse_pdf_legacy.py
|
||||||
|
│ │ ├── parse_pdf_via_doc2x.py
|
||||||
|
│ │ ├── parse_word.py
|
||||||
|
│ │ ├── report_gen_html.py
|
||||||
|
│ │ ├── report_template.html
|
||||||
|
│ │ └── report_template_v2.html
|
||||||
|
│ ├── plugin_template
|
||||||
|
│ │ └── plugin_class_template.py
|
||||||
|
│ ├── prompts
|
||||||
|
│ │ └── internet.py
|
||||||
|
│ ├── rag_fns
|
||||||
|
│ │ ├── llama_index_worker.py
|
||||||
|
│ │ ├── milvus_worker.py
|
||||||
|
│ │ ├── rag_file_support.py
|
||||||
|
│ │ └── vector_store_index.py
|
||||||
|
│ ├── vector_fns
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── general_file_loader.py
|
||||||
|
│ │ └── vector_database.py
|
||||||
|
│ ├── vt_fns
|
||||||
|
│ │ ├── vt_call_plugin.py
|
||||||
|
│ │ ├── vt_modify_config.py
|
||||||
|
│ │ └── vt_state.py
|
||||||
|
│ ├── 下载arxiv论文翻译摘要.py
|
||||||
|
│ ├── 互动小游戏.py
|
||||||
|
│ ├── 交互功能函数模板.py
|
||||||
|
│ ├── 函数动态生成.py
|
||||||
|
│ ├── 命令行助手.py
|
||||||
|
│ ├── 多智能体.py
|
||||||
|
│ ├── 总结word文档.py
|
||||||
|
│ ├── 总结音视频.py
|
||||||
|
│ ├── 批量总结PDF文档.py
|
||||||
|
│ ├── 批量总结PDF文档pdfminer.py
|
||||||
|
│ ├── 批量文件询问.py
|
||||||
|
│ ├── 批量翻译PDF文档_NOUGAT.py
|
||||||
|
│ ├── 数学动画生成manim.py
|
||||||
|
│ ├── 理解PDF文档内容.py
|
||||||
|
│ ├── 生成函数注释.py
|
||||||
|
│ ├── 生成多种Mermaid图表.py
|
||||||
|
│ ├── 知识库问答.py
|
||||||
|
│ ├── 联网的ChatGPT.py
|
||||||
|
│ ├── 联网的ChatGPT_bing版.py
|
||||||
|
│ ├── 虚空终端.py
|
||||||
|
│ ├── 解析JupyterNotebook.py
|
||||||
|
│ ├── 询问多个大语言模型.py
|
||||||
|
│ ├── 语音助手.py
|
||||||
|
│ ├── 读文章写摘要.py
|
||||||
|
│ ├── 谷歌检索小助手.py
|
||||||
|
│ ├── 辅助功能.py
|
||||||
|
│ └── 高级功能函数模板.py
|
||||||
|
├── docker-compose.yml
|
||||||
|
├── instruction.txt
|
||||||
|
├── main.py
|
||||||
|
├── multi_language.py
|
||||||
|
├── requirements.txt
|
||||||
|
├── shared_utils
|
||||||
|
│ ├── advanced_markdown_format.py
|
||||||
|
│ ├── char_visual_effect.py
|
||||||
|
│ ├── colorful.py
|
||||||
|
│ ├── config_loader.py
|
||||||
|
│ ├── connect_void_terminal.py
|
||||||
|
│ ├── cookie_manager.py
|
||||||
|
│ ├── fastapi_server.py
|
||||||
|
│ ├── handle_upload.py
|
||||||
|
│ ├── key_pattern_manager.py
|
||||||
|
│ ├── logging.py
|
||||||
|
│ ├── map_names.py
|
||||||
|
│ └── text_mask.py
|
||||||
|
├── toolbox.py
|
||||||
|
└── version
|
||||||
|
|
||||||
|
2、light_rag的实现方案路径为crazy_functions/rag_fns/LightRAG,主要功能实现文件为operate.py,rag使用到的其他文件为prompt.py、base.py、storage.py、utils.py,请参考实现方案实现插件功能。light_rag的使用案例可以参考crazy_functions/rag_fns/LightRAG/examples路径下的lightrag_hf_demo.py、lightrag_lmdeploy_demo.py:
|
||||||
|
路径目录结构为
|
||||||
|
|
||||||
|
├── README.md
|
||||||
|
├── examples
|
||||||
|
│ ├── batch_eval.py
|
||||||
|
│ ├── generate_query.py
|
||||||
|
│ ├── graph_visual_with_html.py
|
||||||
|
│ ├── graph_visual_with_neo4j.py
|
||||||
|
│ ├── lightrag_azure_openai_demo.py
|
||||||
|
│ ├── lightrag_bedrock_demo.py
|
||||||
|
│ ├── lightrag_hf_demo.py
|
||||||
|
│ ├── lightrag_ollama_demo.py
|
||||||
|
│ ├── lightrag_openai_compatible_demo.py
|
||||||
|
│ ├── lightrag_openai_demo.py
|
||||||
|
│ └── vram_management_demo.py
|
||||||
|
├── lightrag
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── base.py
|
||||||
|
│ ├── lightrag.py
|
||||||
|
│ ├── llm.py
|
||||||
|
│ ├── operate.py
|
||||||
|
│ ├── prompt.py
|
||||||
|
│ ├── storage.py
|
||||||
|
│ └── utils.py
|
||||||
|
├── reproduce
|
||||||
|
│ ├── Step_0.py
|
||||||
|
│ ├── Step_1.py
|
||||||
|
│ ├── Step_1_openai_compatible.py
|
||||||
|
│ ├── Step_2.py
|
||||||
|
│ ├── Step_3.py
|
||||||
|
│ └── Step_3_openai_compatible.py
|
||||||
|
├── requirements.txt
|
||||||
|
└── setup.py
|
||||||
|
|
||||||
|
|
||||||
|
3、我需要开发一个rag插件,请帮我实现一个插件,插件的名称是rag论文总结,插件主入口在crazy_functions/Arxiv_论文对话.py中的Rag论文对话函数,插件的功能步骤分为文件处理和RAG两个步骤,以下是具体的一些要求:
|
||||||
|
I. 函数头如下:
|
||||||
|
@CatchException
|
||||||
|
def rag论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
|
||||||
|
history: List, system_prompt: str, user_request: str):
|
||||||
|
II. 函数返回可参考crazy_functions/批量文件询问.py中的“批量文件询问”函数,主要采用yield方式
|
||||||
|
|
||||||
|
3、对于RAG,我希望采用light_rag的方案,参考已有方案其主要的功能实现是:
|
||||||
|
主要功能包括:
|
||||||
|
a. 分别为project和arxiv创建rag_handler,project类的fragment类内容为
|
||||||
|
@dataclass
|
||||||
|
class DocFragment:
|
||||||
|
"""文本片段数据类"""
|
||||||
|
file_path: str # 原始文件路径
|
||||||
|
content: str # 片段内容
|
||||||
|
segment_index: int # 片段序号
|
||||||
|
total_segments: int # 总片段数
|
||||||
|
rel_path: str # 相对路径
|
||||||
|
arxiv的fragment内容为:
|
||||||
|
@dataclass
|
||||||
|
class ArxivFragment:
|
||||||
|
"""Arxiv论文片段数据类"""
|
||||||
|
file_path: str
|
||||||
|
content: str
|
||||||
|
segment_index: int
|
||||||
|
total_segments: int
|
||||||
|
rel_path: str
|
||||||
|
segment_type: str
|
||||||
|
title: str
|
||||||
|
abstract: str
|
||||||
|
section: str
|
||||||
|
is_appendix: bool
|
||||||
|
b 如果目录下不存在抽取好的实体或关系的摘要,利用`_handle_entity_relation_summary`函数对d步骤生成的文本块进行实体或关系的摘要,并将其存储在project或者arxiv的路径下,路径为获取fragment.file_path的前三级目录(按照“/”区分每一级),如果原目录存在抽取好的,请直接使用,不再重复抽取。
|
||||||
|
f 利用`_handle_single_entity_extraction` 和 `_handle_single_relationship_extraction`:从记录中提取单个实体或关系信息。
|
||||||
|
g `_merge_nodes_then_upsert` 和 `_merge_edges_then_upsert`:合并并插入节点或边。
|
||||||
|
h `extract_entities`:处理多个文本块,提取实体和关系,并存储在知识图谱和向量数据库中。
|
||||||
|
i `local_query`:根据查询提取关键词并生成响应。
|
||||||
|
|
||||||
@@ -385,6 +385,14 @@ model_info = {
|
|||||||
"tokenizer": tokenizer_gpt35,
|
"tokenizer": tokenizer_gpt35,
|
||||||
"token_cnt": get_token_num_gpt35,
|
"token_cnt": get_token_num_gpt35,
|
||||||
},
|
},
|
||||||
|
"glm-4-plus":{
|
||||||
|
"fn_with_ui": zhipu_ui,
|
||||||
|
"fn_without_ui": zhipu_noui,
|
||||||
|
"endpoint": None,
|
||||||
|
"max_token": 10124 * 8,
|
||||||
|
"tokenizer": tokenizer_gpt35,
|
||||||
|
"token_cnt": get_token_num_gpt35,
|
||||||
|
},
|
||||||
|
|
||||||
# api_2d (此后不需要在此处添加api2d的接口了,因为下面的代码会自动添加)
|
# api_2d (此后不需要在此处添加api2d的接口了,因为下面的代码会自动添加)
|
||||||
"api2d-gpt-4": {
|
"api2d-gpt-4": {
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from toolbox import get_conf
|
from toolbox import get_conf
|
||||||
from toolbox import CatchException, update_ui, get_conf, select_api_key, get_log_folder, ProxyNetworkActivate
|
from toolbox import CatchException, update_ui, get_conf, select_api_key, get_log_folder, ProxyNetworkActivate
|
||||||
|
|||||||
@@ -12,12 +12,14 @@ transformers>=4.27.1,<4.42
|
|||||||
scipdf_parser>=0.52
|
scipdf_parser>=0.52
|
||||||
spacy==3.7.4
|
spacy==3.7.4
|
||||||
anthropic>=0.18.1
|
anthropic>=0.18.1
|
||||||
|
sentence-transformers
|
||||||
python-markdown-math
|
python-markdown-math
|
||||||
pymdown-extensions
|
pymdown-extensions
|
||||||
websocket-client
|
websocket-client
|
||||||
beautifulsoup4
|
beautifulsoup4
|
||||||
prompt_toolkit
|
prompt_toolkit
|
||||||
latex2mathml
|
latex2mathml
|
||||||
|
scikit-learn
|
||||||
python-docx
|
python-docx
|
||||||
mdtex2html
|
mdtex2html
|
||||||
dashscope
|
dashscope
|
||||||
@@ -43,4 +45,4 @@ llama-index-embeddings-azure-openai==0.1.10
|
|||||||
llama-index-embeddings-openai==0.1.10
|
llama-index-embeddings-openai==0.1.10
|
||||||
llama-parse==0.4.9
|
llama-parse==0.4.9
|
||||||
mdit-py-plugins>=0.3.3
|
mdit-py-plugins>=0.3.3
|
||||||
linkify-it-py==2.0.3
|
linkify-it-py==2.0.3
|
||||||
|
|||||||
@@ -138,7 +138,9 @@ def start_app(app_block, CONCURRENT_COUNT, AUTHENTICATION, PORT, SSL_KEYFILE, SS
|
|||||||
app_block.is_sagemaker = False
|
app_block.is_sagemaker = False
|
||||||
|
|
||||||
gradio_app = App.create_app(app_block)
|
gradio_app = App.create_app(app_block)
|
||||||
|
for route in list(gradio_app.router.routes):
|
||||||
|
if route.path == "/proxy={url_path:path}":
|
||||||
|
gradio_app.router.routes.remove(route)
|
||||||
# --- --- replace gradio endpoint to forbid access to sensitive files --- ---
|
# --- --- replace gradio endpoint to forbid access to sensitive files --- ---
|
||||||
if len(AUTHENTICATION) > 0:
|
if len(AUTHENTICATION) > 0:
|
||||||
dependencies = []
|
dependencies = []
|
||||||
@@ -154,9 +156,13 @@ def start_app(app_block, CONCURRENT_COUNT, AUTHENTICATION, PORT, SSL_KEYFILE, SS
|
|||||||
@gradio_app.head("/file={path_or_url:path}", dependencies=dependencies)
|
@gradio_app.head("/file={path_or_url:path}", dependencies=dependencies)
|
||||||
@gradio_app.get("/file={path_or_url:path}", dependencies=dependencies)
|
@gradio_app.get("/file={path_or_url:path}", dependencies=dependencies)
|
||||||
async def file(path_or_url: str, request: fastapi.Request):
|
async def file(path_or_url: str, request: fastapi.Request):
|
||||||
if len(AUTHENTICATION) > 0:
|
if not _authorize_user(path_or_url, request, gradio_app):
|
||||||
if not _authorize_user(path_or_url, request, gradio_app):
|
return "越权访问!"
|
||||||
return "越权访问!"
|
stripped = path_or_url.lstrip().lower()
|
||||||
|
if stripped.startswith("https://") or stripped.startswith("http://"):
|
||||||
|
return "账户密码授权模式下, 禁止链接!"
|
||||||
|
if '../' in stripped:
|
||||||
|
return "非法路径!"
|
||||||
return await endpoint(path_or_url, request)
|
return await endpoint(path_or_url, request)
|
||||||
|
|
||||||
from fastapi import Request, status
|
from fastapi import Request, status
|
||||||
@@ -167,6 +173,26 @@ def start_app(app_block, CONCURRENT_COUNT, AUTHENTICATION, PORT, SSL_KEYFILE, SS
|
|||||||
response.delete_cookie('access-token')
|
response.delete_cookie('access-token')
|
||||||
response.delete_cookie('access-token-unsecure')
|
response.delete_cookie('access-token-unsecure')
|
||||||
return response
|
return response
|
||||||
|
else:
|
||||||
|
dependencies = []
|
||||||
|
endpoint = None
|
||||||
|
for route in list(gradio_app.router.routes):
|
||||||
|
if route.path == "/file/{path:path}":
|
||||||
|
gradio_app.router.routes.remove(route)
|
||||||
|
if route.path == "/file={path_or_url:path}":
|
||||||
|
dependencies = route.dependencies
|
||||||
|
endpoint = route.endpoint
|
||||||
|
gradio_app.router.routes.remove(route)
|
||||||
|
@gradio_app.get("/file/{path:path}", dependencies=dependencies)
|
||||||
|
@gradio_app.head("/file={path_or_url:path}", dependencies=dependencies)
|
||||||
|
@gradio_app.get("/file={path_or_url:path}", dependencies=dependencies)
|
||||||
|
async def file(path_or_url: str, request: fastapi.Request):
|
||||||
|
stripped = path_or_url.lstrip().lower()
|
||||||
|
if stripped.startswith("https://") or stripped.startswith("http://"):
|
||||||
|
return "账户密码授权模式下, 禁止链接!"
|
||||||
|
if '../' in stripped:
|
||||||
|
return "非法路径!"
|
||||||
|
return await endpoint(path_or_url, request)
|
||||||
|
|
||||||
# --- --- enable TTS (text-to-speech) functionality --- ---
|
# --- --- enable TTS (text-to-speech) functionality --- ---
|
||||||
TTS_TYPE = get_conf("TTS_TYPE")
|
TTS_TYPE = get_conf("TTS_TYPE")
|
||||||
|
|||||||
7
tests/test_doc2x.py
普通文件
7
tests/test_doc2x.py
普通文件
@@ -0,0 +1,7 @@
|
|||||||
|
import init_test
|
||||||
|
|
||||||
|
from crazy_functions.pdf_fns.parse_pdf_via_doc2x import 解析PDF_DOC2X_转Latex
|
||||||
|
|
||||||
|
# 解析PDF_DOC2X_转Latex("gpt_log/arxiv_cache_old/2410.10819/workfolder/merge.pdf")
|
||||||
|
# 解析PDF_DOC2X_转Latex("gpt_log/arxiv_cache_ooo/2410.07095/workfolder/merge.pdf")
|
||||||
|
解析PDF_DOC2X_转Latex("2410.11190v2.pdf")
|
||||||
在新工单中引用
屏蔽一个用户