镜像自地址
https://github.com/binary-husky/gpt_academic.git
已同步 2025-12-06 22:46:48 +00:00
比较提交
139 次代码提交
version3.8
...
boyin_rag
| 作者 | SHA1 | 提交日期 | |
|---|---|---|---|
|
|
2f946f3e6c | ||
|
|
51ea7f3b5e | ||
|
|
795a6a9333 | ||
|
|
3beb22a347 | ||
|
|
b3aef6b393 | ||
|
|
cf51d4b205 | ||
|
|
bd9c88e896 | ||
|
|
27958b9030 | ||
|
|
9b9d77eded | ||
|
|
50dbff3a14 | ||
|
|
e1dc600030 | ||
|
|
6557c3822a | ||
|
|
81ab9f91a4 | ||
|
|
241c9641bb | ||
|
|
b2d6536974 | ||
|
|
12be7c16e9 | ||
|
|
724940a9d8 | ||
|
|
ea4cd95645 | ||
|
|
f8b60870e9 | ||
|
|
cbef9a908c | ||
|
|
21626a44d5 | ||
|
|
dd902e9519 | ||
|
|
68aa846a89 | ||
|
|
b8617921f4 | ||
|
|
c6687646e4 | ||
|
|
bfa72fb4cf | ||
|
|
61676d0536 | ||
|
|
df2ef7940c | ||
|
|
0afd27deca | ||
|
|
91f5e6b8f7 | ||
|
|
c10f2b45e5 | ||
|
|
7e2ede2d12 | ||
|
|
ec10e2a3ac | ||
|
|
4f0851f703 | ||
|
|
2821f27756 | ||
|
|
7474d43433 | ||
|
|
83489f9acf | ||
|
|
36e50d490d | ||
|
|
9172337695 | ||
|
|
180550b8f0 | ||
|
|
7497dcb852 | ||
|
|
5dab7b2290 | ||
|
|
23ef2ffb22 | ||
|
|
848d0f65c7 | ||
|
|
f0b0364f74 | ||
|
|
89dc6c7265 | ||
|
|
69f3755682 | ||
|
|
4727113243 | ||
|
|
21111d3bd0 | ||
|
|
be9aead04a | ||
|
|
701018f48c | ||
|
|
8733c4e1e9 | ||
|
|
8498ddf6bf | ||
|
|
3c3293818d | ||
|
|
310122f5a7 | ||
|
|
c83bf214d0 | ||
|
|
e34c49dce5 | ||
|
|
3890467c84 | ||
|
|
074b3c9828 | ||
|
|
b8e8457a01 | ||
|
|
2c93a24d7e | ||
|
|
e9af6ef3a0 | ||
|
|
5ae8981dbb | ||
|
|
9adc0ade71 | ||
|
|
bbcdd9aa71 | ||
|
|
cdfe38d296 | ||
|
|
159f628dfe | ||
|
|
adbed044e4 | ||
|
|
2fe5febaf0 | ||
|
|
5888d038aa | ||
|
|
ee8213e936 | ||
|
|
a57dcbcaeb | ||
|
|
b812392a9d | ||
|
|
f54d8e559a | ||
|
|
fce4fa1ec7 | ||
|
|
d13f1e270c | ||
|
|
85cf3d08eb | ||
|
|
584e747565 | ||
|
|
02ba653c19 | ||
|
|
e68fc2bc69 | ||
|
|
f695d7f1da | ||
|
|
2d12b5b27d | ||
|
|
679352d896 | ||
|
|
12c9ab1e33 | ||
|
|
a4bcd262f9 | ||
|
|
da4a5efc49 | ||
|
|
9ac450cfb6 | ||
|
|
172f9e220b | ||
|
|
748e31102f | ||
|
|
a28b7d8475 | ||
|
|
7d3ed36899 | ||
|
|
a7bc5fa357 | ||
|
|
4f5dd9ebcf | ||
|
|
427feb99d8 | ||
|
|
a01ca93362 | ||
|
|
97eef45ab7 | ||
|
|
0c0e2acb9b | ||
|
|
9fba8e0142 | ||
|
|
7d7867fb64 | ||
|
|
7ea791d83a | ||
|
|
f9dbaa39fb | ||
|
|
bbc2288c5b | ||
|
|
64ab916838 | ||
|
|
8fe559da9f | ||
|
|
09fd22091a | ||
|
|
df717f8bba | ||
|
|
e296719b23 | ||
|
|
2f343179a2 | ||
|
|
4d9604f2e9 | ||
|
|
597c320808 | ||
|
|
18290fd138 | ||
|
|
bbf9e9f868 | ||
|
|
0d0575a639 | ||
|
|
aa1f967dd7 | ||
|
|
0d082327c8 | ||
|
|
80acd9c875 | ||
|
|
17cd4f8210 | ||
|
|
4e041e1d4e | ||
|
|
7ef39770c7 | ||
|
|
8222f638cf | ||
|
|
ab32c314ab | ||
|
|
dcfed97054 | ||
|
|
dd66ca26f7 | ||
|
|
8b91d2ac0a | ||
|
|
e4e00b713f | ||
|
|
710a65522c | ||
|
|
34784c1d40 | ||
|
|
80b1a6f99b | ||
|
|
08c3c56f53 | ||
|
|
294716c832 | ||
|
|
16f4fd636e | ||
|
|
e07caf7a69 | ||
|
|
a95b3daab9 | ||
|
|
4873e9dfdc | ||
|
|
a119ab36fe | ||
|
|
f9384e4e5f | ||
|
|
6fe5f6ee6e | ||
|
|
068d753426 | ||
|
|
5010537f3c |
44
.github/workflows/build-with-jittorllms.yml
vendored
44
.github/workflows/build-with-jittorllms.yml
vendored
@@ -1,44 +0,0 @@
|
||||
# https://docs.github.com/en/actions/publishing-packages/publishing-docker-images#publishing-images-to-github-packages
|
||||
name: build-with-jittorllms
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'master'
|
||||
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}_jittorllms
|
||||
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
file: docs/GithubAction+JittorLLMs
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
@@ -1,14 +1,14 @@
|
||||
# https://docs.github.com/en/actions/publishing-packages/publishing-docker-images#publishing-images-to-github-packages
|
||||
name: build-with-all-capacity-beta
|
||||
name: build-with-latex-arm
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'master'
|
||||
- "master"
|
||||
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}_with_all_capacity_beta
|
||||
IMAGE_NAME: ${{ github.repository }}_with_latex_arm
|
||||
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
@@ -18,11 +18,17 @@ jobs:
|
||||
packages: write
|
||||
|
||||
steps:
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v2
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
@@ -35,10 +41,11 @@ jobs:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v4
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
file: docs/GithubAction+AllCapacityBeta
|
||||
platforms: linux/arm64
|
||||
file: docs/GithubAction+NoLocal+Latex
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -160,3 +160,5 @@ test.*
|
||||
temp.*
|
||||
objdump*
|
||||
*.min.*.js
|
||||
TODO
|
||||
*.cursorrules
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
> [!IMPORTANT]
|
||||
> 2024.6.1: 版本3.80加入插件二级菜单功能(详见wiki)
|
||||
> 2024.10.10: 突发停电,紧急恢复了提供[whl包](https://drive.google.com/file/d/19U_hsLoMrjOlQSzYS3pzWX9fTzyusArP/view?usp=sharing)的文件服务器
|
||||
> 2024.10.8: 版本3.90加入对llama-index的初步支持,版本3.80加入插件二级菜单功能(详见wiki)
|
||||
> 2024.5.1: 加入Doc2x翻译PDF论文的功能,[查看详情](https://github.com/binary-husky/gpt_academic/wiki/Doc2x)
|
||||
> 2024.3.11: 全力支持Qwen、GLM、DeepseekCoder等中文大语言模型! SoVits语音克隆模块,[查看详情](https://www.bilibili.com/video/BV1Rp421S7tF/)
|
||||
> 2024.1.17: 安装依赖时,请选择`requirements.txt`中**指定的版本**。 安装命令:`pip install -r requirements.txt`。本项目完全开源免费,您可通过订阅[在线服务](https://github.com/binary-husky/gpt_academic/wiki/online)的方式鼓励本项目的发展。
|
||||
|
||||
134
check_proxy.py
134
check_proxy.py
@@ -1,48 +1,77 @@
|
||||
from loguru import logger
|
||||
|
||||
def check_proxy(proxies, return_ip=False):
|
||||
"""
|
||||
检查代理配置并返回结果。
|
||||
|
||||
Args:
|
||||
proxies (dict): 包含http和https代理配置的字典。
|
||||
return_ip (bool, optional): 是否返回代理的IP地址。默认为False。
|
||||
|
||||
Returns:
|
||||
str or None: 检查的结果信息或代理的IP地址(如果`return_ip`为True)。
|
||||
"""
|
||||
import requests
|
||||
proxies_https = proxies['https'] if proxies is not None else '无'
|
||||
ip = None
|
||||
try:
|
||||
response = requests.get("https://ipapi.co/json/", proxies=proxies, timeout=4)
|
||||
response = requests.get("https://ipapi.co/json/", proxies=proxies, timeout=4) # ⭐ 执行GET请求以获取代理信息
|
||||
data = response.json()
|
||||
if 'country_name' in data:
|
||||
country = data['country_name']
|
||||
result = f"代理配置 {proxies_https}, 代理所在地:{country}"
|
||||
if 'ip' in data: ip = data['ip']
|
||||
if 'ip' in data:
|
||||
ip = data['ip']
|
||||
elif 'error' in data:
|
||||
alternative, ip = _check_with_backup_source(proxies)
|
||||
alternative, ip = _check_with_backup_source(proxies) # ⭐ 调用备用方法检查代理配置
|
||||
if alternative is None:
|
||||
result = f"代理配置 {proxies_https}, 代理所在地:未知,IP查询频率受限"
|
||||
else:
|
||||
result = f"代理配置 {proxies_https}, 代理所在地:{alternative}"
|
||||
else:
|
||||
result = f"代理配置 {proxies_https}, 代理数据解析失败:{data}"
|
||||
|
||||
if not return_ip:
|
||||
print(result)
|
||||
logger.warning(result)
|
||||
return result
|
||||
else:
|
||||
return ip
|
||||
except:
|
||||
result = f"代理配置 {proxies_https}, 代理所在地查询超时,代理可能无效"
|
||||
if not return_ip:
|
||||
print(result)
|
||||
logger.warning(result)
|
||||
return result
|
||||
else:
|
||||
return ip
|
||||
|
||||
def _check_with_backup_source(proxies):
|
||||
"""
|
||||
通过备份源检查代理,并获取相应信息。
|
||||
|
||||
Args:
|
||||
proxies (dict): 包含代理信息的字典。
|
||||
|
||||
Returns:
|
||||
tuple: 代理信息(geo)和IP地址(ip)的元组。
|
||||
"""
|
||||
import random, string, requests
|
||||
random_string = ''.join(random.choices(string.ascii_letters + string.digits, k=32))
|
||||
try:
|
||||
res_json = requests.get(f"http://{random_string}.edns.ip-api.com/json", proxies=proxies, timeout=4).json()
|
||||
res_json = requests.get(f"http://{random_string}.edns.ip-api.com/json", proxies=proxies, timeout=4).json() # ⭐ 执行代理检查和备份源请求
|
||||
return res_json['dns']['geo'], res_json['dns']['ip']
|
||||
except:
|
||||
return None, None
|
||||
|
||||
def backup_and_download(current_version, remote_version):
|
||||
"""
|
||||
一键更新协议:备份和下载
|
||||
一键更新协议:备份当前版本,下载远程版本并解压缩。
|
||||
|
||||
Args:
|
||||
current_version (str): 当前版本号。
|
||||
remote_version (str): 远程版本号。
|
||||
|
||||
Returns:
|
||||
str: 新版本目录的路径。
|
||||
"""
|
||||
from toolbox import get_conf
|
||||
import shutil
|
||||
@@ -59,7 +88,7 @@ def backup_and_download(current_version, remote_version):
|
||||
proxies = get_conf('proxies')
|
||||
try: r = requests.get('https://github.com/binary-husky/chatgpt_academic/archive/refs/heads/master.zip', proxies=proxies, stream=True)
|
||||
except: r = requests.get('https://public.agent-matrix.com/publish/master.zip', proxies=proxies, stream=True)
|
||||
zip_file_path = backup_dir+'/master.zip'
|
||||
zip_file_path = backup_dir+'/master.zip' # ⭐ 保存备份文件的路径
|
||||
with open(zip_file_path, 'wb+') as f:
|
||||
f.write(r.content)
|
||||
dst_path = new_version_dir
|
||||
@@ -75,6 +104,17 @@ def backup_and_download(current_version, remote_version):
|
||||
def patch_and_restart(path):
|
||||
"""
|
||||
一键更新协议:覆盖和重启
|
||||
|
||||
Args:
|
||||
path (str): 新版本代码所在的路径
|
||||
|
||||
注意事项:
|
||||
如果您的程序没有使用config_private.py私密配置文件,则会将config.py重命名为config_private.py以避免配置丢失。
|
||||
|
||||
更新流程:
|
||||
- 复制最新版本代码到当前目录
|
||||
- 更新pip包依赖
|
||||
- 如果更新失败,则提示手动安装依赖库并重启
|
||||
"""
|
||||
from distutils import dir_util
|
||||
import shutil
|
||||
@@ -82,33 +122,44 @@ def patch_and_restart(path):
|
||||
import sys
|
||||
import time
|
||||
import glob
|
||||
from shared_utils.colorful import print亮黄, print亮绿, print亮红
|
||||
# if not using config_private, move origin config.py as config_private.py
|
||||
from shared_utils.colorful import log亮黄, log亮绿, log亮红
|
||||
|
||||
if not os.path.exists('config_private.py'):
|
||||
print亮黄('由于您没有设置config_private.py私密配置,现将您的现有配置移动至config_private.py以防止配置丢失,',
|
||||
log亮黄('由于您没有设置config_private.py私密配置,现将您的现有配置移动至config_private.py以防止配置丢失,',
|
||||
'另外您可以随时在history子文件夹下找回旧版的程序。')
|
||||
shutil.copyfile('config.py', 'config_private.py')
|
||||
|
||||
path_new_version = glob.glob(path + '/*-master')[0]
|
||||
dir_util.copy_tree(path_new_version, './')
|
||||
print亮绿('代码已经更新,即将更新pip包依赖……')
|
||||
for i in reversed(range(5)): time.sleep(1); print(i)
|
||||
dir_util.copy_tree(path_new_version, './') # ⭐ 将最新版本代码复制到当前目录
|
||||
|
||||
log亮绿('代码已经更新,即将更新pip包依赖……')
|
||||
for i in reversed(range(5)): time.sleep(1); log亮绿(i)
|
||||
|
||||
try:
|
||||
import subprocess
|
||||
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-r', 'requirements.txt'])
|
||||
except:
|
||||
print亮红('pip包依赖安装出现问题,需要手动安装新增的依赖库 `python -m pip install -r requirements.txt`,然后在用常规的`python main.py`的方式启动。')
|
||||
print亮绿('更新完成,您可以随时在history子文件夹下找回旧版的程序,5s之后重启')
|
||||
print亮红('假如重启失败,您可能需要手动安装新增的依赖库 `python -m pip install -r requirements.txt`,然后在用常规的`python main.py`的方式启动。')
|
||||
print(' ------------------------------ -----------------------------------')
|
||||
for i in reversed(range(8)): time.sleep(1); print(i)
|
||||
os.execl(sys.executable, sys.executable, *sys.argv)
|
||||
log亮红('pip包依赖安装出现问题,需要手动安装新增的依赖库 `python -m pip install -r requirements.txt`,然后在用常规的`python main.py`的方式启动。')
|
||||
|
||||
log亮绿('更新完成,您可以随时在history子文件夹下找回旧版的程序,5s之后重启')
|
||||
log亮红('假如重启失败,您可能需要手动安装新增的依赖库 `python -m pip install -r requirements.txt`,然后在用常规的`python main.py`的方式启动。')
|
||||
log亮绿(' ------------------------------ -----------------------------------')
|
||||
|
||||
for i in reversed(range(8)): time.sleep(1); log亮绿(i)
|
||||
os.execl(sys.executable, sys.executable, *sys.argv) # 重启程序
|
||||
|
||||
|
||||
def get_current_version():
|
||||
"""
|
||||
获取当前的版本号。
|
||||
|
||||
Returns:
|
||||
str: 当前的版本号。如果无法获取版本号,则返回空字符串。
|
||||
"""
|
||||
import json
|
||||
try:
|
||||
with open('./version', 'r', encoding='utf8') as f:
|
||||
current_version = json.loads(f.read())['version']
|
||||
current_version = json.loads(f.read())['version'] # ⭐ 从读取的json数据中提取版本号
|
||||
except:
|
||||
current_version = ""
|
||||
return current_version
|
||||
@@ -117,6 +168,12 @@ def get_current_version():
|
||||
def auto_update(raise_error=False):
|
||||
"""
|
||||
一键更新协议:查询版本和用户意见
|
||||
|
||||
Args:
|
||||
raise_error (bool, optional): 是否在出错时抛出错误。默认为 False。
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
try:
|
||||
from toolbox import get_conf
|
||||
@@ -135,22 +192,22 @@ def auto_update(raise_error=False):
|
||||
current_version = f.read()
|
||||
current_version = json.loads(current_version)['version']
|
||||
if (remote_version - current_version) >= 0.01-1e-5:
|
||||
from shared_utils.colorful import print亮黄
|
||||
print亮黄(f'\n新版本可用。新版本:{remote_version},当前版本:{current_version}。{new_feature}')
|
||||
print('(1)Github更新地址:\nhttps://github.com/binary-husky/chatgpt_academic\n')
|
||||
from shared_utils.colorful import log亮黄
|
||||
log亮黄(f'\n新版本可用。新版本:{remote_version},当前版本:{current_version}。{new_feature}') # ⭐ 在控制台打印新版本信息
|
||||
logger.info('(1)Github更新地址:\nhttps://github.com/binary-husky/chatgpt_academic\n')
|
||||
user_instruction = input('(2)是否一键更新代码(Y+回车=确认,输入其他/无输入+回车=不更新)?')
|
||||
if user_instruction in ['Y', 'y']:
|
||||
path = backup_and_download(current_version, remote_version)
|
||||
path = backup_and_download(current_version, remote_version) # ⭐ 备份并下载文件
|
||||
try:
|
||||
patch_and_restart(path)
|
||||
patch_and_restart(path) # ⭐ 执行覆盖并重启操作
|
||||
except:
|
||||
msg = '更新失败。'
|
||||
if raise_error:
|
||||
from toolbox import trimmed_format_exc
|
||||
msg += trimmed_format_exc()
|
||||
print(msg)
|
||||
logger.warning(msg)
|
||||
else:
|
||||
print('自动更新程序:已禁用')
|
||||
logger.info('自动更新程序:已禁用')
|
||||
return
|
||||
else:
|
||||
return
|
||||
@@ -159,10 +216,13 @@ def auto_update(raise_error=False):
|
||||
if raise_error:
|
||||
from toolbox import trimmed_format_exc
|
||||
msg += trimmed_format_exc()
|
||||
print(msg)
|
||||
logger.info(msg)
|
||||
|
||||
def warm_up_modules():
|
||||
print('正在执行一些模块的预热 ...')
|
||||
"""
|
||||
预热模块,加载特定模块并执行预热操作。
|
||||
"""
|
||||
logger.info('正在执行一些模块的预热 ...')
|
||||
from toolbox import ProxyNetworkActivate
|
||||
from request_llms.bridge_all import model_info
|
||||
with ProxyNetworkActivate("Warmup_Modules"):
|
||||
@@ -172,7 +232,17 @@ def warm_up_modules():
|
||||
enc.encode("模块预热", disallowed_special=())
|
||||
|
||||
def warm_up_vectordb():
|
||||
print('正在执行一些模块的预热 ...')
|
||||
"""
|
||||
执行一些模块的预热操作。
|
||||
|
||||
本函数主要用于执行一些模块的预热操作,确保在后续的流程中能够顺利运行。
|
||||
|
||||
⭐ 关键作用:预热模块
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
logger.info('正在执行一些模块的预热 ...')
|
||||
from toolbox import ProxyNetworkActivate
|
||||
with ProxyNetworkActivate("Warmup_Modules"):
|
||||
import nltk
|
||||
@@ -184,4 +254,4 @@ if __name__ == '__main__':
|
||||
os.environ['no_proxy'] = '*' # 避免代理网络产生意外污染
|
||||
from toolbox import get_conf
|
||||
proxies = get_conf('proxies')
|
||||
check_proxy(proxies)
|
||||
check_proxy(proxies)
|
||||
14
config.py
14
config.py
@@ -36,8 +36,11 @@ AVAIL_LLM_MODELS = ["gpt-4-1106-preview", "gpt-4-turbo-preview", "gpt-4-vision-p
|
||||
"gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
|
||||
"gpt-3.5-turbo-1106", "gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5",
|
||||
"gpt-4", "gpt-4-32k", "azure-gpt-4", "glm-4", "glm-4v", "glm-3-turbo",
|
||||
"gemini-pro", "chatglm3"
|
||||
"gemini-1.5-pro", "chatglm3"
|
||||
]
|
||||
|
||||
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
|
||||
# --- --- --- ---
|
||||
# P.S. 其他可用的模型还包括
|
||||
# AVAIL_LLM_MODELS = [
|
||||
@@ -50,12 +53,13 @@ AVAIL_LLM_MODELS = ["gpt-4-1106-preview", "gpt-4-turbo-preview", "gpt-4-vision-p
|
||||
# "claude-3-haiku-20240307","claude-3-sonnet-20240229","claude-3-opus-20240229", "claude-2.1", "claude-instant-1.2",
|
||||
# "moss", "llama2", "chatglm_onnx", "internlm", "jittorllms_pangualpha", "jittorllms_llama",
|
||||
# "deepseek-chat" ,"deepseek-coder",
|
||||
# "gemini-1.5-flash",
|
||||
# "yi-34b-chat-0205","yi-34b-chat-200k","yi-large","yi-medium","yi-spark","yi-large-turbo","yi-large-preview",
|
||||
# ]
|
||||
# --- --- --- ---
|
||||
# 此外,您还可以在接入one-api/vllm/ollama时,
|
||||
# 使用"one-api-*","vllm-*","ollama-*"前缀直接使用非标准方式接入的模型,例如
|
||||
# AVAIL_LLM_MODELS = ["one-api-claude-3-sonnet-20240229(max_token=100000)", "ollama-phi3(max_token=4096)"]
|
||||
# 此外,您还可以在接入one-api/vllm/ollama/Openroute时,
|
||||
# 使用"one-api-*","vllm-*","ollama-*","openrouter-*"前缀直接使用非标准方式接入的模型,例如
|
||||
# AVAIL_LLM_MODELS = ["one-api-claude-3-sonnet-20240229(max_token=100000)", "ollama-phi3(max_token=4096)","openrouter-openai/gpt-4o-mini","openrouter-openai/chatgpt-4o-latest"]
|
||||
# --- --- --- ---
|
||||
|
||||
|
||||
@@ -295,7 +299,7 @@ ARXIV_CACHE_DIR = "gpt_log/arxiv_cache"
|
||||
|
||||
# 除了连接OpenAI之外,还有哪些场合允许使用代理,请尽量不要修改
|
||||
WHEN_TO_USE_PROXY = ["Download_LLM", "Download_Gradio_Theme", "Connect_Grobid",
|
||||
"Warmup_Modules", "Nougat_Download", "AutoGen"]
|
||||
"Warmup_Modules", "Nougat_Download", "AutoGen", "Connect_OpenAI_Embedding"]
|
||||
|
||||
|
||||
# 启用插件热加载
|
||||
|
||||
@@ -17,7 +17,7 @@ def get_core_functions():
|
||||
text_show_english=
|
||||
r"Below is a paragraph from an academic paper. Polish the writing to meet the academic style, "
|
||||
r"improve the spelling, grammar, clarity, concision and overall readability. When necessary, rewrite the whole sentence. "
|
||||
r"Firstly, you should provide the polished paragraph. "
|
||||
r"Firstly, you should provide the polished paragraph (in English). "
|
||||
r"Secondly, you should list all your modification and explain the reasons to do so in markdown table.",
|
||||
text_show_chinese=
|
||||
r"作为一名中文学术论文写作改进助理,你的任务是改进所提供文本的拼写、语法、清晰、简洁和整体可读性,"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from toolbox import HotReload # HotReload 的意思是热更新,修改函数插件后,不需要重启程序,代码直接生效
|
||||
from toolbox import trimmed_format_exc
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def get_crazy_functions():
|
||||
@@ -14,13 +15,13 @@ def get_crazy_functions():
|
||||
from crazy_functions.SourceCode_Analyse import 解析一个Rust项目
|
||||
from crazy_functions.SourceCode_Analyse import 解析一个Java项目
|
||||
from crazy_functions.SourceCode_Analyse import 解析一个前端项目
|
||||
from crazy_functions.Arxiv_论文对话 import Arxiv论文对话
|
||||
from crazy_functions.高级功能函数模板 import 高阶功能模板函数
|
||||
from crazy_functions.高级功能函数模板 import Demo_Wrap
|
||||
from crazy_functions.Latex全文润色 import Latex英文润色
|
||||
from crazy_functions.询问多个大语言模型 import 同时问询
|
||||
from crazy_functions.SourceCode_Analyse import 解析一个Lua项目
|
||||
from crazy_functions.SourceCode_Analyse import 解析一个CSharp项目
|
||||
from crazy_functions.总结word文档 import 总结word文档
|
||||
from crazy_functions.解析JupyterNotebook import 解析ipynb文件
|
||||
from crazy_functions.Conversation_To_File import 载入对话历史存档
|
||||
from crazy_functions.Conversation_To_File import 对话历史存档
|
||||
@@ -30,6 +31,8 @@ def get_crazy_functions():
|
||||
from crazy_functions.Markdown_Translate import Markdown英译中
|
||||
from crazy_functions.批量总结PDF文档 import 批量总结PDF文档
|
||||
from crazy_functions.PDF_Translate import 批量翻译PDF文档
|
||||
from crazy_functions.批量文件询问 import 批量文件询问
|
||||
|
||||
from crazy_functions.谷歌检索小助手 import 谷歌检索小助手
|
||||
from crazy_functions.理解PDF文档内容 import 理解PDF文档内容标准文件输入
|
||||
from crazy_functions.Latex全文润色 import Latex中文润色
|
||||
@@ -48,6 +51,7 @@ def get_crazy_functions():
|
||||
from crazy_functions.Image_Generate import 图片生成_DALLE2, 图片生成_DALLE3, 图片修改_DALLE2
|
||||
from crazy_functions.Image_Generate_Wrap import ImageGen_Wrap
|
||||
from crazy_functions.SourceCode_Comment import 注释Python项目
|
||||
from crazy_functions.SourceCode_Comment_Wrap import SourceCodeComment_Wrap
|
||||
|
||||
function_plugins = {
|
||||
"虚空终端": {
|
||||
@@ -57,33 +61,6 @@ def get_crazy_functions():
|
||||
"Info": "使用自然语言实现您的想法",
|
||||
"Function": HotReload(虚空终端),
|
||||
},
|
||||
"解析整个Python项目": {
|
||||
"Group": "编程",
|
||||
"Color": "stop",
|
||||
"AsButton": True,
|
||||
"Info": "解析一个Python项目的所有源文件(.py) | 输入参数为路径",
|
||||
"Function": HotReload(解析一个Python项目),
|
||||
},
|
||||
"注释Python项目": {
|
||||
"Group": "编程",
|
||||
"Color": "stop",
|
||||
"AsButton": False,
|
||||
"Info": "上传一系列python源文件(或者压缩包), 为这些代码添加docstring | 输入参数为路径",
|
||||
"Function": HotReload(注释Python项目),
|
||||
},
|
||||
"载入对话历史存档(先上传存档或输入路径)": {
|
||||
"Group": "对话",
|
||||
"Color": "stop",
|
||||
"AsButton": False,
|
||||
"Info": "载入对话历史存档 | 输入参数为路径",
|
||||
"Function": HotReload(载入对话历史存档),
|
||||
},
|
||||
"删除所有本地对话历史记录(谨慎操作)": {
|
||||
"Group": "对话",
|
||||
"AsButton": False,
|
||||
"Info": "删除所有本地对话历史记录,谨慎操作 | 不需要输入参数",
|
||||
"Function": HotReload(删除所有本地对话历史记录),
|
||||
},
|
||||
"清除所有缓存文件(谨慎操作)": {
|
||||
"Group": "对话",
|
||||
"Color": "stop",
|
||||
@@ -91,14 +68,6 @@ def get_crazy_functions():
|
||||
"Info": "清除所有缓存文件,谨慎操作 | 不需要输入参数",
|
||||
"Function": HotReload(清除缓存),
|
||||
},
|
||||
"生成多种Mermaid图表(从当前对话或路径(.pdf/.md/.docx)中生产图表)": {
|
||||
"Group": "对话",
|
||||
"Color": "stop",
|
||||
"AsButton": False,
|
||||
"Info" : "基于当前对话或文件生成多种Mermaid图表,图表类型由模型判断",
|
||||
"Function": None,
|
||||
"Class": Mermaid_Gen
|
||||
},
|
||||
"Arxiv论文翻译": {
|
||||
"Group": "学术",
|
||||
"Color": "stop",
|
||||
@@ -107,91 +76,25 @@ def get_crazy_functions():
|
||||
"Function": HotReload(Latex翻译中文并重新编译PDF), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
|
||||
"Class": Arxiv_Localize, # 新一代插件需要注册Class
|
||||
},
|
||||
"批量总结Word文档": {
|
||||
"批量文件询问": {
|
||||
"Group": "学术",
|
||||
"Color": "stop",
|
||||
"AsButton": False,
|
||||
"Info": "批量总结word文档 | 输入参数为路径",
|
||||
"Function": HotReload(总结word文档),
|
||||
"AdvancedArgs": True,
|
||||
"Info": "通过在高级参数区写入prompt,可自定义询问逻辑,默认情况下为总结逻辑 | 输入参数为路径",
|
||||
"ArgsReminder": r"1、请不要更改上方输入框中以“private_upload/...”开头的路径。 "
|
||||
r"2、请在下方高级参数区中输入你的prompt,文档中的内容将被添加你的prompt后。3、示例:“请总结下面的内容:”,此时,文档内容将添加在“:”后 ",
|
||||
"Function": HotReload(批量文件询问),
|
||||
},
|
||||
"解析整个Matlab项目": {
|
||||
"Group": "编程",
|
||||
"Color": "stop",
|
||||
"AsButton": False,
|
||||
"Info": "解析一个Matlab项目的所有源文件(.m) | 输入参数为路径",
|
||||
"Function": HotReload(解析一个Matlab项目),
|
||||
},
|
||||
"解析整个C++项目头文件": {
|
||||
"Group": "编程",
|
||||
"Color": "stop",
|
||||
"AsButton": False, # 加入下拉菜单中
|
||||
"Info": "解析一个C++项目的所有头文件(.h/.hpp) | 输入参数为路径",
|
||||
"Function": HotReload(解析一个C项目的头文件),
|
||||
},
|
||||
"解析整个C++项目(.cpp/.hpp/.c/.h)": {
|
||||
"Group": "编程",
|
||||
"Color": "stop",
|
||||
"AsButton": False, # 加入下拉菜单中
|
||||
"Info": "解析一个C++项目的所有源文件(.cpp/.hpp/.c/.h)| 输入参数为路径",
|
||||
"Function": HotReload(解析一个C项目),
|
||||
},
|
||||
"解析整个Go项目": {
|
||||
"Group": "编程",
|
||||
"Color": "stop",
|
||||
"AsButton": False, # 加入下拉菜单中
|
||||
"Info": "解析一个Go项目的所有源文件 | 输入参数为路径",
|
||||
"Function": HotReload(解析一个Golang项目),
|
||||
},
|
||||
"解析整个Rust项目": {
|
||||
"Group": "编程",
|
||||
"Color": "stop",
|
||||
"AsButton": False, # 加入下拉菜单中
|
||||
"Info": "解析一个Rust项目的所有源文件 | 输入参数为路径",
|
||||
"Function": HotReload(解析一个Rust项目),
|
||||
},
|
||||
"解析整个Java项目": {
|
||||
"Group": "编程",
|
||||
"Color": "stop",
|
||||
"AsButton": False, # 加入下拉菜单中
|
||||
"Info": "解析一个Java项目的所有源文件 | 输入参数为路径",
|
||||
"Function": HotReload(解析一个Java项目),
|
||||
},
|
||||
"解析整个前端项目(js,ts,css等)": {
|
||||
"Group": "编程",
|
||||
"Color": "stop",
|
||||
"AsButton": False, # 加入下拉菜单中
|
||||
"Info": "解析一个前端项目的所有源文件(js,ts,css等) | 输入参数为路径",
|
||||
"Function": HotReload(解析一个前端项目),
|
||||
},
|
||||
"解析整个Lua项目": {
|
||||
"Group": "编程",
|
||||
"Color": "stop",
|
||||
"AsButton": False, # 加入下拉菜单中
|
||||
"Info": "解析一个Lua项目的所有源文件 | 输入参数为路径",
|
||||
"Function": HotReload(解析一个Lua项目),
|
||||
},
|
||||
"解析整个CSharp项目": {
|
||||
"Group": "编程",
|
||||
"Color": "stop",
|
||||
"AsButton": False, # 加入下拉菜单中
|
||||
"Info": "解析一个CSharp项目的所有源文件 | 输入参数为路径",
|
||||
"Function": HotReload(解析一个CSharp项目),
|
||||
},
|
||||
"解析Jupyter Notebook文件": {
|
||||
"Group": "编程",
|
||||
"Color": "stop",
|
||||
"AsButton": False,
|
||||
"Info": "解析Jupyter Notebook文件 | 输入参数为路径",
|
||||
"Function": HotReload(解析ipynb文件),
|
||||
"AdvancedArgs": True, # 调用时,唤起高级参数输入区(默认False)
|
||||
"ArgsReminder": "若输入0,则不解析notebook中的Markdown块", # 高级参数输入区的显示提示
|
||||
},
|
||||
"读Tex论文写摘要": {
|
||||
"Arxiv论文对话": {
|
||||
"Group": "学术",
|
||||
"Color": "stop",
|
||||
"AsButton": False,
|
||||
"Info": "读取Tex论文并写摘要 | 输入参数为路径",
|
||||
"Function": HotReload(读文章写摘要),
|
||||
"AdvancedArgs": True,
|
||||
"Info": "在输入区中输入论文ID,在高级参数区中输入问题",
|
||||
"ArgsReminder": r"1、请在输入区中输入arxiv ID。 "
|
||||
r"2、请在下方高级参数区中输入你的问题,示例:“这篇文章的方法是什么,请用中文回答我” ",
|
||||
"Function": HotReload(Arxiv论文对话),
|
||||
},
|
||||
"翻译README或MD": {
|
||||
"Group": "编程",
|
||||
@@ -421,8 +324,8 @@ def get_crazy_functions():
|
||||
}
|
||||
)
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
print("Load function plugin failed")
|
||||
logger.error(trimmed_format_exc())
|
||||
logger.error("Load function plugin failed")
|
||||
|
||||
# try:
|
||||
# from crazy_functions.联网的ChatGPT import 连接网络回答问题
|
||||
@@ -452,8 +355,8 @@ def get_crazy_functions():
|
||||
# }
|
||||
# )
|
||||
# except:
|
||||
# print(trimmed_format_exc())
|
||||
# print("Load function plugin failed")
|
||||
# logger.error(trimmed_format_exc())
|
||||
# logger.error("Load function plugin failed")
|
||||
|
||||
try:
|
||||
from crazy_functions.SourceCode_Analyse import 解析任意code项目
|
||||
@@ -471,8 +374,8 @@ def get_crazy_functions():
|
||||
}
|
||||
)
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
print("Load function plugin failed")
|
||||
logger.error(trimmed_format_exc())
|
||||
logger.error("Load function plugin failed")
|
||||
|
||||
try:
|
||||
from crazy_functions.询问多个大语言模型 import 同时问询_指定模型
|
||||
@@ -490,8 +393,8 @@ def get_crazy_functions():
|
||||
}
|
||||
)
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
print("Load function plugin failed")
|
||||
logger.error(trimmed_format_exc())
|
||||
logger.error("Load function plugin failed")
|
||||
|
||||
|
||||
|
||||
@@ -512,8 +415,8 @@ def get_crazy_functions():
|
||||
}
|
||||
)
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
print("Load function plugin failed")
|
||||
logger.error(trimmed_format_exc())
|
||||
logger.error("Load function plugin failed")
|
||||
|
||||
try:
|
||||
from crazy_functions.数学动画生成manim import 动画生成
|
||||
@@ -530,8 +433,8 @@ def get_crazy_functions():
|
||||
}
|
||||
)
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
print("Load function plugin failed")
|
||||
logger.error(trimmed_format_exc())
|
||||
logger.error("Load function plugin failed")
|
||||
|
||||
try:
|
||||
from crazy_functions.Markdown_Translate import Markdown翻译指定语言
|
||||
@@ -549,8 +452,8 @@ def get_crazy_functions():
|
||||
}
|
||||
)
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
print("Load function plugin failed")
|
||||
logger.error(trimmed_format_exc())
|
||||
logger.error("Load function plugin failed")
|
||||
|
||||
try:
|
||||
from crazy_functions.知识库问答 import 知识库文件注入
|
||||
@@ -568,8 +471,8 @@ def get_crazy_functions():
|
||||
}
|
||||
)
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
print("Load function plugin failed")
|
||||
logger.error(trimmed_format_exc())
|
||||
logger.error("Load function plugin failed")
|
||||
|
||||
try:
|
||||
from crazy_functions.知识库问答 import 读取知识库作答
|
||||
@@ -587,8 +490,8 @@ def get_crazy_functions():
|
||||
}
|
||||
)
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
print("Load function plugin failed")
|
||||
logger.error(trimmed_format_exc())
|
||||
logger.error("Load function plugin failed")
|
||||
|
||||
try:
|
||||
from crazy_functions.交互功能函数模板 import 交互功能模板函数
|
||||
@@ -604,8 +507,8 @@ def get_crazy_functions():
|
||||
}
|
||||
)
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
print("Load function plugin failed")
|
||||
logger.error(trimmed_format_exc())
|
||||
logger.error("Load function plugin failed")
|
||||
|
||||
|
||||
try:
|
||||
@@ -627,8 +530,8 @@ def get_crazy_functions():
|
||||
}
|
||||
)
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
print("Load function plugin failed")
|
||||
logger.error(trimmed_format_exc())
|
||||
logger.error("Load function plugin failed")
|
||||
|
||||
try:
|
||||
from crazy_functions.批量翻译PDF文档_NOUGAT import 批量翻译PDF文档
|
||||
@@ -644,8 +547,8 @@ def get_crazy_functions():
|
||||
}
|
||||
)
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
print("Load function plugin failed")
|
||||
logger.error(trimmed_format_exc())
|
||||
logger.error("Load function plugin failed")
|
||||
|
||||
try:
|
||||
from crazy_functions.函数动态生成 import 函数动态生成
|
||||
@@ -661,8 +564,8 @@ def get_crazy_functions():
|
||||
}
|
||||
)
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
print("Load function plugin failed")
|
||||
logger.error(trimmed_format_exc())
|
||||
logger.error("Load function plugin failed")
|
||||
|
||||
try:
|
||||
from crazy_functions.多智能体 import 多智能体终端
|
||||
@@ -678,8 +581,8 @@ def get_crazy_functions():
|
||||
}
|
||||
)
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
print("Load function plugin failed")
|
||||
logger.error(trimmed_format_exc())
|
||||
logger.error("Load function plugin failed")
|
||||
|
||||
try:
|
||||
from crazy_functions.互动小游戏 import 随机小游戏
|
||||
@@ -695,8 +598,33 @@ def get_crazy_functions():
|
||||
}
|
||||
)
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
print("Load function plugin failed")
|
||||
logger.error(trimmed_format_exc())
|
||||
logger.error("Load function plugin failed")
|
||||
|
||||
try:
|
||||
from crazy_functions.Rag_Interface import Rag问答
|
||||
|
||||
function_plugins.update(
|
||||
{
|
||||
"Rag智能召回": {
|
||||
"Group": "对话",
|
||||
"Color": "stop",
|
||||
"AsButton": False,
|
||||
"Info": "将问答数据记录到向量库中,作为长期参考。",
|
||||
"Function": HotReload(Rag问答),
|
||||
},
|
||||
}
|
||||
)
|
||||
except:
|
||||
logger.error(trimmed_format_exc())
|
||||
logger.error("Load function plugin failed")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# try:
|
||||
# from crazy_functions.高级功能函数模板 import 测试图表渲染
|
||||
@@ -709,7 +637,7 @@ def get_crazy_functions():
|
||||
# }
|
||||
# })
|
||||
# except:
|
||||
# print(trimmed_format_exc())
|
||||
# logger.error(trimmed_format_exc())
|
||||
# print('Load function plugin failed')
|
||||
|
||||
# try:
|
||||
|
||||
573
crazy_functions/Arxiv_论文对话.py
普通文件
573
crazy_functions/Arxiv_论文对话.py
普通文件
@@ -0,0 +1,573 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from threading import Lock as ThreadLock
|
||||
from typing import Generator
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from crazy_functions.crazy_utils import input_clipping
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import ArxivSplitter, save_fragments_to_file, process_arxiv_sync
|
||||
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment as Fragment
|
||||
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
|
||||
from toolbox import CatchException, update_ui, get_log_folder, update_ui_lastest_msg
|
||||
|
||||
# 全局常量配置
|
||||
MAX_HISTORY_ROUND = 5 # 最大历史对话轮数
|
||||
MAX_CONTEXT_TOKEN_LIMIT = 4096 # 上下文最大token数
|
||||
REMEMBER_PREVIEW = 1000 # 记忆预览长度
|
||||
VECTOR_STORE_TYPE = "Simple" # 向量存储类型:Simple或Milvus
|
||||
MAX_CONCURRENT_PAPERS = 20 # 最大并行处理论文数
|
||||
MAX_WORKERS = 3 # 最大工作线程数
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessingTask:
|
||||
"""论文处理任务数据类"""
|
||||
arxiv_id: str
|
||||
status: str = "pending" # pending, processing, completed, failed
|
||||
error: Optional[str] = None
|
||||
fragments: List[Fragment] = None
|
||||
start_time: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
class ArxivRagWorker:
|
||||
def __init__(self, user_name: str, llm_kwargs: Dict, arxiv_id: str = None):
|
||||
"""初始化ArxivRagWorker"""
|
||||
self.user_name = user_name
|
||||
self.llm_kwargs = llm_kwargs
|
||||
self.arxiv_id = self._normalize_arxiv_id(arxiv_id) if arxiv_id else None
|
||||
self.fragments = None
|
||||
|
||||
|
||||
# 初始化基础目录
|
||||
self.base_dir = Path(get_log_folder( plugin_name='arxiv_rag_cache'))
|
||||
self._setup_directories()
|
||||
|
||||
# 初始化处理状态
|
||||
|
||||
# 线程安全的计数器和集合
|
||||
self._processing_lock = ThreadLock()
|
||||
self._processed_fragments = set()
|
||||
self._processed_count = 0
|
||||
# 优化的线程池配置
|
||||
cpu_count = os.cpu_count() or 1
|
||||
self.thread_pool = ThreadPoolExecutor(
|
||||
max_workers=min(32, cpu_count * 4),
|
||||
thread_name_prefix="arxiv_worker"
|
||||
)
|
||||
|
||||
# 批处理配置
|
||||
self._batch_size = min(20, cpu_count * 2) # 动态设置批大小
|
||||
self.max_concurrent_papers = MAX_CONCURRENT_PAPERS
|
||||
self._semaphore = None
|
||||
self._loop = None
|
||||
|
||||
# 初始化处理队列
|
||||
self.processing_queue = {}
|
||||
|
||||
# 初始化工作组件
|
||||
self._init_workers()
|
||||
|
||||
def _setup_directories(self):
|
||||
"""设置工作目录"""
|
||||
|
||||
if self.arxiv_id:
|
||||
self.checkpoint_dir = self.base_dir / self.arxiv_id
|
||||
self.vector_store_dir = self.checkpoint_dir / "vector_store"
|
||||
self.fragment_store_dir = self.checkpoint_dir / "fragments"
|
||||
else:
|
||||
self.checkpoint_dir = self.base_dir
|
||||
self.vector_store_dir = self.base_dir / "vector_store"
|
||||
self.fragment_store_dir = self.base_dir / "fragments"
|
||||
|
||||
self.paper_path = self.checkpoint_dir / f"{self.arxiv_id}.processed"
|
||||
self.loading = self.paper_path.exists()
|
||||
# 创建必要的目录
|
||||
for directory in [self.checkpoint_dir, self.vector_store_dir, self.fragment_store_dir]:
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Created directory: {directory}")
|
||||
|
||||
def _init_workers(self):
|
||||
"""初始化工作组件"""
|
||||
try:
|
||||
self.rag_worker = LlamaIndexRagWorker(
|
||||
user_name=self.user_name,
|
||||
llm_kwargs=self.llm_kwargs,
|
||||
checkpoint_dir=str(self.vector_store_dir),
|
||||
auto_load_checkpoint=True
|
||||
)
|
||||
|
||||
self.arxiv_splitter = ArxivSplitter(
|
||||
root_dir=str(self.checkpoint_dir / "arxiv_cache")
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing workers: {str(e)}")
|
||||
raise
|
||||
|
||||
def _ensure_loop(self):
|
||||
"""确保存在事件循环"""
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
if self._loop is None:
|
||||
self._loop = asyncio.get_event_loop()
|
||||
else:
|
||||
try:
|
||||
self._loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
return self._loop
|
||||
|
||||
@property
|
||||
def semaphore(self):
|
||||
"""延迟创建semaphore"""
|
||||
if self._semaphore is None:
|
||||
self._semaphore = asyncio.Semaphore(self.max_concurrent_papers)
|
||||
return self._semaphore
|
||||
|
||||
async def _process_fragments(self, fragments: List[Fragment]) -> None:
|
||||
"""优化的并行处理论文片段"""
|
||||
if not fragments:
|
||||
logger.warning("No fragments to process")
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
total_fragments = len(fragments)
|
||||
|
||||
try:
|
||||
# 1. 处理论文概述
|
||||
overview = self._create_overview(fragments[0])
|
||||
overview_success = self._safe_add_to_vector_store_sync(overview['text'])
|
||||
if not overview_success:
|
||||
raise RuntimeError("Failed to add overview to vector store")
|
||||
|
||||
# 2. 并行处理片段
|
||||
successful_fragments = await self._parallel_process_fragments(fragments)
|
||||
|
||||
# 3. 保存处理结果
|
||||
if successful_fragments > 0:
|
||||
await self._save_results(fragments, overview['arxiv_id'], successful_fragments)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in fragment processing: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
self._log_processing_stats(start_time, total_fragments)
|
||||
|
||||
def _create_overview(self, first_fragment: Fragment) -> Dict:
|
||||
"""创建论文概述"""
|
||||
return {
|
||||
'arxiv_id': first_fragment.arxiv_id,
|
||||
'text': (
|
||||
f"Paper Title: {first_fragment.title}\n"
|
||||
f"ArXiv ID: {first_fragment.arxiv_id}\n"
|
||||
f"Abstract: {first_fragment.abstract}\n"
|
||||
f"Table of contents:{first_fragment.catalogs}\n"
|
||||
f"Type: OVERVIEW"
|
||||
)
|
||||
}
|
||||
|
||||
async def _parallel_process_fragments(self, fragments: List[Fragment]) -> int:
|
||||
"""并行处理所有片段"""
|
||||
successful_count = 0
|
||||
loop = self._ensure_loop()
|
||||
|
||||
for i in range(0, len(fragments), self._batch_size):
|
||||
batch = fragments[i:i + self._batch_size]
|
||||
batch_futures = []
|
||||
|
||||
for j, fragment in enumerate(batch):
|
||||
if not self._is_fragment_processed(fragment, i + j):
|
||||
future = loop.run_in_executor(
|
||||
self.thread_pool,
|
||||
self._process_single_fragment_sync,
|
||||
fragment,
|
||||
i + j
|
||||
)
|
||||
batch_futures.append(future)
|
||||
|
||||
if batch_futures:
|
||||
results = await asyncio.gather(*batch_futures, return_exceptions=True)
|
||||
successful_count += sum(1 for r in results if isinstance(r, bool) and r)
|
||||
|
||||
return successful_count
|
||||
|
||||
def _is_fragment_processed(self, fragment: Fragment, index: int) -> bool:
|
||||
"""检查片段是否已处理"""
|
||||
fragment_id = f"{fragment.arxiv_id}_{index}"
|
||||
with self._processing_lock:
|
||||
return fragment_id in self._processed_fragments
|
||||
|
||||
def _safe_add_to_vector_store_sync(self, text: str) -> bool:
|
||||
"""线程安全的向量存储添加"""
|
||||
with self._processing_lock:
|
||||
try:
|
||||
self.rag_worker.add_text_to_vector_store(text)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding to vector store: {str(e)}")
|
||||
return False
|
||||
|
||||
def _process_single_fragment_sync(self, fragment: Fragment, index: int) -> bool:
|
||||
"""处理单个片段"""
|
||||
fragment_id = f"{fragment.arxiv_id}_{index}"
|
||||
try:
|
||||
text = self._build_fragment_text(fragment)
|
||||
if self._safe_add_to_vector_store_sync(text):
|
||||
with self._processing_lock:
|
||||
self._processed_fragments.add(fragment_id)
|
||||
self._processed_count += 1
|
||||
logger.info(f"Successfully processed fragment {index}")
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing fragment {index}: {str(e)}")
|
||||
return False
|
||||
|
||||
def _build_fragment_text(self, fragment: Fragment) -> str:
|
||||
"""构建片段文本"""
|
||||
return "".join([
|
||||
f"Paper Title: {fragment.title}\n",
|
||||
f"Section: {fragment.current_section}\n",
|
||||
f"Content: {fragment.content}\n",
|
||||
f"Bibliography: {fragment.bibliography}\n",
|
||||
"Type: FRAGMENT"
|
||||
])
|
||||
|
||||
async def _save_results(self, fragments: List[Fragment], arxiv_id: str, successful_count: int) -> None:
|
||||
"""保存处理结果"""
|
||||
if successful_count > 0:
|
||||
loop = self._ensure_loop()
|
||||
await loop.run_in_executor(
|
||||
self.thread_pool,
|
||||
save_fragments_to_file,
|
||||
fragments,
|
||||
str(self.fragment_store_dir / f"{arxiv_id}_fragments.json")
|
||||
)
|
||||
|
||||
def _log_processing_stats(self, start_time: float, total_fragments: int) -> None:
|
||||
"""记录处理统计信息"""
|
||||
elapsed_time = time.time() - start_time
|
||||
processing_rate = total_fragments / elapsed_time if elapsed_time > 0 else 0
|
||||
logger.info(
|
||||
f"Processed {self._processed_count}/{total_fragments} fragments "
|
||||
f"in {elapsed_time:.2f}s (rate: {processing_rate:.2f} fragments/s)"
|
||||
)
|
||||
|
||||
async def process_paper(self, fragments: List[Fragment]) -> bool:
|
||||
"""处理论文主函数"""
|
||||
try:
|
||||
|
||||
if self.paper_path.exists():
|
||||
logger.info(f"Paper {self.arxiv_id} already processed")
|
||||
return True
|
||||
|
||||
task = self._create_processing_task(self.arxiv_id)
|
||||
try:
|
||||
async with self.semaphore:
|
||||
await self._process_fragments(fragments)
|
||||
self._complete_task(task, fragments, self.paper_path)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self._fail_task(task, str(e))
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing paper {self.arxiv_id}: {str(e)}")
|
||||
return False
|
||||
|
||||
def _create_processing_task(self, arxiv_id: str) -> ProcessingTask:
|
||||
"""创建处理任务"""
|
||||
task = ProcessingTask(arxiv_id=arxiv_id)
|
||||
with self._processing_lock:
|
||||
self.processing_queue[arxiv_id] = task
|
||||
task.status = "processing"
|
||||
return task
|
||||
|
||||
def _complete_task(self, task: ProcessingTask, fragments: List[Fragment], paper_path: Path) -> None:
|
||||
"""完成任务处理"""
|
||||
with self._processing_lock:
|
||||
task.status = "completed"
|
||||
task.fragments = fragments
|
||||
paper_path.touch()
|
||||
logger.info(f"Paper {task.arxiv_id} processed successfully with {self._processed_count} fragments")
|
||||
|
||||
def _fail_task(self, task: ProcessingTask, error: str) -> None:
|
||||
"""任务失败处理"""
|
||||
with self._processing_lock:
|
||||
task.status = "failed"
|
||||
task.error = error
|
||||
|
||||
def _normalize_arxiv_id(self, input_str: str) -> str:
|
||||
"""规范化ArXiv ID"""
|
||||
if not input_str:
|
||||
return ""
|
||||
|
||||
input_str = input_str.strip().lower()
|
||||
if 'arxiv.org/' in input_str:
|
||||
if '/pdf/' in input_str:
|
||||
arxiv_id = input_str.split('/pdf/')[-1]
|
||||
else:
|
||||
arxiv_id = input_str.split('/abs/')[-1]
|
||||
return arxiv_id.split('v')[0].strip()
|
||||
return input_str.split('v')[0].strip()
|
||||
|
||||
async def wait_for_paper(self, arxiv_id: str, timeout: float = 300.0) -> bool:
|
||||
"""等待论文处理完成"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
while True:
|
||||
with self._processing_lock:
|
||||
task = self.processing_queue.get(arxiv_id)
|
||||
if not task:
|
||||
return False
|
||||
|
||||
if task.status == "completed":
|
||||
return True
|
||||
if task.status == "failed":
|
||||
return False
|
||||
|
||||
if time.time() - start_time > timeout:
|
||||
logger.error(f"Processing paper {arxiv_id} timed out")
|
||||
return False
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
except Exception as e:
|
||||
logger.error(f"Error waiting for paper {arxiv_id}: {str(e)}")
|
||||
return False
|
||||
|
||||
def retrieve_and_generate(self, query: str) -> str:
|
||||
"""检索相关内容并生成提示词"""
|
||||
try:
|
||||
nodes = self.rag_worker.retrieve_from_store_with_query(query)
|
||||
return self.rag_worker.build_prompt(query=query, nodes=nodes)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in retrieve and generate: {str(e)}")
|
||||
return ""
|
||||
|
||||
def remember_qa(self, question: str, answer: str) -> None:
|
||||
"""记忆问答对"""
|
||||
try:
|
||||
self.rag_worker.remember_qa(question, answer)
|
||||
except Exception as e:
|
||||
logger.error(f"Error remembering QA: {str(e)}")
|
||||
|
||||
async def auto_analyze_paper(self, chatbot: List, history: List, system_prompt: str) -> None:
|
||||
"""自动分析论文的关键问题"""
|
||||
key_questions = [
|
||||
"What is the main research question or problem addressed in this paper?",
|
||||
"What methods or approaches did the authors use to investigate the problem?",
|
||||
"What are the key findings or results presented in the paper?",
|
||||
"How do the findings of this paper contribute to the broader field or topic of study?",
|
||||
"What are the limitations of this study, and what future research directions do the authors suggest?"
|
||||
]
|
||||
|
||||
results = []
|
||||
for question in key_questions:
|
||||
try:
|
||||
prompt = self.retrieve_and_generate(question)
|
||||
if prompt:
|
||||
response = await request_gpt_model_in_new_thread_with_ui_alive(
|
||||
inputs=prompt,
|
||||
inputs_show_user=question,
|
||||
llm_kwargs=self.llm_kwargs,
|
||||
chatbot=chatbot,
|
||||
history=history,
|
||||
sys_prompt=system_prompt
|
||||
)
|
||||
results.append(f"Q: {question}\nA: {response}\n")
|
||||
self.remember_qa(question, response)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in auto analysis: {str(e)}")
|
||||
|
||||
# 合并所有结果
|
||||
summary = "\n\n".join(results)
|
||||
chatbot[-1] = (chatbot[-1][0], f"论文已成功加载并完成初步分析:\n\n{summary}\n\n您现在可以继续提问更多细节。")
|
||||
|
||||
@CatchException
|
||||
def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
|
||||
history: List, system_prompt: str, web_port: str) -> Generator:
|
||||
"""
|
||||
Arxiv论文对话主函数
|
||||
Args:
|
||||
txt: arxiv ID/URL
|
||||
llm_kwargs: LLM配置参数
|
||||
plugin_kwargs: 插件配置参数,包含 advanced_arg 字段作为用户询问指令
|
||||
chatbot: 对话历史
|
||||
history: 聊天历史
|
||||
system_prompt: 系统提示词
|
||||
web_port: Web端口
|
||||
"""
|
||||
# 初始化时,提示用户需要 arxiv ID/URL
|
||||
from toolbox import promote_file_to_downloadzone
|
||||
if len(history) == 0 and not txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '0', '1', '2')):
|
||||
chatbot.append((txt, "请先提供Arxiv论文链接或ID。"))
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
user_name = chatbot.get_user()
|
||||
arxiv_worker = ArxivRagWorker(user_name, llm_kwargs, arxiv_id=txt)
|
||||
arxiv_id = arxiv_worker.arxiv_id
|
||||
|
||||
# 处理新论文的情况
|
||||
if txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '0', '1', '2')) and not arxiv_worker.loading:
|
||||
chatbot.append((txt, "正在处理论文,请稍等..."))
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
fragments, formatted_content, fragment_output_files = process_arxiv_sync(arxiv_worker.arxiv_splitter, arxiv_id)
|
||||
for file in fragment_output_files:
|
||||
promote_file_to_downloadzone(file, chatbot=chatbot)
|
||||
chatbot.append(["论文文字内容已保存至下载区,接下来将进行论文编码,请耐心等待三分钟,论文的文字内容为:", formatted_content])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
try:
|
||||
# 创建新的事件循环
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# 设置超时时间为5分钟
|
||||
success = loop.run_until_complete(
|
||||
asyncio.wait_for(arxiv_worker.process_paper(fragments), timeout=300)
|
||||
)
|
||||
if success:
|
||||
success = loop.run_until_complete(
|
||||
asyncio.wait_for(arxiv_worker.wait_for_paper(arxiv_id), timeout=60)
|
||||
)
|
||||
if success:
|
||||
chatbot[-1] = (txt, "论文处理完成,您现在可以开始提问。")
|
||||
else:
|
||||
chatbot[-1] = (txt, "论文处理超时,请重试。")
|
||||
else:
|
||||
chatbot[-1] = (txt, "论文处理失败,请检查论文ID是否正确或稍后重试。")
|
||||
except asyncio.TimeoutError:
|
||||
chatbot[-1] = (txt, "论文处理超时,请重试。")
|
||||
success = False
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
if not success:
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in main process: {str(e)}")
|
||||
chatbot[-1] = (txt, f"处理过程中发生错误: {str(e)}")
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
# 处理用户询问的情况
|
||||
# 获取用户询问指令
|
||||
user_query = plugin_kwargs.get("advanced_arg",
|
||||
"What is the main research question or problem addressed in this paper?")
|
||||
if len(history)<2:
|
||||
fragments, formatted_content, fragment_output_files = process_arxiv_sync(arxiv_worker.arxiv_splitter, arxiv_id)
|
||||
for file in fragment_output_files:
|
||||
promote_file_to_downloadzone(file, chatbot=chatbot)
|
||||
chatbot.append(["论文文字内容已保存至下载区,论文的文字内容为:", formatted_content])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
if not user_query:
|
||||
user_query = "What is the main research question or problem addressed in this paper?"
|
||||
# chatbot.append((txt, "请提供您的问题。"))
|
||||
# yield from update_ui(chatbot=chatbot, history=history)
|
||||
# return
|
||||
|
||||
# 处理历史对话长度
|
||||
if len(history) > MAX_HISTORY_ROUND * 2:
|
||||
history = history[-(MAX_HISTORY_ROUND * 2):]
|
||||
|
||||
# 处理询问指令
|
||||
query_clip, history, flags = input_clipping(
|
||||
user_query,
|
||||
history,
|
||||
max_token_limit=MAX_CONTEXT_TOKEN_LIMIT,
|
||||
return_clip_flags=True
|
||||
)
|
||||
|
||||
if flags["original_input_len"] != flags["clipped_input_len"]:
|
||||
yield from update_ui_lastest_msg('检测到长输入,正在处理...', chatbot, history, delay=0)
|
||||
if len(user_query) > REMEMBER_PREVIEW:
|
||||
HALF = REMEMBER_PREVIEW // 2
|
||||
query_to_remember = user_query[
|
||||
:HALF] + f" ...\n...(省略{len(user_query) - REMEMBER_PREVIEW}字)...\n... " + user_query[
|
||||
-HALF:]
|
||||
else:
|
||||
query_to_remember = query_clip
|
||||
else:
|
||||
query_to_remember = query_clip
|
||||
|
||||
chatbot.append((user_query, "正在思考中..."))
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
# 生成提示词
|
||||
prompt = arxiv_worker.retrieve_and_generate(query_clip)
|
||||
if not prompt:
|
||||
chatbot[-1] = (user_query, "抱歉,处理您的问题时出现错误,请重试。")
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
# 获取回答
|
||||
response = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||
inputs=prompt,
|
||||
inputs_show_user=query_clip,
|
||||
llm_kwargs=llm_kwargs,
|
||||
chatbot=chatbot,
|
||||
history=history,
|
||||
sys_prompt=system_prompt
|
||||
)
|
||||
|
||||
# 记忆问答对
|
||||
# worker.remember_qa(query_to_remember, response)
|
||||
history.extend([user_query, response])
|
||||
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试代码
|
||||
llm_kwargs = {
|
||||
'api_key': os.getenv("one_api_key"),
|
||||
'client_ip': '127.0.0.1',
|
||||
'embed_model': 'text-embedding-3-small',
|
||||
'llm_model': 'one-api-Qwen2.5-72B-Instruct',
|
||||
'max_length': 4096,
|
||||
'most_recent_uploaded': None,
|
||||
'temperature': 1,
|
||||
'top_p': 1
|
||||
}
|
||||
plugin_kwargs = {}
|
||||
chatbot = []
|
||||
history = []
|
||||
system_prompt = "You are a helpful assistant."
|
||||
web_port = "8080"
|
||||
|
||||
# 测试论文导入
|
||||
arxiv_url = "https://arxiv.org/abs/2312.12345"
|
||||
for response in Arxiv论文对话(
|
||||
arxiv_url, llm_kwargs, plugin_kwargs,
|
||||
chatbot, history, system_prompt, web_port
|
||||
):
|
||||
print(response)
|
||||
|
||||
# 测试问答
|
||||
question = "这篇论文的主要贡献是什么?"
|
||||
for response in Arxiv论文对话(
|
||||
question, llm_kwargs, plugin_kwargs,
|
||||
chatbot, history, system_prompt, web_port
|
||||
):
|
||||
print(response)
|
||||
@@ -152,8 +152,6 @@ class Conversation_To_File_Wrap(GptAcademicPluginTemplate):
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def hide_cwd(str):
|
||||
import os
|
||||
current_path = os.getcwd()
|
||||
@@ -171,7 +169,7 @@ def 载入对话历史存档(txt, llm_kwargs, plugin_kwargs, chatbot, history, s
|
||||
system_prompt 给gpt的静默提醒
|
||||
user_request 当前用户的请求信息(IP地址等)
|
||||
"""
|
||||
from .crazy_utils import get_files_from_everything
|
||||
from crazy_functions.crazy_utils import get_files_from_everything
|
||||
success, file_manifest, _ = get_files_from_everything(txt, type='.html')
|
||||
|
||||
if not success:
|
||||
|
||||
@@ -30,7 +30,7 @@ def gen_image(llm_kwargs, prompt, resolution="1024x1024", model="dall-e-2", qual
|
||||
if style is not None:
|
||||
data['style'] = style
|
||||
response = requests.post(url, headers=headers, json=data, proxies=proxies)
|
||||
print(response.content)
|
||||
# logger.info(response.content)
|
||||
try:
|
||||
image_url = json.loads(response.content.decode('utf8'))['data'][0]['url']
|
||||
except:
|
||||
@@ -76,7 +76,7 @@ def edit_image(llm_kwargs, prompt, image_path, resolution="1024x1024", model="da
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, files=files, proxies=proxies)
|
||||
print(response.content)
|
||||
# logger.info(response.content)
|
||||
try:
|
||||
image_url = json.loads(response.content.decode('utf8'))['data'][0]['url']
|
||||
except:
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from toolbox import update_ui, trimmed_format_exc, get_conf, get_log_folder, promote_file_to_downloadzone, check_repeat_upload, map_file_to_sha256
|
||||
from toolbox import CatchException, report_exception, update_ui_lastest_msg, zip_result, gen_time_str
|
||||
from functools import partial
|
||||
import glob, os, requests, time, json, tarfile
|
||||
from loguru import logger
|
||||
|
||||
import glob, os, requests, time, json, tarfile, threading
|
||||
|
||||
pj = os.path.join
|
||||
ARXIV_CACHE_DIR = get_conf("ARXIV_CACHE_DIR")
|
||||
@@ -136,25 +138,43 @@ def arxiv_download(chatbot, history, txt, allow_cache=True):
|
||||
cached_translation_pdf = check_cached_translation_pdf(arxiv_id)
|
||||
if cached_translation_pdf and allow_cache: return cached_translation_pdf, arxiv_id
|
||||
|
||||
url_tar = url_.replace('/abs/', '/e-print/')
|
||||
translation_dir = pj(ARXIV_CACHE_DIR, arxiv_id, 'e-print')
|
||||
extract_dst = pj(ARXIV_CACHE_DIR, arxiv_id, 'extract')
|
||||
os.makedirs(translation_dir, exist_ok=True)
|
||||
|
||||
# <-------------- download arxiv source file ------------->
|
||||
translation_dir = pj(ARXIV_CACHE_DIR, arxiv_id, 'e-print')
|
||||
dst = pj(translation_dir, arxiv_id + '.tar')
|
||||
if os.path.exists(dst):
|
||||
yield from update_ui_lastest_msg("调用缓存", chatbot=chatbot, history=history) # 刷新界面
|
||||
os.makedirs(translation_dir, exist_ok=True)
|
||||
# <-------------- download arxiv source file ------------->
|
||||
|
||||
def fix_url_and_download():
|
||||
# for url_tar in [url_.replace('/abs/', '/e-print/'), url_.replace('/abs/', '/src/')]:
|
||||
for url_tar in [url_.replace('/abs/', '/src/'), url_.replace('/abs/', '/e-print/')]:
|
||||
proxies = get_conf('proxies')
|
||||
r = requests.get(url_tar, proxies=proxies)
|
||||
if r.status_code == 200:
|
||||
with open(dst, 'wb+') as f:
|
||||
f.write(r.content)
|
||||
return True
|
||||
return False
|
||||
|
||||
if os.path.exists(dst) and allow_cache:
|
||||
yield from update_ui_lastest_msg(f"调用缓存 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
|
||||
success = True
|
||||
else:
|
||||
yield from update_ui_lastest_msg("开始下载", chatbot=chatbot, history=history) # 刷新界面
|
||||
proxies = get_conf('proxies')
|
||||
r = requests.get(url_tar, proxies=proxies)
|
||||
with open(dst, 'wb+') as f:
|
||||
f.write(r.content)
|
||||
yield from update_ui_lastest_msg(f"开始下载 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
|
||||
success = fix_url_and_download()
|
||||
yield from update_ui_lastest_msg(f"下载完成 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
|
||||
if not success:
|
||||
yield from update_ui_lastest_msg(f"下载失败 {arxiv_id}", chatbot=chatbot, history=history)
|
||||
raise tarfile.ReadError(f"论文下载失败 {arxiv_id}")
|
||||
|
||||
# <-------------- extract file ------------->
|
||||
yield from update_ui_lastest_msg("下载完成", chatbot=chatbot, history=history) # 刷新界面
|
||||
from toolbox import extract_archive
|
||||
extract_archive(file_path=dst, dest_dir=extract_dst)
|
||||
try:
|
||||
extract_archive(file_path=dst, dest_dir=extract_dst)
|
||||
except tarfile.ReadError:
|
||||
os.remove(dst)
|
||||
raise tarfile.ReadError(f"论文下载失败")
|
||||
return extract_dst, arxiv_id
|
||||
|
||||
|
||||
@@ -178,7 +198,7 @@ def pdf2tex_project(pdf_file_path, plugin_kwargs):
|
||||
|
||||
if response.ok:
|
||||
pdf_id = response.json()["pdf_id"]
|
||||
print(f"PDF processing initiated. PDF ID: {pdf_id}")
|
||||
logger.info(f"PDF processing initiated. PDF ID: {pdf_id}")
|
||||
|
||||
# Step 2: Check processing status
|
||||
while True:
|
||||
@@ -186,12 +206,12 @@ def pdf2tex_project(pdf_file_path, plugin_kwargs):
|
||||
conversion_data = conversion_response.json()
|
||||
|
||||
if conversion_data["status"] == "completed":
|
||||
print("PDF processing completed.")
|
||||
logger.info("PDF processing completed.")
|
||||
break
|
||||
elif conversion_data["status"] == "error":
|
||||
print("Error occurred during processing.")
|
||||
logger.info("Error occurred during processing.")
|
||||
else:
|
||||
print(f"Processing status: {conversion_data['status']}")
|
||||
logger.info(f"Processing status: {conversion_data['status']}")
|
||||
time.sleep(5) # wait for a few seconds before checking again
|
||||
|
||||
# Step 3: Save results to local files
|
||||
@@ -206,7 +226,7 @@ def pdf2tex_project(pdf_file_path, plugin_kwargs):
|
||||
output_path = os.path.join(output_dir, output_name)
|
||||
with open(output_path, "wb") as output_file:
|
||||
output_file.write(response.content)
|
||||
print(f"tex.zip file saved at: {output_path}")
|
||||
logger.info(f"tex.zip file saved at: {output_path}")
|
||||
|
||||
import zipfile
|
||||
unzip_dir = os.path.join(output_dir, file_name_wo_dot)
|
||||
@@ -216,7 +236,7 @@ def pdf2tex_project(pdf_file_path, plugin_kwargs):
|
||||
return unzip_dir
|
||||
|
||||
else:
|
||||
print(f"Error sending PDF for processing. Status code: {response.status_code}")
|
||||
logger.error(f"Error sending PDF for processing. Status code: {response.status_code}")
|
||||
return None
|
||||
else:
|
||||
from crazy_functions.pdf_fns.parse_pdf_via_doc2x import 解析PDF_DOC2X_转Latex
|
||||
@@ -318,11 +338,17 @@ def Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot,
|
||||
# <-------------- more requirements ------------->
|
||||
if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
|
||||
more_req = plugin_kwargs.get("advanced_arg", "")
|
||||
no_cache = more_req.startswith("--no-cache")
|
||||
if no_cache: more_req.lstrip("--no-cache")
|
||||
|
||||
no_cache = ("--no-cache" in more_req)
|
||||
if no_cache: more_req = more_req.replace("--no-cache", "").strip()
|
||||
|
||||
allow_gptac_cloud_io = ("--allow-cloudio" in more_req) # 从云端下载翻译结果,以及上传翻译结果到云端
|
||||
if allow_gptac_cloud_io: more_req = more_req.replace("--allow-cloudio", "").strip()
|
||||
|
||||
allow_cache = not no_cache
|
||||
_switch_prompt_ = partial(switch_prompt, more_requirement=more_req)
|
||||
|
||||
|
||||
# <-------------- check deps ------------->
|
||||
try:
|
||||
import glob, os, time, subprocess
|
||||
@@ -349,6 +375,20 @@ def Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot,
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
return
|
||||
|
||||
# #################################################################
|
||||
if allow_gptac_cloud_io and arxiv_id:
|
||||
# 访问 GPTAC学术云,查询云端是否存在该论文的翻译版本
|
||||
from crazy_functions.latex_fns.latex_actions import check_gptac_cloud
|
||||
success, downloaded = check_gptac_cloud(arxiv_id, chatbot)
|
||||
if success:
|
||||
chatbot.append([
|
||||
f"检测到GPTAC云端存在翻译版本, 如果不满意翻译结果, 请禁用云端分享, 然后重新执行。",
|
||||
None
|
||||
])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
#################################################################
|
||||
|
||||
if os.path.exists(txt):
|
||||
project_folder = txt
|
||||
else:
|
||||
@@ -386,14 +426,21 @@ def Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot,
|
||||
# <-------------- zip PDF ------------->
|
||||
zip_res = zip_result(project_folder)
|
||||
if success:
|
||||
if allow_gptac_cloud_io and arxiv_id:
|
||||
# 如果用户允许,我们将翻译好的arxiv论文PDF上传到GPTAC学术云
|
||||
from crazy_functions.latex_fns.latex_actions import upload_to_gptac_cloud_if_user_allow
|
||||
threading.Thread(target=upload_to_gptac_cloud_if_user_allow,
|
||||
args=(chatbot, arxiv_id), daemon=True).start()
|
||||
|
||||
chatbot.append((f"成功啦", '请查收结果(压缩包)...'))
|
||||
yield from update_ui(chatbot=chatbot, history=history);
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
time.sleep(1) # 刷新界面
|
||||
promote_file_to_downloadzone(file=zip_res, chatbot=chatbot)
|
||||
|
||||
else:
|
||||
chatbot.append((f"失败了",
|
||||
'虽然PDF生成失败了, 但请查收结果(压缩包), 内含已经翻译的Tex文档, 您可以到Github Issue区, 用该压缩包进行反馈。如系统是Linux,请检查系统字体(见Github wiki) ...'))
|
||||
yield from update_ui(chatbot=chatbot, history=history);
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
time.sleep(1) # 刷新界面
|
||||
promote_file_to_downloadzone(file=zip_res, chatbot=chatbot)
|
||||
|
||||
|
||||
@@ -30,6 +30,8 @@ class Arxiv_Localize(GptAcademicPluginTemplate):
|
||||
default_value="", type="string").model_dump_json(), # 高级参数输入区,自动同步
|
||||
"allow_cache":
|
||||
ArgProperty(title="是否允许从缓存中调取结果", options=["允许缓存", "从头执行"], default_value="允许缓存", description="无", type="dropdown").model_dump_json(),
|
||||
"allow_cloudio":
|
||||
ArgProperty(title="是否允许从GPTAC学术云下载(或者上传)翻译结果(仅针对Arxiv论文)", options=["允许", "禁止"], default_value="禁止", description="共享文献,互助互利", type="dropdown").model_dump_json(),
|
||||
}
|
||||
return gui_definition
|
||||
|
||||
@@ -38,9 +40,14 @@ class Arxiv_Localize(GptAcademicPluginTemplate):
|
||||
执行插件
|
||||
"""
|
||||
allow_cache = plugin_kwargs["allow_cache"]
|
||||
allow_cloudio = plugin_kwargs["allow_cloudio"]
|
||||
advanced_arg = plugin_kwargs["advanced_arg"]
|
||||
|
||||
if allow_cache == "从头执行": plugin_kwargs["advanced_arg"] = "--no-cache " + plugin_kwargs["advanced_arg"]
|
||||
|
||||
# 从云端下载翻译结果,以及上传翻译结果到云端;人人为我,我为人人。
|
||||
if allow_cloudio == "允许": plugin_kwargs["advanced_arg"] = "--allow-cloudio " + plugin_kwargs["advanced_arg"]
|
||||
|
||||
yield from Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from toolbox import update_ui, trimmed_format_exc, promote_file_to_downloadzone, get_log_folder
|
||||
from toolbox import CatchException, report_exception, write_history_to_file, zip_folder
|
||||
|
||||
from loguru import logger
|
||||
|
||||
class PaperFileGroup():
|
||||
def __init__(self):
|
||||
@@ -33,7 +33,7 @@ class PaperFileGroup():
|
||||
self.sp_file_index.append(index)
|
||||
self.sp_file_tag.append(self.file_paths[index] + f".part-{j}.tex")
|
||||
|
||||
print('Segmentation: done')
|
||||
logger.info('Segmentation: done')
|
||||
def merge_result(self):
|
||||
self.file_result = ["" for _ in range(len(self.file_paths))]
|
||||
for r, k in zip(self.sp_file_result, self.sp_file_index):
|
||||
@@ -56,7 +56,7 @@ class PaperFileGroup():
|
||||
|
||||
def 多文件润色(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, language='en', mode='polish'):
|
||||
import time, os, re
|
||||
from .crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||
|
||||
|
||||
# <-------- 读取Latex文件,删除其中的所有注释 ---------->
|
||||
@@ -122,7 +122,7 @@ def 多文件润色(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
|
||||
pfg.write_result()
|
||||
pfg.zip_result()
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
logger.error(trimmed_format_exc())
|
||||
|
||||
# <-------- 整理结果,退出 ---------->
|
||||
create_report_file_name = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + f"-chatgpt.polish.md"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from toolbox import update_ui, promote_file_to_downloadzone
|
||||
from toolbox import CatchException, report_exception, write_history_to_file
|
||||
fast_debug = False
|
||||
from loguru import logger
|
||||
|
||||
class PaperFileGroup():
|
||||
def __init__(self):
|
||||
@@ -33,11 +33,11 @@ class PaperFileGroup():
|
||||
self.sp_file_index.append(index)
|
||||
self.sp_file_tag.append(self.file_paths[index] + f".part-{j}.tex")
|
||||
|
||||
print('Segmentation: done')
|
||||
logger.info('Segmentation: done')
|
||||
|
||||
def 多文件翻译(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, language='en'):
|
||||
import time, os, re
|
||||
from .crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||
|
||||
# <-------- 读取Latex文件,删除其中的所有注释 ---------->
|
||||
pfg = PaperFileGroup()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import glob, shutil, os, re, logging
|
||||
import glob, shutil, os, re
|
||||
from loguru import logger
|
||||
from toolbox import update_ui, trimmed_format_exc, gen_time_str
|
||||
from toolbox import CatchException, report_exception, get_log_folder
|
||||
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
||||
@@ -34,7 +35,7 @@ class PaperFileGroup():
|
||||
self.sp_file_contents.append(segment)
|
||||
self.sp_file_index.append(index)
|
||||
self.sp_file_tag.append(self.file_paths[index] + f".part-{j}.md")
|
||||
logging.info('Segmentation: done')
|
||||
logger.info('Segmentation: done')
|
||||
|
||||
def merge_result(self):
|
||||
self.file_result = ["" for _ in range(len(self.file_paths))]
|
||||
@@ -51,7 +52,7 @@ class PaperFileGroup():
|
||||
return manifest
|
||||
|
||||
def 多文件翻译(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, language='en'):
|
||||
from .crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||
|
||||
# <-------- 读取Markdown文件,删除其中的所有注释 ---------->
|
||||
pfg = PaperFileGroup()
|
||||
@@ -64,7 +65,7 @@ def 多文件翻译(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
|
||||
pfg.file_contents.append(file_content)
|
||||
|
||||
# <-------- 拆分过长的Markdown文件 ---------->
|
||||
pfg.run_file_split(max_token_limit=2048)
|
||||
pfg.run_file_split(max_token_limit=1024)
|
||||
n_split = len(pfg.sp_file_contents)
|
||||
|
||||
# <-------- 多线程翻译开始 ---------->
|
||||
@@ -106,7 +107,7 @@ def 多文件翻译(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
|
||||
expected_f_name = plugin_kwargs['markdown_expected_output_path']
|
||||
shutil.copyfile(output_file, expected_f_name)
|
||||
except:
|
||||
logging.error(trimmed_format_exc())
|
||||
logger.error(trimmed_format_exc())
|
||||
|
||||
# <-------- 整理结果,退出 ---------->
|
||||
create_report_file_name = gen_time_str() + f"-chatgpt.md"
|
||||
@@ -126,7 +127,7 @@ def get_files_from_everything(txt, preference=''):
|
||||
proxies = get_conf('proxies')
|
||||
# 网络的远程文件
|
||||
if preference == 'Github':
|
||||
logging.info('正在从github下载资源 ...')
|
||||
logger.info('正在从github下载资源 ...')
|
||||
if not txt.endswith('.md'):
|
||||
# Make a request to the GitHub API to retrieve the repository information
|
||||
url = txt.replace("https://github.com/", "https://api.github.com/repos/") + '/readme'
|
||||
|
||||
154
crazy_functions/Rag_Interface.py
普通文件
154
crazy_functions/Rag_Interface.py
普通文件
@@ -0,0 +1,154 @@
|
||||
import os,glob
|
||||
from typing import List
|
||||
|
||||
from shared_utils.fastapi_server import validate_path_safety
|
||||
|
||||
from toolbox import report_exception
|
||||
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg
|
||||
from crazy_functions.crazy_utils import input_clipping
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
|
||||
RAG_WORKER_REGISTER = {}
|
||||
MAX_HISTORY_ROUND = 5
|
||||
MAX_CONTEXT_TOKEN_LIMIT = 4096
|
||||
REMEMBER_PREVIEW = 1000
|
||||
|
||||
@CatchException
|
||||
def handle_document_upload(files: List[str], llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request, rag_worker):
|
||||
"""
|
||||
Handles document uploads by extracting text and adding it to the vector store.
|
||||
"""
|
||||
from llama_index.core import Document
|
||||
from crazy_functions.rag_fns.rag_file_support import extract_text, supports_format
|
||||
user_name = chatbot.get_user()
|
||||
checkpoint_dir = get_log_folder(user_name, plugin_name='experimental_rag')
|
||||
|
||||
for file_path in files:
|
||||
try:
|
||||
validate_path_safety(file_path, user_name)
|
||||
text = extract_text(file_path)
|
||||
if text is None:
|
||||
chatbot.append(
|
||||
[f"上传文件: {os.path.basename(file_path)}", f"文件解析失败,无法提取文本内容,请更换文件。失败原因可能为:1.文档格式过于复杂;2. 不支持的文件格式,支持的文件格式后缀有:" + ", ".join(supports_format)])
|
||||
else:
|
||||
chatbot.append(
|
||||
[f"上传文件: {os.path.basename(file_path)}", f"上传文件前50个字符为:{text[:50]}。"])
|
||||
document = Document(text=text, metadata={"source": file_path})
|
||||
rag_worker.add_documents_to_vector_store([document])
|
||||
chatbot.append([f"上传文件: {os.path.basename(file_path)}", "文件已成功添加到知识库。"])
|
||||
except Exception as e:
|
||||
report_exception(chatbot, history, a=f"处理文件: {file_path}", b=str(e))
|
||||
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
|
||||
|
||||
# Main Q&A function with document upload support
|
||||
@CatchException
|
||||
def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
|
||||
# import vector store lib
|
||||
VECTOR_STORE_TYPE = "Milvus"
|
||||
if VECTOR_STORE_TYPE == "Milvus":
|
||||
try:
|
||||
from crazy_functions.rag_fns.milvus_worker import MilvusRagWorker as LlamaIndexRagWorker
|
||||
except:
|
||||
VECTOR_STORE_TYPE = "Simple"
|
||||
if VECTOR_STORE_TYPE == "Simple":
|
||||
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
|
||||
|
||||
# 1. we retrieve rag worker from global context
|
||||
user_name = chatbot.get_user()
|
||||
checkpoint_dir = get_log_folder(user_name, plugin_name='experimental_rag')
|
||||
if user_name in RAG_WORKER_REGISTER:
|
||||
rag_worker = RAG_WORKER_REGISTER[user_name]
|
||||
else:
|
||||
rag_worker = RAG_WORKER_REGISTER[user_name] = LlamaIndexRagWorker(
|
||||
user_name,
|
||||
llm_kwargs,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
auto_load_checkpoint=True
|
||||
)
|
||||
|
||||
current_context = f"{VECTOR_STORE_TYPE} @ {checkpoint_dir}"
|
||||
tip = "提示:输入“清空向量数据库”可以清空RAG向量数据库"
|
||||
|
||||
# 2. Handle special commands
|
||||
if os.path.exists(txt) and os.path.isdir(txt):
|
||||
project_folder = txt
|
||||
validate_path_safety(project_folder, chatbot.get_user())
|
||||
# Extract file paths from the user input
|
||||
# Assuming the user inputs file paths separated by commas after the command
|
||||
file_paths = [f for f in glob.glob(f'{project_folder}/**/*', recursive=True)]
|
||||
chatbot.append([txt, f'正在处理上传的文档 ({current_context}) ...'])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
yield from handle_document_upload(file_paths, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request, rag_worker)
|
||||
return
|
||||
|
||||
elif txt == "清空向量数据库":
|
||||
chatbot.append([txt, f'正在清空 ({current_context}) ...'])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
rag_worker.purge_vector_store()
|
||||
yield from update_ui_lastest_msg('已清空', chatbot, history, delay=0) # 刷新界面
|
||||
return
|
||||
|
||||
else:
|
||||
report_exception(chatbot, history, a=f"上传文件路径错误: {txt}", b="请检查并提供正确路径。")
|
||||
|
||||
# 3. Normal Q&A processing
|
||||
chatbot.append([txt, f'正在召回知识 ({current_context}) ...'])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
# 4. Clip history to reduce token consumption
|
||||
txt_origin = txt
|
||||
|
||||
if len(history) > MAX_HISTORY_ROUND * 2:
|
||||
history = history[-(MAX_HISTORY_ROUND * 2):]
|
||||
txt_clip, history, flags = input_clipping(txt, history, max_token_limit=MAX_CONTEXT_TOKEN_LIMIT, return_clip_flags=True)
|
||||
input_is_clipped_flag = (flags["original_input_len"] != flags["clipped_input_len"])
|
||||
|
||||
# 5. If input is clipped, add input to vector store before retrieve
|
||||
if input_is_clipped_flag:
|
||||
yield from update_ui_lastest_msg('检测到长输入, 正在向量化 ...', chatbot, history, delay=0) # 刷新界面
|
||||
# Save input to vector store
|
||||
rag_worker.add_text_to_vector_store(txt_origin)
|
||||
yield from update_ui_lastest_msg('向量化完成 ...', chatbot, history, delay=0) # 刷新界面
|
||||
|
||||
if len(txt_origin) > REMEMBER_PREVIEW:
|
||||
HALF = REMEMBER_PREVIEW // 2
|
||||
i_say_to_remember = txt[:HALF] + f" ...\n...(省略{len(txt_origin)-REMEMBER_PREVIEW}字)...\n... " + txt[-HALF:]
|
||||
if (flags["original_input_len"] - flags["clipped_input_len"]) > HALF:
|
||||
txt_clip = txt_clip + f" ...\n...(省略{len(txt_origin)-len(txt_clip)-HALF}字)...\n... " + txt[-HALF:]
|
||||
else:
|
||||
i_say_to_remember = i_say = txt_clip
|
||||
else:
|
||||
i_say_to_remember = i_say = txt_clip
|
||||
|
||||
# 6. Search vector store and build prompts
|
||||
nodes = rag_worker.retrieve_from_store_with_query(i_say)
|
||||
prompt = rag_worker.build_prompt(query=i_say, nodes=nodes)
|
||||
# 7. Query language model
|
||||
if len(chatbot) != 0:
|
||||
chatbot.pop(-1) # Pop temp chat, because we are going to add them again inside `request_gpt_model_in_new_thread_with_ui_alive`
|
||||
|
||||
model_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||
inputs=prompt,
|
||||
inputs_show_user=i_say,
|
||||
llm_kwargs=llm_kwargs,
|
||||
chatbot=chatbot,
|
||||
history=history,
|
||||
sys_prompt=system_prompt,
|
||||
retry_times_at_unknown_error=0
|
||||
)
|
||||
|
||||
# 8. Remember Q&A
|
||||
yield from update_ui_lastest_msg(
|
||||
model_say + '</br></br>' + f'对话记忆中, 请稍等 ({current_context}) ...',
|
||||
chatbot, history, delay=0.5
|
||||
)
|
||||
rag_worker.remember_qa(i_say_to_remember, model_say)
|
||||
history.extend([i_say, model_say])
|
||||
|
||||
# 9. Final UI Update
|
||||
yield from update_ui_lastest_msg(model_say, chatbot, history, delay=0, msg=tip)
|
||||
167
crazy_functions/Social_Helper.py
普通文件
167
crazy_functions/Social_Helper.py
普通文件
@@ -0,0 +1,167 @@
|
||||
import pickle, os, random
|
||||
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg
|
||||
from crazy_functions.crazy_utils import input_clipping
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from request_llms.bridge_all import predict_no_ui_long_connection
|
||||
from crazy_functions.json_fns.select_tool import structure_output, select_tool
|
||||
from pydantic import BaseModel, Field
|
||||
from loguru import logger
|
||||
from typing import List
|
||||
|
||||
|
||||
SOCIAL_NETWOK_WORKER_REGISTER = {}
|
||||
|
||||
class SocialNetwork():
|
||||
def __init__(self):
|
||||
self.people = []
|
||||
|
||||
class SaveAndLoad():
|
||||
def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None:
|
||||
self.user_name = user_name
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
if auto_load_checkpoint:
|
||||
self.social_network = self.load_from_checkpoint(checkpoint_dir)
|
||||
else:
|
||||
self.social_network = SocialNetwork()
|
||||
|
||||
def does_checkpoint_exist(self, checkpoint_dir=None):
|
||||
import os, glob
|
||||
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
|
||||
if not os.path.exists(checkpoint_dir): return False
|
||||
if len(glob.glob(os.path.join(checkpoint_dir, "social_network.pkl"))) == 0: return False
|
||||
return True
|
||||
|
||||
def save_to_checkpoint(self, checkpoint_dir=None):
|
||||
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
|
||||
with open(os.path.join(checkpoint_dir, 'social_network.pkl'), "wb+") as f:
|
||||
pickle.dump(self.social_network, f)
|
||||
return
|
||||
|
||||
def load_from_checkpoint(self, checkpoint_dir=None):
|
||||
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
|
||||
if self.does_checkpoint_exist(checkpoint_dir=checkpoint_dir):
|
||||
with open(os.path.join(checkpoint_dir, 'social_network.pkl'), "rb") as f:
|
||||
social_network = pickle.load(f)
|
||||
return social_network
|
||||
else:
|
||||
return SocialNetwork()
|
||||
|
||||
|
||||
class Friend(BaseModel):
|
||||
friend_name: str = Field(description="name of a friend")
|
||||
friend_description: str = Field(description="description of a friend (everything about this friend)")
|
||||
friend_relationship: str = Field(description="The relationship with a friend (e.g. friend, family, colleague)")
|
||||
|
||||
class FriendList(BaseModel):
|
||||
friends_list: List[Friend] = Field(description="The list of friends")
|
||||
|
||||
|
||||
class SocialNetworkWorker(SaveAndLoad):
|
||||
def ai_socail_advice(self, prompt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, run_gpt_fn, intention_type):
|
||||
pass
|
||||
|
||||
def ai_remove_friend(self, prompt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, run_gpt_fn, intention_type):
|
||||
pass
|
||||
|
||||
def ai_list_friends(self, prompt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, run_gpt_fn, intention_type):
|
||||
pass
|
||||
|
||||
def ai_add_multi_friends(self, prompt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, run_gpt_fn, intention_type):
|
||||
friend, err_msg = structure_output(
|
||||
txt=prompt,
|
||||
prompt="根据提示, 解析多个联系人的身份信息\n\n",
|
||||
err_msg=f"不能理解该联系人",
|
||||
run_gpt_fn=run_gpt_fn,
|
||||
pydantic_cls=FriendList
|
||||
)
|
||||
if friend.friends_list:
|
||||
for f in friend.friends_list:
|
||||
self.add_friend(f)
|
||||
msg = f"成功添加{len(friend.friends_list)}个联系人: {str(friend.friends_list)}"
|
||||
yield from update_ui_lastest_msg(lastmsg=msg, chatbot=chatbot, history=history, delay=0)
|
||||
|
||||
|
||||
def run(self, txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
prompt = txt
|
||||
run_gpt_fn = lambda inputs, sys_prompt: predict_no_ui_long_connection(inputs=inputs, llm_kwargs=llm_kwargs, history=[], sys_prompt=sys_prompt, observe_window=[])
|
||||
self.tools_to_select = {
|
||||
"SocialAdvice":{
|
||||
"explain_to_llm": "如果用户希望获取社交指导,调用SocialAdvice生成一些社交建议",
|
||||
"callback": self.ai_socail_advice,
|
||||
},
|
||||
"AddFriends":{
|
||||
"explain_to_llm": "如果用户给出了联系人,调用AddMultiFriends把联系人添加到数据库",
|
||||
"callback": self.ai_add_multi_friends,
|
||||
},
|
||||
"RemoveFriend":{
|
||||
"explain_to_llm": "如果用户希望移除某个联系人,调用RemoveFriend",
|
||||
"callback": self.ai_remove_friend,
|
||||
},
|
||||
"ListFriends":{
|
||||
"explain_to_llm": "如果用户列举联系人,调用ListFriends",
|
||||
"callback": self.ai_list_friends,
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
Explaination = '\n'.join([f'{k}: {v["explain_to_llm"]}' for k, v in self.tools_to_select.items()])
|
||||
class UserSociaIntention(BaseModel):
|
||||
intention_type: str = Field(
|
||||
description=
|
||||
f"The type of user intention. You must choose from {self.tools_to_select.keys()}.\n\n"
|
||||
f"Explaination:\n{Explaination}",
|
||||
default="SocialAdvice"
|
||||
)
|
||||
pydantic_cls_instance, err_msg = select_tool(
|
||||
prompt=txt,
|
||||
run_gpt_fn=run_gpt_fn,
|
||||
pydantic_cls=UserSociaIntention
|
||||
)
|
||||
except Exception as e:
|
||||
yield from update_ui_lastest_msg(
|
||||
lastmsg=f"无法理解用户意图 {err_msg}",
|
||||
chatbot=chatbot,
|
||||
history=history,
|
||||
delay=0
|
||||
)
|
||||
return
|
||||
|
||||
intention_type = pydantic_cls_instance.intention_type
|
||||
intention_callback = self.tools_to_select[pydantic_cls_instance.intention_type]['callback']
|
||||
yield from intention_callback(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, run_gpt_fn, intention_type)
|
||||
|
||||
|
||||
def add_friend(self, friend):
|
||||
# check whether the friend is already in the social network
|
||||
for f in self.social_network.people:
|
||||
if f.friend_name == friend.friend_name:
|
||||
f.friend_description = friend.friend_description
|
||||
f.friend_relationship = friend.friend_relationship
|
||||
logger.info(f"Repeated friend, update info: {friend}")
|
||||
return
|
||||
logger.info(f"Add a new friend: {friend}")
|
||||
self.social_network.people.append(friend)
|
||||
return
|
||||
|
||||
|
||||
@CatchException
|
||||
def I人助手(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
|
||||
# 1. we retrieve worker from global context
|
||||
user_name = chatbot.get_user()
|
||||
checkpoint_dir=get_log_folder(user_name, plugin_name='experimental_rag')
|
||||
if user_name in SOCIAL_NETWOK_WORKER_REGISTER:
|
||||
social_network_worker = SOCIAL_NETWOK_WORKER_REGISTER[user_name]
|
||||
else:
|
||||
social_network_worker = SOCIAL_NETWOK_WORKER_REGISTER[user_name] = SocialNetworkWorker(
|
||||
user_name,
|
||||
llm_kwargs,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
auto_load_checkpoint=True
|
||||
)
|
||||
|
||||
# 2. save
|
||||
yield from social_network_worker.run(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)
|
||||
social_network_worker.save_to_checkpoint(checkpoint_dir)
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
@@ -5,8 +5,8 @@ from crazy_functions.crazy_utils import input_clipping
|
||||
|
||||
def 解析源代码新(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
||||
import os, copy
|
||||
from .crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
|
||||
summary_batch_isolation = True
|
||||
inputs_array = []
|
||||
|
||||
@@ -6,7 +6,10 @@ from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_ver
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from crazy_functions.agent_fns.python_comment_agent import PythonCodeComment
|
||||
from crazy_functions.diagram_fns.file_tree import FileNode
|
||||
from crazy_functions.agent_fns.watchdog import WatchDog
|
||||
from shared_utils.advanced_markdown_format import markdown_convertion_for_file
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
||||
|
||||
@@ -24,12 +27,13 @@ def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
|
||||
file_tree_struct.add_file(file_path, file_path)
|
||||
|
||||
# <第一步,逐个文件分析,多线程>
|
||||
lang = "" if not plugin_kwargs["use_chinese"] else " (you must use Chinese)"
|
||||
for index, fp in enumerate(file_manifest):
|
||||
# 读取文件
|
||||
with open(fp, 'r', encoding='utf-8', errors='replace') as f:
|
||||
file_content = f.read()
|
||||
prefix = ""
|
||||
i_say = prefix + f'Please conclude the following source code at {os.path.relpath(fp, project_folder)} with only one sentence, the code is:\n```{file_content}```'
|
||||
i_say = prefix + f'Please conclude the following source code at {os.path.relpath(fp, project_folder)} with only one sentence{lang}, the code is:\n```{file_content}```'
|
||||
i_say_show_user = prefix + f'[{index+1}/{len(file_manifest)}] 请用一句话对下面的程序文件做一个整体概述: {fp}'
|
||||
# 装载请求内容
|
||||
MAX_TOKEN_SINGLE_FILE = 2560
|
||||
@@ -37,7 +41,7 @@ def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
|
||||
inputs_array.append(i_say)
|
||||
inputs_show_user_array.append(i_say_show_user)
|
||||
history_array.append([])
|
||||
sys_prompt_array.append("You are a software architecture analyst analyzing a source code project. Do not dig into details, tell me what the code is doing in general. Your answer must be short, simple and clear.")
|
||||
sys_prompt_array.append(f"You are a software architecture analyst analyzing a source code project. Do not dig into details, tell me what the code is doing in general. Your answer must be short, simple and clear{lang}.")
|
||||
# 文件读取完成,对每一个源代码文件,生成一个请求线程,发送到大模型进行分析
|
||||
gpt_response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
|
||||
inputs_array = inputs_array,
|
||||
@@ -50,10 +54,20 @@ def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
|
||||
)
|
||||
|
||||
# <第二步,逐个文件分析,生成带注释文件>
|
||||
tasks = ["" for _ in range(len(file_manifest))]
|
||||
def bark_fn(tasks):
|
||||
for i in range(len(tasks)): tasks[i] = "watchdog is dead"
|
||||
wd = WatchDog(timeout=10, bark_fn=lambda: bark_fn(tasks), interval=3, msg="ThreadWatcher timeout")
|
||||
wd.begin_watch()
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
executor = ThreadPoolExecutor(max_workers=get_conf('DEFAULT_WORKER_NUM'))
|
||||
def _task_multi_threading(i_say, gpt_say, fp, file_tree_struct):
|
||||
pcc = PythonCodeComment(llm_kwargs, language='English')
|
||||
def _task_multi_threading(i_say, gpt_say, fp, file_tree_struct, index):
|
||||
language = 'Chinese' if plugin_kwargs["use_chinese"] else 'English'
|
||||
def observe_window_update(x):
|
||||
if tasks[index] == "watchdog is dead":
|
||||
raise TimeoutError("ThreadWatcher: watchdog is dead")
|
||||
tasks[index] = x
|
||||
pcc = PythonCodeComment(llm_kwargs, plugin_kwargs, language=language, observe_window_update=observe_window_update)
|
||||
pcc.read_file(path=fp, brief=gpt_say)
|
||||
revised_path, revised_content = pcc.begin_comment_source_code(None, None)
|
||||
file_tree_struct.manifest[fp].revised_path = revised_path
|
||||
@@ -65,7 +79,8 @@ def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
|
||||
with open("crazy_functions/agent_fns/python_comment_compare.html", 'r', encoding='utf-8') as f:
|
||||
html_template = f.read()
|
||||
warp = lambda x: "```python\n\n" + x + "\n\n```"
|
||||
from themes.theme import advanced_css
|
||||
from themes.theme import load_dynamic_theme
|
||||
_, advanced_css, _, _ = load_dynamic_theme("Default")
|
||||
html_template = html_template.replace("ADVANCED_CSS", advanced_css)
|
||||
html_template = html_template.replace("REPLACE_CODE_FILE_LEFT", pcc.get_markdown_block_in_html(markdown_convertion_for_file(warp(pcc.original_content))))
|
||||
html_template = html_template.replace("REPLACE_CODE_FILE_RIGHT", pcc.get_markdown_block_in_html(markdown_convertion_for_file(warp(revised_content))))
|
||||
@@ -73,17 +88,21 @@ def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
|
||||
file_tree_struct.manifest[fp].compare_html = compare_html_path
|
||||
with open(compare_html_path, 'w', encoding='utf-8') as f:
|
||||
f.write(html_template)
|
||||
print('done 1')
|
||||
tasks[index] = ""
|
||||
|
||||
chatbot.append([None, f"正在处理:"])
|
||||
futures = []
|
||||
index = 0
|
||||
for i_say, gpt_say, fp in zip(gpt_response_collection[0::2], gpt_response_collection[1::2], file_manifest):
|
||||
future = executor.submit(_task_multi_threading, i_say, gpt_say, fp, file_tree_struct)
|
||||
future = executor.submit(_task_multi_threading, i_say, gpt_say, fp, file_tree_struct, index)
|
||||
index += 1
|
||||
futures.append(future)
|
||||
|
||||
# <第三步,等待任务完成>
|
||||
cnt = 0
|
||||
while True:
|
||||
cnt += 1
|
||||
wd.feed()
|
||||
time.sleep(3)
|
||||
worker_done = [h.done() for h in futures]
|
||||
remain = len(worker_done) - sum(worker_done)
|
||||
@@ -92,14 +111,18 @@ def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
|
||||
preview_html_list = []
|
||||
for done, fp in zip(worker_done, file_manifest):
|
||||
if not done: continue
|
||||
preview_html_list.append(file_tree_struct.manifest[fp].compare_html)
|
||||
if hasattr(file_tree_struct.manifest[fp], 'compare_html'):
|
||||
preview_html_list.append(file_tree_struct.manifest[fp].compare_html)
|
||||
else:
|
||||
logger.error(f"文件: {fp} 的注释结果未能成功")
|
||||
file_links = generate_file_link(preview_html_list)
|
||||
|
||||
yield from update_ui_lastest_msg(
|
||||
f"剩余源文件数量: {remain}.\n\n" +
|
||||
f"已完成的文件: {sum(worker_done)}.\n\n" +
|
||||
f"当前任务: <br/>{'<br/>'.join(tasks)}.<br/>" +
|
||||
f"剩余源文件数量: {remain}.<br/>" +
|
||||
f"已完成的文件: {sum(worker_done)}.<br/>" +
|
||||
file_links +
|
||||
"\n\n" +
|
||||
"<br/>" +
|
||||
''.join(['.']*(cnt % 10 + 1)
|
||||
), chatbot=chatbot, history=history, delay=0)
|
||||
yield from update_ui(chatbot=chatbot, history=[]) # 刷新界面
|
||||
@@ -120,6 +143,7 @@ def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
|
||||
@CatchException
|
||||
def 注释Python项目(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
history = [] # 清空历史,以免输入溢出
|
||||
plugin_kwargs["use_chinese"] = plugin_kwargs.get("use_chinese", False)
|
||||
import glob, os
|
||||
if os.path.exists(txt):
|
||||
project_folder = txt
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
|
||||
from toolbox import get_conf, update_ui
|
||||
from crazy_functions.plugin_template.plugin_class_template import GptAcademicPluginTemplate, ArgProperty
|
||||
from crazy_functions.SourceCode_Comment import 注释Python项目
|
||||
|
||||
class SourceCodeComment_Wrap(GptAcademicPluginTemplate):
|
||||
def __init__(self):
|
||||
"""
|
||||
请注意`execute`会执行在不同的线程中,因此您在定义和使用类变量时,应当慎之又慎!
|
||||
"""
|
||||
pass
|
||||
|
||||
def define_arg_selection_menu(self):
|
||||
"""
|
||||
定义插件的二级选项菜单
|
||||
"""
|
||||
gui_definition = {
|
||||
"main_input":
|
||||
ArgProperty(title="路径", description="程序路径(上传文件后自动填写)", default_value="", type="string").model_dump_json(), # 主输入,自动从输入框同步
|
||||
"use_chinese":
|
||||
ArgProperty(title="注释语言", options=["英文", "中文"], default_value="英文", description="无", type="dropdown").model_dump_json(),
|
||||
# "use_emoji":
|
||||
# ArgProperty(title="在注释中使用emoji", options=["禁止", "允许"], default_value="禁止", description="无", type="dropdown").model_dump_json(),
|
||||
}
|
||||
return gui_definition
|
||||
|
||||
def execute(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
"""
|
||||
执行插件
|
||||
"""
|
||||
if plugin_kwargs["use_chinese"] == "中文":
|
||||
plugin_kwargs["use_chinese"] = True
|
||||
else:
|
||||
plugin_kwargs["use_chinese"] = False
|
||||
|
||||
yield from 注释Python项目(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)
|
||||
@@ -1,4 +1,5 @@
|
||||
from crazy_functions.agent_fns.pipe import PluginMultiprocessManager, PipeCom
|
||||
from loguru import logger
|
||||
|
||||
class EchoDemo(PluginMultiprocessManager):
|
||||
def subprocess_worker(self, child_conn):
|
||||
@@ -16,4 +17,4 @@ class EchoDemo(PluginMultiprocessManager):
|
||||
elif msg.cmd == "terminate":
|
||||
self.child_conn.send(PipeCom("done", ""))
|
||||
break
|
||||
print('[debug] subprocess_worker terminated')
|
||||
logger.info('[debug] subprocess_worker terminated')
|
||||
@@ -1,5 +1,6 @@
|
||||
from toolbox import get_log_folder, update_ui, gen_time_str, get_conf, promote_file_to_downloadzone
|
||||
from crazy_functions.agent_fns.watchdog import WatchDog
|
||||
from loguru import logger
|
||||
import time, os
|
||||
|
||||
class PipeCom:
|
||||
@@ -47,7 +48,7 @@ class PluginMultiprocessManager:
|
||||
def terminate(self):
|
||||
self.p.terminate()
|
||||
self.alive = False
|
||||
print("[debug] instance terminated")
|
||||
logger.info("[debug] instance terminated")
|
||||
|
||||
def subprocess_worker(self, child_conn):
|
||||
# ⭐⭐ run in subprocess
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from toolbox import CatchException, update_ui
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from request_llms.bridge_all import predict_no_ui_long_connection
|
||||
import datetime
|
||||
import re
|
||||
import os
|
||||
from loguru import logger
|
||||
from textwrap import dedent
|
||||
from toolbox import CatchException, update_ui
|
||||
from request_llms.bridge_all import predict_no_ui_long_connection
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
|
||||
# TODO: 解决缩进问题
|
||||
|
||||
find_function_end_prompt = '''
|
||||
@@ -66,6 +68,7 @@ Be aware:
|
||||
1. You must NOT modify the indent of code.
|
||||
2. You are NOT authorized to change or translate non-comment code, and you are NOT authorized to add empty lines either, toggle qu.
|
||||
3. Use {LANG} to add comments and docstrings. Do NOT translate Chinese that is already in the code.
|
||||
4. Besides adding a docstring, use the ⭐ symbol to annotate the most core and important line of code within the function, explaining its role.
|
||||
|
||||
------------------ Example ------------------
|
||||
INPUT:
|
||||
@@ -114,10 +117,66 @@ def zip_result(folder):
|
||||
'''
|
||||
|
||||
|
||||
revise_funtion_prompt_chinese = '''
|
||||
您需要阅读以下代码,并根据以下说明修订源代码({FILE_BASENAME}):
|
||||
1. 如果源代码中包含函数的话, 你应该分析给定函数实现了什么功能
|
||||
2. 如果源代码中包含函数的话, 你需要为函数添加docstring, docstring必须使用中文
|
||||
|
||||
请注意:
|
||||
1. 你不得修改代码的缩进
|
||||
2. 你无权更改或翻译代码中的非注释部分,也不允许添加空行
|
||||
3. 使用 {LANG} 添加注释和文档字符串。不要翻译代码中已有的中文
|
||||
4. 除了添加docstring之外, 使用⭐符号给该函数中最核心、最重要的一行代码添加注释,并说明其作用
|
||||
|
||||
------------------ 示例 ------------------
|
||||
INPUT:
|
||||
```
|
||||
L0000 |
|
||||
L0001 |def zip_result(folder):
|
||||
L0002 | t = gen_time_str()
|
||||
L0003 | zip_folder(folder, get_log_folder(), f"result.zip")
|
||||
L0004 | return os.path.join(get_log_folder(), f"result.zip")
|
||||
L0005 |
|
||||
L0006 |
|
||||
```
|
||||
|
||||
OUTPUT:
|
||||
|
||||
<instruction_1_purpose>
|
||||
该函数用于压缩指定文件夹,并返回生成的`zip`文件的路径。
|
||||
</instruction_1_purpose>
|
||||
<instruction_2_revised_code>
|
||||
```
|
||||
def zip_result(folder):
|
||||
"""
|
||||
该函数将指定的文件夹压缩成ZIP文件, 并将其存储在日志文件夹中。
|
||||
|
||||
输入参数:
|
||||
folder (str): 需要压缩的文件夹的路径。
|
||||
返回值:
|
||||
str: 日志文件夹中创建的ZIP文件的路径。
|
||||
"""
|
||||
t = gen_time_str()
|
||||
zip_folder(folder, get_log_folder(), f"result.zip") # ⭐ 执行文件夹的压缩
|
||||
return os.path.join(get_log_folder(), f"result.zip")
|
||||
```
|
||||
</instruction_2_revised_code>
|
||||
------------------ End of Example ------------------
|
||||
|
||||
|
||||
------------------ the real INPUT you need to process NOW ({FILE_BASENAME}) ------------------
|
||||
```
|
||||
{THE_CODE}
|
||||
```
|
||||
{INDENT_REMINDER}
|
||||
{BRIEF_REMINDER}
|
||||
{HINT_REMINDER}
|
||||
'''
|
||||
|
||||
|
||||
class PythonCodeComment():
|
||||
|
||||
def __init__(self, llm_kwargs, language) -> None:
|
||||
def __init__(self, llm_kwargs, plugin_kwargs, language, observe_window_update) -> None:
|
||||
self.original_content = ""
|
||||
self.full_context = []
|
||||
self.full_context_with_line_no = []
|
||||
@@ -125,7 +184,13 @@ class PythonCodeComment():
|
||||
self.page_limit = 100 # 100 lines of code each page
|
||||
self.ignore_limit = 20
|
||||
self.llm_kwargs = llm_kwargs
|
||||
self.plugin_kwargs = plugin_kwargs
|
||||
self.language = language
|
||||
self.observe_window_update = observe_window_update
|
||||
if self.language == "chinese":
|
||||
self.core_prompt = revise_funtion_prompt_chinese
|
||||
else:
|
||||
self.core_prompt = revise_funtion_prompt
|
||||
self.path = None
|
||||
self.file_basename = None
|
||||
self.file_brief = ""
|
||||
@@ -256,7 +321,7 @@ class PythonCodeComment():
|
||||
hint_reminder = "" if hint is None else f"(Reminder: do not ignore or modify code such as `{hint}`, provide complete code in the OUTPUT.)"
|
||||
self.llm_kwargs['temperature'] = 0
|
||||
result = predict_no_ui_long_connection(
|
||||
inputs=revise_funtion_prompt.format(
|
||||
inputs=self.core_prompt.format(
|
||||
LANG=self.language,
|
||||
FILE_BASENAME=self.file_basename,
|
||||
THE_CODE=code,
|
||||
@@ -346,6 +411,7 @@ class PythonCodeComment():
|
||||
try:
|
||||
# yield from update_ui_lastest_msg(f"({self.file_basename}) 正在读取下一段代码片段:\n", chatbot=chatbot, history=history, delay=0)
|
||||
next_batch, line_no_start, line_no_end = self.get_next_batch()
|
||||
self.observe_window_update(f"正在处理{self.file_basename} - {line_no_start}/{len(self.full_context)}\n")
|
||||
# yield from update_ui_lastest_msg(f"({self.file_basename}) 处理代码片段:\n\n{next_batch}", chatbot=chatbot, history=history, delay=0)
|
||||
|
||||
hint = None
|
||||
@@ -355,7 +421,7 @@ class PythonCodeComment():
|
||||
try:
|
||||
successful, hint = self.verify_successful(next_batch, result)
|
||||
except Exception as e:
|
||||
print('ignored exception:\n' + str(e))
|
||||
logger.error('ignored exception:\n' + str(e))
|
||||
break
|
||||
if successful:
|
||||
break
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import threading, time
|
||||
from loguru import logger
|
||||
|
||||
class WatchDog():
|
||||
def __init__(self, timeout, bark_fn, interval=3, msg="") -> None:
|
||||
@@ -13,7 +14,7 @@ class WatchDog():
|
||||
while True:
|
||||
if self.kill_dog: break
|
||||
if time.time() - self.last_feed > self.timeout:
|
||||
if len(self.msg) > 0: print(self.msg)
|
||||
if len(self.msg) > 0: logger.info(self.msg)
|
||||
self.bark_fn()
|
||||
break
|
||||
time.sleep(self.interval)
|
||||
|
||||
@@ -1,39 +1,47 @@
|
||||
import ast
|
||||
import token
|
||||
import tokenize
|
||||
import copy
|
||||
import io
|
||||
|
||||
class CommentRemover(ast.NodeTransformer):
|
||||
def visit_FunctionDef(self, node):
|
||||
# 移除函数的文档字符串
|
||||
if (node.body and isinstance(node.body[0], ast.Expr) and
|
||||
isinstance(node.body[0].value, ast.Str)):
|
||||
node.body = node.body[1:]
|
||||
self.generic_visit(node)
|
||||
return node
|
||||
|
||||
def visit_ClassDef(self, node):
|
||||
# 移除类的文档字符串
|
||||
if (node.body and isinstance(node.body[0], ast.Expr) and
|
||||
isinstance(node.body[0].value, ast.Str)):
|
||||
node.body = node.body[1:]
|
||||
self.generic_visit(node)
|
||||
return node
|
||||
def remove_python_comments(input_source: str) -> str:
|
||||
source_flag = copy.copy(input_source)
|
||||
source = io.StringIO(input_source)
|
||||
ls = input_source.split('\n')
|
||||
prev_toktype = token.INDENT
|
||||
readline = source.readline
|
||||
|
||||
def visit_Module(self, node):
|
||||
# 移除模块的文档字符串
|
||||
if (node.body and isinstance(node.body[0], ast.Expr) and
|
||||
isinstance(node.body[0].value, ast.Str)):
|
||||
node.body = node.body[1:]
|
||||
self.generic_visit(node)
|
||||
return node
|
||||
|
||||
def get_char_index(lineno, col):
|
||||
# find the index of the char in the source code
|
||||
if lineno == 1:
|
||||
return len('\n'.join(ls[:(lineno-1)])) + col
|
||||
else:
|
||||
return len('\n'.join(ls[:(lineno-1)])) + col + 1
|
||||
|
||||
def replace_char_between(start_lineno, start_col, end_lineno, end_col, source, replace_char, ls):
|
||||
# replace char between start_lineno, start_col and end_lineno, end_col with replace_char, but keep '\n' and ' '
|
||||
b = get_char_index(start_lineno, start_col)
|
||||
e = get_char_index(end_lineno, end_col)
|
||||
for i in range(b, e):
|
||||
if source[i] == '\n':
|
||||
source = source[:i] + '\n' + source[i+1:]
|
||||
elif source[i] == ' ':
|
||||
source = source[:i] + ' ' + source[i+1:]
|
||||
else:
|
||||
source = source[:i] + replace_char + source[i+1:]
|
||||
return source
|
||||
|
||||
tokgen = tokenize.generate_tokens(readline)
|
||||
for toktype, ttext, (slineno, scol), (elineno, ecol), ltext in tokgen:
|
||||
if toktype == token.STRING and (prev_toktype == token.INDENT):
|
||||
source_flag = replace_char_between(slineno, scol, elineno, ecol, source_flag, ' ', ls)
|
||||
elif toktype == token.STRING and (prev_toktype == token.NEWLINE):
|
||||
source_flag = replace_char_between(slineno, scol, elineno, ecol, source_flag, ' ', ls)
|
||||
elif toktype == tokenize.COMMENT:
|
||||
source_flag = replace_char_between(slineno, scol, elineno, ecol, source_flag, ' ', ls)
|
||||
prev_toktype = toktype
|
||||
return source_flag
|
||||
|
||||
def remove_python_comments(source_code):
|
||||
# 解析源代码为 AST
|
||||
tree = ast.parse(source_code)
|
||||
# 移除注释
|
||||
transformer = CommentRemover()
|
||||
tree = transformer.visit(tree)
|
||||
# 将处理后的 AST 转换回源代码
|
||||
return ast.unparse(tree)
|
||||
|
||||
# 示例使用
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from toolbox import CatchException, update_ui, promote_file_to_downloadzone
|
||||
from .crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||
import datetime, json
|
||||
|
||||
def fetch_items(list_of_items, batch_size):
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from toolbox import update_ui, get_conf, trimmed_format_exc, get_max_token, Singleton
|
||||
from shared_utils.char_visual_effect import scolling_visual_effect
|
||||
import threading
|
||||
import os
|
||||
import logging
|
||||
import threading
|
||||
from loguru import logger
|
||||
from shared_utils.char_visual_effect import scolling_visual_effect
|
||||
from toolbox import update_ui, get_conf, trimmed_format_exc, get_max_token, Singleton
|
||||
|
||||
def input_clipping(inputs, history, max_token_limit):
|
||||
def input_clipping(inputs, history, max_token_limit, return_clip_flags=False):
|
||||
"""
|
||||
当输入文本 + 历史文本超出最大限制时,采取措施丢弃一部分文本。
|
||||
输入:
|
||||
@@ -20,17 +20,20 @@ def input_clipping(inputs, history, max_token_limit):
|
||||
enc = model_info["gpt-3.5-turbo"]['tokenizer']
|
||||
def get_token_num(txt): return len(enc.encode(txt, disallowed_special=()))
|
||||
|
||||
|
||||
mode = 'input-and-history'
|
||||
# 当 输入部分的token占比 小于 全文的一半时,只裁剪历史
|
||||
input_token_num = get_token_num(inputs)
|
||||
original_input_len = len(inputs)
|
||||
if input_token_num < max_token_limit//2:
|
||||
mode = 'only-history'
|
||||
max_token_limit = max_token_limit - input_token_num
|
||||
|
||||
everything = [inputs] if mode == 'input-and-history' else ['']
|
||||
everything.extend(history)
|
||||
n_token = get_token_num('\n'.join(everything))
|
||||
full_token_num = n_token = get_token_num('\n'.join(everything))
|
||||
everything_token = [get_token_num(e) for e in everything]
|
||||
everything_token_num = sum(everything_token)
|
||||
delta = max(everything_token) // 16 # 截断时的颗粒度
|
||||
|
||||
while n_token > max_token_limit:
|
||||
@@ -43,10 +46,24 @@ def input_clipping(inputs, history, max_token_limit):
|
||||
|
||||
if mode == 'input-and-history':
|
||||
inputs = everything[0]
|
||||
full_token_num = everything_token_num
|
||||
else:
|
||||
pass
|
||||
full_token_num = everything_token_num + input_token_num
|
||||
|
||||
history = everything[1:]
|
||||
return inputs, history
|
||||
|
||||
flags = {
|
||||
"mode": mode,
|
||||
"original_input_token_num": input_token_num,
|
||||
"original_full_token_num": full_token_num,
|
||||
"original_input_len": original_input_len,
|
||||
"clipped_input_len": len(inputs),
|
||||
}
|
||||
|
||||
if not return_clip_flags:
|
||||
return inputs, history
|
||||
else:
|
||||
return inputs, history, flags
|
||||
|
||||
def request_gpt_model_in_new_thread_with_ui_alive(
|
||||
inputs, inputs_show_user, llm_kwargs,
|
||||
@@ -116,7 +133,7 @@ def request_gpt_model_in_new_thread_with_ui_alive(
|
||||
except:
|
||||
# 【第三种情况】:其他错误:重试几次
|
||||
tb_str = '```\n' + trimmed_format_exc() + '```'
|
||||
print(tb_str)
|
||||
logger.error(tb_str)
|
||||
mutable[0] += f"[Local Message] 警告,在执行过程中遭遇问题, Traceback:\n\n{tb_str}\n\n"
|
||||
if retry_op > 0:
|
||||
retry_op -= 1
|
||||
@@ -266,7 +283,7 @@ def request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
|
||||
# 【第三种情况】:其他错误
|
||||
if detect_timeout(): raise RuntimeError("检测到程序终止。")
|
||||
tb_str = '```\n' + trimmed_format_exc() + '```'
|
||||
print(tb_str)
|
||||
logger.error(tb_str)
|
||||
gpt_say += f"[Local Message] 警告,线程{index}在执行过程中遭遇问题, Traceback:\n\n{tb_str}\n\n"
|
||||
if len(mutable[index][0]) > 0: gpt_say += "此线程失败前收到的回答:\n\n" + mutable[index][0]
|
||||
if retry_op > 0:
|
||||
@@ -361,7 +378,7 @@ def read_and_clean_pdf_text(fp):
|
||||
import fitz, copy
|
||||
import re
|
||||
import numpy as np
|
||||
from shared_utils.colorful import print亮黄, print亮绿
|
||||
# from shared_utils.colorful import print亮黄, print亮绿
|
||||
fc = 0 # Index 0 文本
|
||||
fs = 1 # Index 1 字体
|
||||
fb = 2 # Index 2 框框
|
||||
@@ -578,7 +595,7 @@ class nougat_interface():
|
||||
def nougat_with_timeout(self, command, cwd, timeout=3600):
|
||||
import subprocess
|
||||
from toolbox import ProxyNetworkActivate
|
||||
logging.info(f'正在执行命令 {command}')
|
||||
logger.info(f'正在执行命令 {command}')
|
||||
with ProxyNetworkActivate("Nougat_Download"):
|
||||
process = subprocess.Popen(command, shell=False, cwd=cwd, env=os.environ)
|
||||
try:
|
||||
@@ -586,7 +603,7 @@ class nougat_interface():
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
stdout, stderr = process.communicate()
|
||||
print("Process timed out!")
|
||||
logger.error("Process timed out!")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
from textwrap import indent
|
||||
from loguru import logger
|
||||
|
||||
class FileNode:
|
||||
def __init__(self, name, build_manifest=False):
|
||||
@@ -60,7 +61,7 @@ class FileNode:
|
||||
current_node.children.append(term)
|
||||
|
||||
def print_files_recursively(self, level=0, code="R0"):
|
||||
print(' '*level + self.name + ' ' + str(self.is_leaf) + ' ' + str(self.level))
|
||||
logger.info(' '*level + self.name + ' ' + str(self.is_leaf) + ' ' + str(self.level))
|
||||
for j, child in enumerate(self.children):
|
||||
child.print_files_recursively(level=level+1, code=code+str(j))
|
||||
self.parenting_ship.extend(child.parenting_ship)
|
||||
@@ -123,4 +124,4 @@ if __name__ == "__main__":
|
||||
"用于加载和分割文件中的文本的通用文件加载器用于加载和分割文件中的文本的通用文件加载器用于加载和分割文件中的文本的通用文件加载器",
|
||||
"包含了用于构建和管理向量数据库的函数和类包含了用于构建和管理向量数据库的函数和类包含了用于构建和管理向量数据库的函数和类",
|
||||
]
|
||||
print(build_file_tree_mermaid_diagram(file_manifest, file_comments, "项目文件树"))
|
||||
logger.info(build_file_tree_mermaid_diagram(file_manifest, file_comments, "项目文件树"))
|
||||
@@ -0,0 +1,450 @@
|
||||
import os
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from docx import Document
|
||||
from docx.enum.style import WD_STYLE_TYPE
|
||||
from docx.enum.text import WD_PARAGRAPH_ALIGNMENT, WD_LINE_SPACING
|
||||
from docx.oxml.ns import qn
|
||||
from docx.shared import Inches, Cm
|
||||
from docx.shared import Pt, RGBColor, Inches
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
|
||||
class DocumentFormatter(ABC):
|
||||
"""文档格式化基类,定义文档格式化的基本接口"""
|
||||
|
||||
def __init__(self, final_summary: str, file_summaries_map: Dict, failed_files: List[Tuple]):
|
||||
self.final_summary = final_summary
|
||||
self.file_summaries_map = file_summaries_map
|
||||
self.failed_files = failed_files
|
||||
|
||||
@abstractmethod
|
||||
def format_failed_files(self) -> str:
|
||||
"""格式化失败文件列表"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def format_file_summaries(self) -> str:
|
||||
"""格式化文件总结内容"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_document(self) -> str:
|
||||
"""创建完整文档"""
|
||||
pass
|
||||
|
||||
|
||||
class WordFormatter(DocumentFormatter):
|
||||
"""Word格式文档生成器 - 符合中国政府公文格式规范(GB/T 9704-2012),并进行了优化"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.doc = Document()
|
||||
self._setup_document()
|
||||
self._create_styles()
|
||||
# 初始化三级标题编号系统
|
||||
self.numbers = {
|
||||
1: 0, # 一级标题编号
|
||||
2: 0, # 二级标题编号
|
||||
3: 0 # 三级标题编号
|
||||
}
|
||||
|
||||
def _setup_document(self):
|
||||
"""设置文档基本格式,包括页面设置和页眉"""
|
||||
sections = self.doc.sections
|
||||
for section in sections:
|
||||
# 设置页面大小为A4
|
||||
section.page_width = Cm(21)
|
||||
section.page_height = Cm(29.7)
|
||||
# 设置页边距
|
||||
section.top_margin = Cm(3.7) # 上边距37mm
|
||||
section.bottom_margin = Cm(3.5) # 下边距35mm
|
||||
section.left_margin = Cm(2.8) # 左边距28mm
|
||||
section.right_margin = Cm(2.6) # 右边距26mm
|
||||
# 设置页眉页脚距离
|
||||
section.header_distance = Cm(2.0)
|
||||
section.footer_distance = Cm(2.0)
|
||||
|
||||
# 添加页眉
|
||||
header = section.header
|
||||
header_para = header.paragraphs[0]
|
||||
header_para.alignment = WD_PARAGRAPH_ALIGNMENT.RIGHT
|
||||
header_run = header_para.add_run("该文档由GPT-academic生成")
|
||||
header_run.font.name = '仿宋'
|
||||
header_run._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
|
||||
header_run.font.size = Pt(9)
|
||||
|
||||
def _create_styles(self):
|
||||
"""创建文档样式"""
|
||||
# 创建正文样式
|
||||
style = self.doc.styles.add_style('Normal_Custom', WD_STYLE_TYPE.PARAGRAPH)
|
||||
style.font.name = '仿宋'
|
||||
style._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
|
||||
style.font.size = Pt(14)
|
||||
style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
|
||||
style.paragraph_format.space_after = Pt(0)
|
||||
style.paragraph_format.first_line_indent = Pt(28)
|
||||
|
||||
# 创建各级标题样式
|
||||
self._create_heading_style('Title_Custom', '方正小标宋简体', 32, WD_PARAGRAPH_ALIGNMENT.CENTER)
|
||||
self._create_heading_style('Heading1_Custom', '黑体', 22, WD_PARAGRAPH_ALIGNMENT.LEFT)
|
||||
self._create_heading_style('Heading2_Custom', '黑体', 18, WD_PARAGRAPH_ALIGNMENT.LEFT)
|
||||
self._create_heading_style('Heading3_Custom', '黑体', 16, WD_PARAGRAPH_ALIGNMENT.LEFT)
|
||||
|
||||
def _create_heading_style(self, style_name: str, font_name: str, font_size: int, alignment):
|
||||
"""创建标题样式"""
|
||||
style = self.doc.styles.add_style(style_name, WD_STYLE_TYPE.PARAGRAPH)
|
||||
style.font.name = font_name
|
||||
style._element.rPr.rFonts.set(qn('w:eastAsia'), font_name)
|
||||
style.font.size = Pt(font_size)
|
||||
style.font.bold = True
|
||||
style.paragraph_format.alignment = alignment
|
||||
style.paragraph_format.space_before = Pt(12)
|
||||
style.paragraph_format.space_after = Pt(12)
|
||||
style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
|
||||
return style
|
||||
|
||||
def _get_heading_number(self, level: int) -> str:
|
||||
"""
|
||||
生成标题编号
|
||||
|
||||
Args:
|
||||
level: 标题级别 (0-3)
|
||||
|
||||
Returns:
|
||||
str: 格式化的标题编号
|
||||
"""
|
||||
if level == 0: # 主标题不需要编号
|
||||
return ""
|
||||
|
||||
self.numbers[level] += 1 # 增加当前级别的编号
|
||||
|
||||
# 重置下级标题编号
|
||||
for i in range(level + 1, 4):
|
||||
self.numbers[i] = 0
|
||||
|
||||
# 根据级别返回不同格式的编号
|
||||
if level == 1:
|
||||
return f"{self.numbers[1]}. "
|
||||
elif level == 2:
|
||||
return f"{self.numbers[1]}.{self.numbers[2]} "
|
||||
elif level == 3:
|
||||
return f"{self.numbers[1]}.{self.numbers[2]}.{self.numbers[3]} "
|
||||
return ""
|
||||
|
||||
def _add_heading(self, text: str, level: int):
|
||||
"""
|
||||
添加带编号的标题
|
||||
|
||||
Args:
|
||||
text: 标题文本
|
||||
level: 标题级别 (0-3)
|
||||
"""
|
||||
style_map = {
|
||||
0: 'Title_Custom',
|
||||
1: 'Heading1_Custom',
|
||||
2: 'Heading2_Custom',
|
||||
3: 'Heading3_Custom'
|
||||
}
|
||||
|
||||
number = self._get_heading_number(level)
|
||||
paragraph = self.doc.add_paragraph(style=style_map[level])
|
||||
|
||||
if number:
|
||||
number_run = paragraph.add_run(number)
|
||||
font_size = 22 if level == 1 else (18 if level == 2 else 16)
|
||||
self._get_run_style(number_run, '黑体', font_size, True)
|
||||
|
||||
text_run = paragraph.add_run(text)
|
||||
font_size = 32 if level == 0 else (22 if level == 1 else (18 if level == 2 else 16))
|
||||
self._get_run_style(text_run, '黑体', font_size, True)
|
||||
|
||||
# 主标题添加日期
|
||||
if level == 0:
|
||||
date_paragraph = self.doc.add_paragraph()
|
||||
date_paragraph.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
|
||||
date_run = date_paragraph.add_run(datetime.now().strftime('%Y年%m月%d日'))
|
||||
self._get_run_style(date_run, '仿宋', 16, False)
|
||||
|
||||
return paragraph
|
||||
|
||||
def _get_run_style(self, run, font_name: str, font_size: int, bold: bool = False):
|
||||
"""设置文本运行对象的样式"""
|
||||
run.font.name = font_name
|
||||
run._element.rPr.rFonts.set(qn('w:eastAsia'), font_name)
|
||||
run.font.size = Pt(font_size)
|
||||
run.font.bold = bold
|
||||
|
||||
def format_failed_files(self) -> str:
|
||||
"""格式化失败文件列表"""
|
||||
result = []
|
||||
if not self.failed_files:
|
||||
return "\n".join(result)
|
||||
|
||||
result.append("处理失败文件:")
|
||||
for fp, reason in self.failed_files:
|
||||
result.append(f"• {os.path.basename(fp)}: {reason}")
|
||||
|
||||
self._add_heading("处理失败文件", 1)
|
||||
for fp, reason in self.failed_files:
|
||||
self._add_content(f"• {os.path.basename(fp)}: {reason}", indent=False)
|
||||
self.doc.add_paragraph()
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
def _add_content(self, text: str, indent: bool = True):
|
||||
"""添加正文内容"""
|
||||
paragraph = self.doc.add_paragraph(text, style='Normal_Custom')
|
||||
if not indent:
|
||||
paragraph.paragraph_format.first_line_indent = Pt(0)
|
||||
return paragraph
|
||||
|
||||
def format_file_summaries(self) -> str:
|
||||
"""
|
||||
格式化文件总结内容,确保正确的标题层级
|
||||
|
||||
返回:
|
||||
str: 格式化后的文件总结字符串
|
||||
|
||||
标题层级规则:
|
||||
1. 一级标题为"各文件详细总结"
|
||||
2. 如果文件有目录路径:
|
||||
- 目录路径作为二级标题 (2.1, 2.2 等)
|
||||
- 该目录下所有文件作为三级标题 (2.1.1, 2.1.2 等)
|
||||
3. 如果文件没有目录路径:
|
||||
- 文件直接作为二级标题 (2.1, 2.2 等)
|
||||
"""
|
||||
result = []
|
||||
# 首先对文件路径进行分组整理
|
||||
file_groups = {}
|
||||
for path in sorted(self.file_summaries_map.keys()):
|
||||
dir_path = os.path.dirname(path)
|
||||
if dir_path not in file_groups:
|
||||
file_groups[dir_path] = []
|
||||
file_groups[dir_path].append(path)
|
||||
|
||||
# 处理没有目录的文件
|
||||
root_files = file_groups.get("", [])
|
||||
if root_files:
|
||||
for path in sorted(root_files):
|
||||
file_name = os.path.basename(path)
|
||||
result.append(f"\n📄 {file_name}")
|
||||
result.append(self.file_summaries_map[path])
|
||||
# 无目录的文件作为二级标题
|
||||
self._add_heading(f"📄 {file_name}", 2)
|
||||
self._add_content(self.file_summaries_map[path])
|
||||
self.doc.add_paragraph()
|
||||
|
||||
# 处理有目录的文件
|
||||
for dir_path in sorted(file_groups.keys()):
|
||||
if dir_path == "": # 跳过已处理的根目录文件
|
||||
continue
|
||||
|
||||
# 添加目录作为二级标题
|
||||
result.append(f"\n📁 {dir_path}")
|
||||
self._add_heading(f"📁 {dir_path}", 2)
|
||||
|
||||
# 该目录下的所有文件作为三级标题
|
||||
for path in sorted(file_groups[dir_path]):
|
||||
file_name = os.path.basename(path)
|
||||
result.append(f"\n📄 {file_name}")
|
||||
result.append(self.file_summaries_map[path])
|
||||
|
||||
# 添加文件名作为三级标题
|
||||
self._add_heading(f"📄 {file_name}", 3)
|
||||
self._add_content(self.file_summaries_map[path])
|
||||
self.doc.add_paragraph()
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
def create_document(self):
|
||||
"""创建完整Word文档并返回文档对象"""
|
||||
# 重置所有编号
|
||||
for level in self.numbers:
|
||||
self.numbers[level] = 0
|
||||
|
||||
# 添加主标题
|
||||
self._add_heading("文档总结报告", 0)
|
||||
self.doc.add_paragraph()
|
||||
|
||||
# 添加总体摘要
|
||||
self._add_heading("总体摘要", 1)
|
||||
self._add_content(self.final_summary)
|
||||
self.doc.add_paragraph()
|
||||
|
||||
# 添加失败文件列表(如果有)
|
||||
if self.failed_files:
|
||||
self.format_failed_files()
|
||||
|
||||
# 添加文件详细总结
|
||||
self._add_heading("各文件详细总结", 1)
|
||||
self.format_file_summaries()
|
||||
|
||||
return self.doc
|
||||
|
||||
|
||||
class MarkdownFormatter(DocumentFormatter):
|
||||
"""Markdown格式文档生成器"""
|
||||
|
||||
def format_failed_files(self) -> str:
|
||||
if not self.failed_files:
|
||||
return ""
|
||||
|
||||
formatted_text = ["\n## ⚠️ 处理失败的文件"]
|
||||
for fp, reason in self.failed_files:
|
||||
formatted_text.append(f"- {os.path.basename(fp)}: {reason}")
|
||||
formatted_text.append("\n---")
|
||||
return "\n".join(formatted_text)
|
||||
|
||||
def format_file_summaries(self) -> str:
|
||||
formatted_text = []
|
||||
sorted_paths = sorted(self.file_summaries_map.keys())
|
||||
current_dir = ""
|
||||
|
||||
for path in sorted_paths:
|
||||
dir_path = os.path.dirname(path)
|
||||
if dir_path != current_dir:
|
||||
if dir_path:
|
||||
formatted_text.append(f"\n## 📁 {dir_path}")
|
||||
current_dir = dir_path
|
||||
|
||||
file_name = os.path.basename(path)
|
||||
formatted_text.append(f"\n### 📄 {file_name}")
|
||||
formatted_text.append(self.file_summaries_map[path])
|
||||
formatted_text.append("\n---")
|
||||
|
||||
return "\n".join(formatted_text)
|
||||
|
||||
def create_document(self) -> str:
|
||||
document = [
|
||||
"# 📑 文档总结报告",
|
||||
"\n## 总体摘要",
|
||||
self.final_summary
|
||||
]
|
||||
|
||||
if self.failed_files:
|
||||
document.append(self.format_failed_files())
|
||||
|
||||
document.extend([
|
||||
"\n# 📚 各文件详细总结",
|
||||
self.format_file_summaries()
|
||||
])
|
||||
|
||||
return "\n".join(document)
|
||||
|
||||
|
||||
class HtmlFormatter(DocumentFormatter):
|
||||
"""HTML格式文档生成器"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.css_styles = """
|
||||
body {
|
||||
font-family: "Microsoft YaHei", Arial, sans-serif;
|
||||
line-height: 1.6;
|
||||
max-width: 1000px;
|
||||
margin: 0 auto;
|
||||
padding: 20px;
|
||||
color: #333;
|
||||
}
|
||||
h1 {
|
||||
color: #2c3e50;
|
||||
border-bottom: 2px solid #eee;
|
||||
padding-bottom: 10px;
|
||||
font-size: 24px;
|
||||
text-align: center;
|
||||
}
|
||||
h2 {
|
||||
color: #34495e;
|
||||
margin-top: 30px;
|
||||
font-size: 20px;
|
||||
border-left: 4px solid #3498db;
|
||||
padding-left: 10px;
|
||||
}
|
||||
h3 {
|
||||
color: #2c3e50;
|
||||
font-size: 18px;
|
||||
margin-top: 20px;
|
||||
}
|
||||
.summary {
|
||||
background-color: #f8f9fa;
|
||||
padding: 20px;
|
||||
border-radius: 5px;
|
||||
margin: 20px 0;
|
||||
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
||||
}
|
||||
.details {
|
||||
margin-top: 40px;
|
||||
}
|
||||
.failed-files {
|
||||
background-color: #fff3f3;
|
||||
padding: 15px;
|
||||
border-left: 4px solid #e74c3c;
|
||||
margin: 20px 0;
|
||||
}
|
||||
.file-summary {
|
||||
background-color: #fff;
|
||||
padding: 15px;
|
||||
margin: 15px 0;
|
||||
border-radius: 4px;
|
||||
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
|
||||
}
|
||||
"""
|
||||
|
||||
def format_failed_files(self) -> str:
|
||||
if not self.failed_files:
|
||||
return ""
|
||||
|
||||
failed_files_html = ['<div class="failed-files">']
|
||||
failed_files_html.append("<h2>⚠️ 处理失败的文件</h2>")
|
||||
failed_files_html.append("<ul>")
|
||||
for fp, reason in self.failed_files:
|
||||
failed_files_html.append(f"<li><strong>{os.path.basename(fp)}:</strong> {reason}</li>")
|
||||
failed_files_html.append("</ul></div>")
|
||||
return "\n".join(failed_files_html)
|
||||
|
||||
def format_file_summaries(self) -> str:
|
||||
formatted_html = []
|
||||
sorted_paths = sorted(self.file_summaries_map.keys())
|
||||
current_dir = ""
|
||||
|
||||
for path in sorted_paths:
|
||||
dir_path = os.path.dirname(path)
|
||||
if dir_path != current_dir:
|
||||
if dir_path:
|
||||
formatted_html.append(f'<h2>📁 {dir_path}</h2>')
|
||||
current_dir = dir_path
|
||||
|
||||
file_name = os.path.basename(path)
|
||||
formatted_html.append('<div class="file-summary">')
|
||||
formatted_html.append(f'<h3>📄 {file_name}</h3>')
|
||||
formatted_html.append(f'<p>{self.file_summaries_map[path]}</p>')
|
||||
formatted_html.append('</div>')
|
||||
|
||||
return "\n".join(formatted_html)
|
||||
|
||||
def create_document(self) -> str:
|
||||
return f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset='utf-8'>
|
||||
<title>文档总结报告</title>
|
||||
<style>{self.css_styles}</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>📑 文档总结报告</h1>
|
||||
<h2>总体摘要</h2>
|
||||
<div class="summary">{self.final_summary}</div>
|
||||
{self.format_failed_files()}
|
||||
<div class="details">
|
||||
<h2>📚 各文件详细总结</h2>
|
||||
{self.format_file_summaries()}
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
@@ -0,0 +1,387 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional, Type, TypeVar, Generic, Union
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment
|
||||
|
||||
# 设置日志
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 自定义异常类定义
|
||||
class FoldingError(Exception):
|
||||
"""折叠相关的自定义异常基类"""
|
||||
pass
|
||||
|
||||
|
||||
class FormattingError(FoldingError):
|
||||
"""格式化过程中的错误"""
|
||||
pass
|
||||
|
||||
|
||||
class MetadataError(FoldingError):
|
||||
"""元数据相关的错误"""
|
||||
pass
|
||||
|
||||
|
||||
class ValidationError(FoldingError):
|
||||
"""验证错误"""
|
||||
pass
|
||||
|
||||
|
||||
class FoldingStyle(Enum):
|
||||
"""折叠样式枚举"""
|
||||
SIMPLE = auto() # 简单折叠
|
||||
DETAILED = auto() # 详细折叠(带有额外信息)
|
||||
NESTED = auto() # 嵌套折叠
|
||||
|
||||
|
||||
@dataclass
|
||||
class FoldingOptions:
|
||||
"""折叠选项配置"""
|
||||
style: FoldingStyle = FoldingStyle.DETAILED
|
||||
code_language: Optional[str] = None # 代码块的语言
|
||||
show_timestamp: bool = False # 是否显示时间戳
|
||||
indent_level: int = 0 # 缩进级别
|
||||
custom_css: Optional[str] = None # 自定义CSS类
|
||||
|
||||
|
||||
T = TypeVar('T') # 用于泛型类型
|
||||
|
||||
|
||||
class BaseMetadata(ABC):
|
||||
"""元数据基类"""
|
||||
|
||||
@abstractmethod
|
||||
def validate(self) -> bool:
|
||||
"""验证元数据的有效性"""
|
||||
pass
|
||||
|
||||
def _validate_non_empty_str(self, value: Optional[str]) -> bool:
|
||||
"""验证字符串非空"""
|
||||
return bool(value and value.strip())
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileMetadata(BaseMetadata):
|
||||
"""文件元数据"""
|
||||
rel_path: str
|
||||
size: float
|
||||
last_modified: Optional[datetime] = None
|
||||
mime_type: Optional[str] = None
|
||||
encoding: str = 'utf-8'
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""验证文件元数据的有效性"""
|
||||
try:
|
||||
if not self._validate_non_empty_str(self.rel_path):
|
||||
return False
|
||||
if self.size < 0:
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"File metadata validation error: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
|
||||
|
||||
class ContentFormatter(ABC, Generic[T]):
|
||||
"""内容格式化抽象基类
|
||||
|
||||
支持泛型类型参数,可以指定具体的元数据类型。
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def format(self,
|
||||
content: str,
|
||||
metadata: T,
|
||||
options: Optional[FoldingOptions] = None) -> str:
|
||||
"""格式化内容
|
||||
|
||||
Args:
|
||||
content: 需要格式化的内容
|
||||
metadata: 类型化的元数据
|
||||
options: 折叠选项
|
||||
|
||||
Returns:
|
||||
str: 格式化后的内容
|
||||
|
||||
Raises:
|
||||
FormattingError: 格式化过程中的错误
|
||||
"""
|
||||
pass
|
||||
|
||||
def _create_summary(self, metadata: T) -> str:
|
||||
"""创建折叠摘要,可被子类重写"""
|
||||
return str(metadata)
|
||||
|
||||
def _format_content_block(self,
|
||||
content: str,
|
||||
options: Optional[FoldingOptions]) -> str:
|
||||
"""格式化内容块,处理代码块等特殊格式"""
|
||||
if not options:
|
||||
return content
|
||||
|
||||
if options.code_language:
|
||||
return f"```{options.code_language}\n{content}\n```"
|
||||
return content
|
||||
|
||||
def _add_indent(self, text: str, level: int) -> str:
|
||||
"""添加缩进"""
|
||||
if level <= 0:
|
||||
return text
|
||||
indent = " " * level
|
||||
return "\n".join(indent + line for line in text.splitlines())
|
||||
|
||||
|
||||
class FileContentFormatter(ContentFormatter[FileMetadata]):
|
||||
"""文件内容格式化器"""
|
||||
|
||||
def format(self,
|
||||
content: str,
|
||||
metadata: FileMetadata,
|
||||
options: Optional[FoldingOptions] = None) -> str:
|
||||
"""格式化文件内容"""
|
||||
if not metadata.validate():
|
||||
raise MetadataError("Invalid file metadata")
|
||||
|
||||
try:
|
||||
options = options or FoldingOptions()
|
||||
|
||||
# 构建摘要信息
|
||||
summary_parts = [
|
||||
f"{metadata.rel_path} ({metadata.size:.2f}MB)",
|
||||
f"Type: {metadata.mime_type}" if metadata.mime_type else None,
|
||||
(f"Modified: {metadata.last_modified.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
if metadata.last_modified and options.show_timestamp else None)
|
||||
]
|
||||
summary = " | ".join(filter(None, summary_parts))
|
||||
|
||||
# 构建HTML类
|
||||
css_class = f' class="{options.custom_css}"' if options.custom_css else ''
|
||||
|
||||
# 格式化内容
|
||||
formatted_content = self._format_content_block(content, options)
|
||||
|
||||
# 组装最终结果
|
||||
result = (
|
||||
f'<details{css_class}><summary>{summary}</summary>\n\n'
|
||||
f'{formatted_content}\n\n'
|
||||
f'</details>\n\n'
|
||||
)
|
||||
|
||||
return self._add_indent(result, options.indent_level)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting file content: {str(e)}")
|
||||
raise FormattingError(f"Failed to format file content: {str(e)}")
|
||||
|
||||
|
||||
class ContentFoldingManager:
|
||||
"""内容折叠管理器"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化折叠管理器"""
|
||||
self._formatters: Dict[str, ContentFormatter] = {}
|
||||
self._register_default_formatters()
|
||||
|
||||
def _register_default_formatters(self) -> None:
|
||||
"""注册默认的格式化器"""
|
||||
self.register_formatter('file', FileContentFormatter())
|
||||
|
||||
def register_formatter(self, name: str, formatter: ContentFormatter) -> None:
|
||||
"""注册新的格式化器"""
|
||||
if not isinstance(formatter, ContentFormatter):
|
||||
raise TypeError("Formatter must implement ContentFormatter interface")
|
||||
self._formatters[name] = formatter
|
||||
|
||||
def _guess_language(self, extension: str) -> Optional[str]:
|
||||
"""根据文件扩展名猜测编程语言"""
|
||||
extension = extension.lower().lstrip('.')
|
||||
language_map = {
|
||||
'py': 'python',
|
||||
'js': 'javascript',
|
||||
'java': 'java',
|
||||
'cpp': 'cpp',
|
||||
'cs': 'csharp',
|
||||
'html': 'html',
|
||||
'css': 'css',
|
||||
'md': 'markdown',
|
||||
'json': 'json',
|
||||
'xml': 'xml',
|
||||
'sql': 'sql',
|
||||
'sh': 'bash',
|
||||
'yaml': 'yaml',
|
||||
'yml': 'yaml',
|
||||
'txt': None # 纯文本不需要语言标识
|
||||
}
|
||||
return language_map.get(extension)
|
||||
|
||||
def format_content(self,
|
||||
content: str,
|
||||
formatter_type: str,
|
||||
metadata: Union[FileMetadata],
|
||||
options: Optional[FoldingOptions] = None) -> str:
|
||||
"""格式化内容"""
|
||||
formatter = self._formatters.get(formatter_type)
|
||||
if not formatter:
|
||||
raise KeyError(f"No formatter registered for type: {formatter_type}")
|
||||
|
||||
if not isinstance(metadata, FileMetadata):
|
||||
raise TypeError("Invalid metadata type")
|
||||
|
||||
return formatter.format(content, metadata, options)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PaperMetadata(BaseMetadata):
|
||||
"""论文元数据"""
|
||||
title: str
|
||||
authors: str
|
||||
abstract: str
|
||||
catalogs: str
|
||||
arxiv_id: str = ""
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""验证论文元数据的有效性"""
|
||||
try:
|
||||
if not self._validate_non_empty_str(self.title):
|
||||
return False
|
||||
if not self._validate_non_empty_str(self.authors):
|
||||
return False
|
||||
if not self._validate_non_empty_str(self.abstract):
|
||||
return False
|
||||
if not self._validate_non_empty_str(self.catalogs):
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Paper metadata validation error: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
class PaperContentFormatter(ContentFormatter[PaperMetadata]):
|
||||
"""论文内容格式化器"""
|
||||
|
||||
def format(self,
|
||||
fragments: list[SectionFragment],
|
||||
metadata: PaperMetadata,
|
||||
options: Optional[FoldingOptions] = None) -> str:
|
||||
"""格式化论文内容
|
||||
|
||||
Args:
|
||||
fragments: 论文片段列表
|
||||
metadata: 论文元数据
|
||||
options: 折叠选项
|
||||
|
||||
Returns:
|
||||
str: 格式化后的论文内容
|
||||
"""
|
||||
if not metadata.validate():
|
||||
raise MetadataError("Invalid paper metadata")
|
||||
|
||||
try:
|
||||
options = options or FoldingOptions()
|
||||
|
||||
# 1. 生成标题部分(不折叠)
|
||||
result = [f"# {metadata.title}\n"]
|
||||
|
||||
# 2. 生成作者信息(折叠)
|
||||
result.append(self._create_folded_section(
|
||||
"Authors",
|
||||
metadata.authors,
|
||||
options
|
||||
))
|
||||
|
||||
# 3. 生成摘要(折叠)
|
||||
result.append(self._create_folded_section(
|
||||
"Abstract",
|
||||
metadata.abstract,
|
||||
options
|
||||
))
|
||||
|
||||
# 4. 生成目录树(折叠)
|
||||
result.append(self._create_folded_section(
|
||||
"Table of Contents",
|
||||
f"```\n{metadata.catalogs}\n```",
|
||||
options
|
||||
))
|
||||
|
||||
# 5. 按章节组织并生成内容
|
||||
sections = self._organize_sections(fragments)
|
||||
for section, section_fragments in sections.items():
|
||||
# 拼接该章节的所有内容
|
||||
section_content = "\n\n".join(
|
||||
fragment.content for fragment in section_fragments
|
||||
)
|
||||
|
||||
result.append(self._create_folded_section(
|
||||
section,
|
||||
section_content,
|
||||
options
|
||||
))
|
||||
|
||||
# 6. 生成参考文献(折叠)
|
||||
# 收集所有非空的参考文献
|
||||
all_refs = "\n".join(filter(None,
|
||||
(fragment.bibliography for fragment in fragments)
|
||||
))
|
||||
if all_refs:
|
||||
result.append(self._create_folded_section(
|
||||
"Bibliography",
|
||||
f"```bibtex\n{all_refs}\n```",
|
||||
options
|
||||
))
|
||||
|
||||
return "\n\n".join(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting paper content: {str(e)}")
|
||||
raise FormattingError(f"Failed to format paper content: {str(e)}")
|
||||
|
||||
def _create_folded_section(self,
|
||||
title: str,
|
||||
content: str,
|
||||
options: FoldingOptions) -> str:
|
||||
"""创建折叠区块
|
||||
|
||||
Args:
|
||||
title: 区块标题
|
||||
content: 区块内容
|
||||
options: 折叠选项
|
||||
|
||||
Returns:
|
||||
str: 格式化后的折叠区块
|
||||
"""
|
||||
css_class = f' class="{options.custom_css}"' if options.custom_css else ''
|
||||
|
||||
result = (
|
||||
f'<details{css_class}><summary>{title}</summary>\n\n'
|
||||
f'{content}\n\n'
|
||||
f'</details>'
|
||||
)
|
||||
|
||||
return self._add_indent(result, options.indent_level)
|
||||
|
||||
def _organize_sections(self,
|
||||
fragments: list[SectionFragment]
|
||||
) -> Dict[str, list[SectionFragment]]:
|
||||
"""将片段按章节分组
|
||||
|
||||
Args:
|
||||
fragments: 论文片段列表
|
||||
|
||||
Returns:
|
||||
Dict[str, list[SectionFragment]]: 按章节分组的片段字典
|
||||
"""
|
||||
sections: Dict[str, list[SectionFragment]] = {}
|
||||
|
||||
for fragment in fragments:
|
||||
section = fragment.current_section or "Uncategorized"
|
||||
if section not in sections:
|
||||
sections[section] = []
|
||||
sections[section].append(fragment)
|
||||
|
||||
return sections
|
||||
@@ -0,0 +1,354 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Dict
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
@dataclass
|
||||
class SectionFragment:
|
||||
"""Arxiv论文片段数据类"""
|
||||
title: str
|
||||
authors: str
|
||||
abstract: str
|
||||
catalogs: str
|
||||
arxiv_id: str = ""
|
||||
current_section: str = "Introduction"
|
||||
content: str = ''
|
||||
bibliography: str = ''
|
||||
|
||||
|
||||
class PaperHtmlFormatter:
|
||||
"""HTML格式论文文档生成器"""
|
||||
|
||||
def __init__(self, fragments: List[SectionFragment], output_dir: Path):
|
||||
self.fragments = fragments
|
||||
self.output_dir = output_dir
|
||||
self.css_styles = """
|
||||
:root {
|
||||
--primary-color: #1a73e8;
|
||||
--secondary-color: #34495e;
|
||||
--background-color: #f8f9fa;
|
||||
--text-color: #2c3e50;
|
||||
--border-color: #e0e0e0;
|
||||
--code-bg-color: #f6f8fa;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: "Source Serif Pro", "Times New Roman", serif;
|
||||
line-height: 1.8;
|
||||
max-width: 1000px;
|
||||
margin: 0 auto;
|
||||
padding: 2rem;
|
||||
color: var(--text-color);
|
||||
background-color: var(--background-color);
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
.container {
|
||||
background: white;
|
||||
padding: 2rem;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 12px rgba(0,0,0,0.1);
|
||||
}
|
||||
|
||||
h1 {
|
||||
color: var(--primary-color);
|
||||
font-size: 2.2em;
|
||||
text-align: center;
|
||||
margin: 1.5rem 0;
|
||||
padding-bottom: 1rem;
|
||||
border-bottom: 3px solid var(--primary-color);
|
||||
}
|
||||
|
||||
h2 {
|
||||
color: var(--secondary-color);
|
||||
font-size: 1.8em;
|
||||
margin-top: 2rem;
|
||||
padding-left: 1rem;
|
||||
border-left: 4px solid var(--primary-color);
|
||||
}
|
||||
|
||||
h3 {
|
||||
color: var(--text-color);
|
||||
font-size: 1.5em;
|
||||
margin-top: 1.5rem;
|
||||
border-bottom: 2px solid var(--border-color);
|
||||
padding-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.authors {
|
||||
text-align: center;
|
||||
color: var(--secondary-color);
|
||||
font-size: 1.1em;
|
||||
margin: 1rem 0 2rem;
|
||||
}
|
||||
|
||||
.abstract-container {
|
||||
background: var(--background-color);
|
||||
padding: 1.5rem;
|
||||
border-radius: 6px;
|
||||
margin: 2rem 0;
|
||||
}
|
||||
|
||||
.abstract-title {
|
||||
font-weight: bold;
|
||||
color: var(--primary-color);
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.abstract-content {
|
||||
font-style: italic;
|
||||
line-height: 1.7;
|
||||
}
|
||||
|
||||
.toc {
|
||||
background: white;
|
||||
padding: 1.5rem;
|
||||
border-radius: 6px;
|
||||
margin: 2rem 0;
|
||||
box-shadow: 0 2px 8px rgba(0,0,0,0.05);
|
||||
}
|
||||
|
||||
.toc-title {
|
||||
color: var(--primary-color);
|
||||
font-size: 1.4em;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.section-content {
|
||||
background: white;
|
||||
padding: 1.5rem;
|
||||
border-radius: 6px;
|
||||
margin: 1.5rem 0;
|
||||
box-shadow: 0 1px 3px rgba(0,0,0,0.05);
|
||||
}
|
||||
|
||||
.fragment {
|
||||
margin: 2rem 0;
|
||||
padding-left: 1rem;
|
||||
border-left: 3px solid var(--border-color);
|
||||
}
|
||||
|
||||
.fragment:hover {
|
||||
border-left-color: var(--primary-color);
|
||||
}
|
||||
|
||||
.bibliography {
|
||||
background: var(--code-bg-color);
|
||||
padding: 1rem;
|
||||
border-radius: 4px;
|
||||
font-family: "Source Code Pro", monospace;
|
||||
font-size: 0.9em;
|
||||
white-space: pre-wrap;
|
||||
margin-top: 1rem;
|
||||
}
|
||||
|
||||
pre {
|
||||
background: var(--code-bg-color);
|
||||
padding: 1rem;
|
||||
border-radius: 4px;
|
||||
overflow-x: auto;
|
||||
font-family: "Source Code Pro", monospace;
|
||||
}
|
||||
|
||||
.paper-info {
|
||||
background: white;
|
||||
padding: 2rem;
|
||||
border-radius: 8px;
|
||||
margin: 2rem 0;
|
||||
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
|
||||
}
|
||||
|
||||
.arxiv-id {
|
||||
text-align: center;
|
||||
color: #666;
|
||||
font-size: 0.9em;
|
||||
margin: 1rem 0;
|
||||
}
|
||||
|
||||
.section-title {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
color: var(--secondary-color);
|
||||
}
|
||||
|
||||
.section-icon {
|
||||
color: var(--primary-color);
|
||||
}
|
||||
|
||||
@media print {
|
||||
body {
|
||||
background: white;
|
||||
}
|
||||
.container {
|
||||
box-shadow: none;
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
def _sanitize_html(self, text: str) -> str:
|
||||
"""清理HTML特殊字符"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
replacements = {
|
||||
"&": "&",
|
||||
"<": "<",
|
||||
">": ">",
|
||||
'"': """,
|
||||
"'": "'"
|
||||
}
|
||||
|
||||
for old, new in replacements.items():
|
||||
text = text.replace(old, new)
|
||||
return text
|
||||
|
||||
def _create_section_id(self, section: str) -> str:
|
||||
"""创建section的ID"""
|
||||
section = section.strip() or "uncategorized"
|
||||
# 移除特殊字符,转换为小写并用连字符替换空格
|
||||
section_id = re.sub(r'[^\w\s-]', '', section.lower())
|
||||
return section_id.replace(' ', '-')
|
||||
|
||||
def format_paper_info(self) -> str:
|
||||
"""格式化论文基本信息"""
|
||||
if not self.fragments:
|
||||
return ""
|
||||
|
||||
first_fragment = self.fragments[0]
|
||||
paper_info = ['<div class="paper-info">']
|
||||
|
||||
# 添加标题
|
||||
if first_fragment.title:
|
||||
paper_info.append(f'<h1>{self._sanitize_html(first_fragment.title)}</h1>')
|
||||
|
||||
# 添加arXiv ID
|
||||
if first_fragment.arxiv_id:
|
||||
paper_info.append(f'<div class="arxiv-id">arXiv: {self._sanitize_html(first_fragment.arxiv_id)}</div>')
|
||||
|
||||
# 添加作者
|
||||
if first_fragment.authors:
|
||||
paper_info.append(f'<div class="authors">{self._sanitize_html(first_fragment.authors)}</div>')
|
||||
|
||||
# 添加摘要
|
||||
if first_fragment.abstract:
|
||||
paper_info.append('<div class="abstract-container">')
|
||||
paper_info.append('<div class="abstract-title">Abstract</div>')
|
||||
paper_info.append(f'<div class="abstract-content">{self._sanitize_html(first_fragment.abstract)}</div>')
|
||||
paper_info.append('</div>')
|
||||
|
||||
# 添加目录结构
|
||||
if first_fragment.catalogs:
|
||||
paper_info.append('<h2>Document Structure</h2>')
|
||||
paper_info.append('<pre>')
|
||||
paper_info.append(self._sanitize_html(first_fragment.catalogs))
|
||||
paper_info.append('</pre>')
|
||||
|
||||
paper_info.append('</div>')
|
||||
return '\n'.join(paper_info)
|
||||
|
||||
def format_table_of_contents(self, sections: Dict[str, List[SectionFragment]]) -> str:
|
||||
"""生成目录"""
|
||||
toc = ['<div class="toc">']
|
||||
toc.append('<div class="toc-title">Table of Contents</div>')
|
||||
toc.append('<nav>')
|
||||
|
||||
for section in sections:
|
||||
section_id = self._create_section_id(section)
|
||||
clean_section = section.strip() or "Uncategorized"
|
||||
toc.append(f'<div><a href="#{section_id}">{self._sanitize_html(clean_section)} '
|
||||
f'</a></div>')
|
||||
|
||||
toc.append('</nav>')
|
||||
toc.append('</div>')
|
||||
return '\n'.join(toc)
|
||||
|
||||
def format_sections(self) -> str:
|
||||
"""格式化论文各部分内容"""
|
||||
sections = {}
|
||||
for fragment in self.fragments:
|
||||
section = fragment.current_section or "Uncategorized"
|
||||
if section not in sections:
|
||||
sections[section] = []
|
||||
sections[section].append(fragment)
|
||||
|
||||
formatted_html = ['<div class="content">']
|
||||
formatted_html.append(self.format_table_of_contents(sections))
|
||||
|
||||
# 生成各部分内容
|
||||
for section, fragments in sections.items():
|
||||
section_id = self._create_section_id(section)
|
||||
formatted_html.append(f'<h2 id="{section_id}">')
|
||||
formatted_html.append(f'<span class="section-title">')
|
||||
formatted_html.append(f'<span class="section-icon">§</span>')
|
||||
formatted_html.append(f'{self._sanitize_html(section)}')
|
||||
formatted_html.append('</span>')
|
||||
formatted_html.append('</h2>')
|
||||
|
||||
formatted_html.append('<div class="section-content">')
|
||||
|
||||
for i, fragment in enumerate(fragments, 1):
|
||||
formatted_html.append('<div class="fragment">')
|
||||
|
||||
# 添加内容
|
||||
if fragment.content:
|
||||
formatted_html.append(
|
||||
f'<div class="fragment-content">{self._sanitize_html(fragment.content)}</div>'
|
||||
)
|
||||
|
||||
# 添加参考文献
|
||||
if fragment.bibliography:
|
||||
formatted_html.append('<div class="bibliography">')
|
||||
formatted_html.append(f'{self._sanitize_html(fragment.bibliography)}')
|
||||
formatted_html.append('</div>')
|
||||
|
||||
formatted_html.append('</div>')
|
||||
|
||||
formatted_html.append('</div>')
|
||||
|
||||
formatted_html.append('</div>')
|
||||
return '\n'.join(formatted_html)
|
||||
|
||||
def save_html(self) -> Path:
|
||||
"""保存HTML文档"""
|
||||
try:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"paper_content_{timestamp}.html"
|
||||
file_path = self.output_dir / filename
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>{self._sanitize_html(self.fragments[0].title if self.fragments else 'Paper Content')}</title>
|
||||
<style>
|
||||
{self.css_styles}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
{self.format_paper_info()}
|
||||
{self.format_sections()}
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(html_content)
|
||||
|
||||
print(f"HTML document saved to: {file_path}")
|
||||
return file_path
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error saving HTML document: {str(e)}")
|
||||
raise
|
||||
|
||||
# 使用示例:
|
||||
# formatter = PaperHtmlFormatter(fragments, output_dir)
|
||||
# output_path = formatter.save_html()
|
||||
@@ -24,8 +24,8 @@ class Actor(BaseModel):
|
||||
film_names: List[str] = Field(description="list of names of films they starred in")
|
||||
"""
|
||||
|
||||
import json, re, logging
|
||||
|
||||
import json, re
|
||||
from loguru import logger as logging
|
||||
|
||||
PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
|
||||
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
from crazy_functions.json_fns.pydantic_io import GptJsonIO, JsonStringError
|
||||
|
||||
def structure_output(txt, prompt, err_msg, run_gpt_fn, pydantic_cls):
|
||||
gpt_json_io = GptJsonIO(pydantic_cls)
|
||||
analyze_res = run_gpt_fn(
|
||||
txt,
|
||||
sys_prompt=prompt + gpt_json_io.format_instructions
|
||||
)
|
||||
try:
|
||||
friend = gpt_json_io.generate_output_auto_repair(analyze_res, run_gpt_fn)
|
||||
except JsonStringError as e:
|
||||
return None, err_msg
|
||||
|
||||
err_msg = ""
|
||||
return friend, err_msg
|
||||
|
||||
|
||||
def select_tool(prompt, run_gpt_fn, pydantic_cls):
|
||||
pydantic_cls_instance, err_msg = structure_output(
|
||||
txt=prompt,
|
||||
prompt="根据提示, 分析应该调用哪个工具函数\n\n",
|
||||
err_msg=f"不能理解该联系人",
|
||||
run_gpt_fn=run_gpt_fn,
|
||||
pydantic_cls=pydantic_cls
|
||||
)
|
||||
return pydantic_cls_instance, err_msg
|
||||
@@ -1,15 +1,17 @@
|
||||
from toolbox import update_ui, update_ui_lastest_msg, get_log_folder
|
||||
from toolbox import get_conf, promote_file_to_downloadzone
|
||||
from .latex_toolbox import PRESERVE, TRANSFORM
|
||||
from .latex_toolbox import set_forbidden_text, set_forbidden_text_begin_end, set_forbidden_text_careful_brace
|
||||
from .latex_toolbox import reverse_forbidden_text_careful_brace, reverse_forbidden_text, convert_to_linklist, post_process
|
||||
from .latex_toolbox import fix_content, find_main_tex_file, merge_tex_files, compile_latex_with_timeout
|
||||
from .latex_toolbox import find_title_and_abs
|
||||
from .latex_pickle_io import objdump, objload
|
||||
|
||||
import os, shutil
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from toolbox import update_ui, update_ui_lastest_msg, get_log_folder, gen_time_str
|
||||
from toolbox import get_conf, promote_file_to_downloadzone
|
||||
from crazy_functions.latex_fns.latex_toolbox import PRESERVE, TRANSFORM
|
||||
from crazy_functions.latex_fns.latex_toolbox import set_forbidden_text, set_forbidden_text_begin_end, set_forbidden_text_careful_brace
|
||||
from crazy_functions.latex_fns.latex_toolbox import reverse_forbidden_text_careful_brace, reverse_forbidden_text, convert_to_linklist, post_process
|
||||
from crazy_functions.latex_fns.latex_toolbox import fix_content, find_main_tex_file, merge_tex_files, compile_latex_with_timeout
|
||||
from crazy_functions.latex_fns.latex_toolbox import find_title_and_abs
|
||||
from crazy_functions.latex_fns.latex_pickle_io import objdump, objload
|
||||
|
||||
|
||||
pj = os.path.join
|
||||
|
||||
@@ -323,7 +325,7 @@ def remove_buggy_lines(file_path, log_path, tex_name, tex_name_pure, n_fix, work
|
||||
buggy_lines = [int(l) for l in buggy_lines]
|
||||
buggy_lines = sorted(buggy_lines)
|
||||
buggy_line = buggy_lines[0]-1
|
||||
print("reversing tex line that has errors", buggy_line)
|
||||
logger.warning("reversing tex line that has errors", buggy_line)
|
||||
|
||||
# 重组,逆转出错的段落
|
||||
if buggy_line not in fixed_line:
|
||||
@@ -337,7 +339,7 @@ def remove_buggy_lines(file_path, log_path, tex_name, tex_name_pure, n_fix, work
|
||||
|
||||
return True, f"{tex_name_pure}_fix_{n_fix}", buggy_lines
|
||||
except:
|
||||
print("Fatal error occurred, but we cannot identify error, please download zip, read latex log, and compile manually.")
|
||||
logger.error("Fatal error occurred, but we cannot identify error, please download zip, read latex log, and compile manually.")
|
||||
return False, -1, [-1]
|
||||
|
||||
|
||||
@@ -380,7 +382,7 @@ def 编译Latex(chatbot, history, main_file_original, main_file_modified, work_f
|
||||
|
||||
if mode!='translate_zh':
|
||||
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 使用latexdiff生成论文转化前后对比 ...', chatbot, history) # 刷新Gradio前端界面
|
||||
print( f'latexdiff --encoding=utf8 --append-safecmd=subfile {work_folder_original}/{main_file_original}.tex {work_folder_modified}/{main_file_modified}.tex --flatten > {work_folder}/merge_diff.tex')
|
||||
logger.info( f'latexdiff --encoding=utf8 --append-safecmd=subfile {work_folder_original}/{main_file_original}.tex {work_folder_modified}/{main_file_modified}.tex --flatten > {work_folder}/merge_diff.tex')
|
||||
ok = compile_latex_with_timeout(f'latexdiff --encoding=utf8 --append-safecmd=subfile {work_folder_original}/{main_file_original}.tex {work_folder_modified}/{main_file_modified}.tex --flatten > {work_folder}/merge_diff.tex', os.getcwd())
|
||||
|
||||
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 正在编译对比PDF ...', chatbot, history) # 刷新Gradio前端界面
|
||||
@@ -419,7 +421,7 @@ def 编译Latex(chatbot, history, main_file_original, main_file_modified, work_f
|
||||
shutil.copyfile(concat_pdf, pj(work_folder, '..', 'translation', 'comparison.pdf'))
|
||||
promote_file_to_downloadzone(concat_pdf, rename_file=None, chatbot=chatbot) # promote file to web UI
|
||||
except Exception as e:
|
||||
print(e)
|
||||
logger.error(e)
|
||||
pass
|
||||
return True # 成功啦
|
||||
else:
|
||||
@@ -465,4 +467,71 @@ def write_html(sp_file_contents, sp_file_result, chatbot, project_folder):
|
||||
promote_file_to_downloadzone(file=res, chatbot=chatbot)
|
||||
except:
|
||||
from toolbox import trimmed_format_exc
|
||||
print('writing html result failed:', trimmed_format_exc())
|
||||
logger.error('writing html result failed:', trimmed_format_exc())
|
||||
|
||||
|
||||
def upload_to_gptac_cloud_if_user_allow(chatbot, arxiv_id):
|
||||
try:
|
||||
# 如果用户允许,我们将arxiv论文PDF上传到GPTAC学术云
|
||||
from toolbox import map_file_to_sha256
|
||||
# 检查是否顺利,如果没有生成预期的文件,则跳过
|
||||
is_result_good = False
|
||||
for file_path in chatbot._cookies.get("files_to_promote", []):
|
||||
if file_path.endswith('translate_zh.pdf'):
|
||||
is_result_good = True
|
||||
if not is_result_good:
|
||||
return
|
||||
# 上传文件
|
||||
for file_path in chatbot._cookies.get("files_to_promote", []):
|
||||
align_name = None
|
||||
# normalized name
|
||||
for name in ['translate_zh.pdf', 'comparison.pdf']:
|
||||
if file_path.endswith(name): align_name = name
|
||||
# if match any align name
|
||||
if align_name:
|
||||
logger.info(f'Uploading to GPTAC cloud as the user has set `allow_cloud_io`: {file_path}')
|
||||
with open(file_path, 'rb') as f:
|
||||
import requests
|
||||
url = 'https://cloud-2.agent-matrix.com/arxiv_tf_paper_normal_upload'
|
||||
files = {'file': (align_name, f, 'application/octet-stream')}
|
||||
data = {
|
||||
'arxiv_id': arxiv_id,
|
||||
'file_hash': map_file_to_sha256(file_path),
|
||||
'language': 'zh',
|
||||
'trans_prompt': 'to_be_implemented',
|
||||
'llm_model': 'to_be_implemented',
|
||||
'llm_model_param': 'to_be_implemented',
|
||||
}
|
||||
resp = requests.post(url=url, files=files, data=data, timeout=30)
|
||||
logger.info(f'Uploading terminate ({resp.status_code})`: {file_path}')
|
||||
except:
|
||||
# 如果上传失败,不会中断程序,因为这是次要功能
|
||||
pass
|
||||
|
||||
def check_gptac_cloud(arxiv_id, chatbot):
|
||||
import requests
|
||||
success = False
|
||||
downloaded = []
|
||||
try:
|
||||
for pdf_target in ['translate_zh.pdf', 'comparison.pdf']:
|
||||
url = 'https://cloud-2.agent-matrix.com/arxiv_tf_paper_normal_exist'
|
||||
data = {
|
||||
'arxiv_id': arxiv_id,
|
||||
'name': pdf_target,
|
||||
}
|
||||
resp = requests.post(url=url, data=data)
|
||||
cache_hit_result = resp.text.strip('"')
|
||||
if cache_hit_result.startswith("http"):
|
||||
url = cache_hit_result
|
||||
logger.info(f'Downloading from GPTAC cloud: {url}')
|
||||
resp = requests.get(url=url, timeout=30)
|
||||
target = os.path.join(get_log_folder(plugin_name='gptac_cloud'), gen_time_str(), pdf_target)
|
||||
os.makedirs(os.path.dirname(target), exist_ok=True)
|
||||
with open(target, 'wb') as f:
|
||||
f.write(resp.content)
|
||||
new_path = promote_file_to_downloadzone(target, chatbot=chatbot)
|
||||
success = True
|
||||
downloaded.append(new_path)
|
||||
except:
|
||||
pass
|
||||
return success, downloaded
|
||||
|
||||
@@ -6,12 +6,16 @@ class SafeUnpickler(pickle.Unpickler):
|
||||
def get_safe_classes(self):
|
||||
from crazy_functions.latex_fns.latex_actions import LatexPaperFileGroup, LatexPaperSplit
|
||||
from crazy_functions.latex_fns.latex_toolbox import LinkedListNode
|
||||
from numpy.core.multiarray import scalar
|
||||
from numpy import dtype
|
||||
# 定义允许的安全类
|
||||
safe_classes = {
|
||||
# 在这里添加其他安全的类
|
||||
'LatexPaperFileGroup': LatexPaperFileGroup,
|
||||
'LatexPaperSplit': LatexPaperSplit,
|
||||
'LinkedListNode': LinkedListNode,
|
||||
'scalar': scalar,
|
||||
'dtype': dtype,
|
||||
}
|
||||
return safe_classes
|
||||
|
||||
@@ -22,8 +26,6 @@ class SafeUnpickler(pickle.Unpickler):
|
||||
for class_name in self.safe_classes.keys():
|
||||
if (class_name in f'{module}.{name}'):
|
||||
match_class_name = class_name
|
||||
if module == 'numpy' or module.startswith('numpy.'):
|
||||
return super().find_class(module, name)
|
||||
if match_class_name is not None:
|
||||
return self.safe_classes[match_class_name]
|
||||
# 如果尝试加载未授权的类,则抛出异常
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import os, shutil
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
PRESERVE = 0
|
||||
TRANSFORM = 1
|
||||
@@ -55,7 +57,7 @@ def post_process(root):
|
||||
str_stack.append("{")
|
||||
elif c == "}":
|
||||
if len(str_stack) == 1:
|
||||
print("stack fix")
|
||||
logger.warning("fixing brace error")
|
||||
return i
|
||||
str_stack.pop(-1)
|
||||
else:
|
||||
@@ -601,7 +603,7 @@ def compile_latex_with_timeout(command, cwd, timeout=60):
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
stdout, stderr = process.communicate()
|
||||
print("Process timed out!")
|
||||
logger.error("Process timed out (compile_latex_with_timeout)!")
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -642,6 +644,216 @@ def run_in_subprocess(func):
|
||||
|
||||
|
||||
def _merge_pdfs(pdf1_path, pdf2_path, output_path):
|
||||
try:
|
||||
logger.info("Merging PDFs using _merge_pdfs_ng")
|
||||
_merge_pdfs_ng(pdf1_path, pdf2_path, output_path)
|
||||
except:
|
||||
logger.info("Merging PDFs using _merge_pdfs_legacy")
|
||||
_merge_pdfs_legacy(pdf1_path, pdf2_path, output_path)
|
||||
|
||||
|
||||
def _merge_pdfs_ng(pdf1_path, pdf2_path, output_path):
|
||||
import PyPDF2 # PyPDF2这个库有严重的内存泄露问题,把它放到子进程中运行,从而方便内存的释放
|
||||
from PyPDF2.generic import NameObject, TextStringObject, ArrayObject, FloatObject, NumberObject
|
||||
|
||||
Percent = 1
|
||||
# raise RuntimeError('PyPDF2 has a serious memory leak problem, please use other tools to merge PDF files.')
|
||||
# Open the first PDF file
|
||||
with open(pdf1_path, "rb") as pdf1_file:
|
||||
pdf1_reader = PyPDF2.PdfFileReader(pdf1_file)
|
||||
# Open the second PDF file
|
||||
with open(pdf2_path, "rb") as pdf2_file:
|
||||
pdf2_reader = PyPDF2.PdfFileReader(pdf2_file)
|
||||
# Create a new PDF file to store the merged pages
|
||||
output_writer = PyPDF2.PdfFileWriter()
|
||||
# Determine the number of pages in each PDF file
|
||||
num_pages = max(pdf1_reader.numPages, pdf2_reader.numPages)
|
||||
# Merge the pages from the two PDF files
|
||||
for page_num in range(num_pages):
|
||||
# Add the page from the first PDF file
|
||||
if page_num < pdf1_reader.numPages:
|
||||
page1 = pdf1_reader.getPage(page_num)
|
||||
else:
|
||||
page1 = PyPDF2.PageObject.createBlankPage(pdf1_reader)
|
||||
# Add the page from the second PDF file
|
||||
if page_num < pdf2_reader.numPages:
|
||||
page2 = pdf2_reader.getPage(page_num)
|
||||
else:
|
||||
page2 = PyPDF2.PageObject.createBlankPage(pdf1_reader)
|
||||
# Create a new empty page with double width
|
||||
new_page = PyPDF2.PageObject.createBlankPage(
|
||||
width=int(
|
||||
int(page1.mediaBox.getWidth())
|
||||
+ int(page2.mediaBox.getWidth()) * Percent
|
||||
),
|
||||
height=max(page1.mediaBox.getHeight(), page2.mediaBox.getHeight()),
|
||||
)
|
||||
new_page.mergeTranslatedPage(page1, 0, 0)
|
||||
new_page.mergeTranslatedPage(
|
||||
page2,
|
||||
int(
|
||||
int(page1.mediaBox.getWidth())
|
||||
- int(page2.mediaBox.getWidth()) * (1 - Percent)
|
||||
),
|
||||
0,
|
||||
)
|
||||
if "/Annots" in new_page:
|
||||
annotations = new_page["/Annots"]
|
||||
for i, annot in enumerate(annotations):
|
||||
annot_obj = annot.get_object()
|
||||
|
||||
# 检查注释类型是否是链接(/Link)
|
||||
if annot_obj.get("/Subtype") == "/Link":
|
||||
# 检查是否为内部链接跳转(/GoTo)或外部URI链接(/URI)
|
||||
action = annot_obj.get("/A")
|
||||
if action:
|
||||
|
||||
if "/S" in action and action["/S"] == "/GoTo":
|
||||
# 内部链接:跳转到文档中的某个页面
|
||||
dest = action.get("/D") # 目标页或目标位置
|
||||
# if dest and annot.idnum in page2_annot_id:
|
||||
# if dest in pdf2_reader.named_destinations:
|
||||
if dest and page2.annotations:
|
||||
if annot in page2.annotations:
|
||||
# 获取原始文件中跳转信息,包括跳转页面
|
||||
destination = pdf2_reader.named_destinations[
|
||||
dest
|
||||
]
|
||||
page_number = (
|
||||
pdf2_reader.get_destination_page_number(
|
||||
destination
|
||||
)
|
||||
)
|
||||
# 更新跳转信息,跳转到对应的页面和,指定坐标 (100, 150),缩放比例为 100%
|
||||
# “/D”:[10,'/XYZ',100,100,0]
|
||||
if destination.dest_array[1] == "/XYZ":
|
||||
annot_obj["/A"].update(
|
||||
{
|
||||
NameObject("/D"): ArrayObject(
|
||||
[
|
||||
NumberObject(page_number),
|
||||
destination.dest_array[1],
|
||||
FloatObject(
|
||||
destination.dest_array[
|
||||
2
|
||||
]
|
||||
+ int(
|
||||
page1.mediaBox.getWidth()
|
||||
)
|
||||
),
|
||||
destination.dest_array[3],
|
||||
destination.dest_array[4],
|
||||
]
|
||||
) # 确保键和值是 PdfObject
|
||||
}
|
||||
)
|
||||
else:
|
||||
annot_obj["/A"].update(
|
||||
{
|
||||
NameObject("/D"): ArrayObject(
|
||||
[
|
||||
NumberObject(page_number),
|
||||
destination.dest_array[1],
|
||||
]
|
||||
) # 确保键和值是 PdfObject
|
||||
}
|
||||
)
|
||||
|
||||
rect = annot_obj.get("/Rect")
|
||||
# 更新点击坐标
|
||||
rect = ArrayObject(
|
||||
[
|
||||
FloatObject(
|
||||
rect[0]
|
||||
+ int(page1.mediaBox.getWidth())
|
||||
),
|
||||
rect[1],
|
||||
FloatObject(
|
||||
rect[2]
|
||||
+ int(page1.mediaBox.getWidth())
|
||||
),
|
||||
rect[3],
|
||||
]
|
||||
)
|
||||
annot_obj.update(
|
||||
{
|
||||
NameObject(
|
||||
"/Rect"
|
||||
): rect # 确保键和值是 PdfObject
|
||||
}
|
||||
)
|
||||
# if dest and annot.idnum in page1_annot_id:
|
||||
# if dest in pdf1_reader.named_destinations:
|
||||
if dest and page1.annotations:
|
||||
if annot in page1.annotations:
|
||||
# 获取原始文件中跳转信息,包括跳转页面
|
||||
destination = pdf1_reader.named_destinations[
|
||||
dest
|
||||
]
|
||||
page_number = (
|
||||
pdf1_reader.get_destination_page_number(
|
||||
destination
|
||||
)
|
||||
)
|
||||
# 更新跳转信息,跳转到对应的页面和,指定坐标 (100, 150),缩放比例为 100%
|
||||
# “/D”:[10,'/XYZ',100,100,0]
|
||||
if destination.dest_array[1] == "/XYZ":
|
||||
annot_obj["/A"].update(
|
||||
{
|
||||
NameObject("/D"): ArrayObject(
|
||||
[
|
||||
NumberObject(page_number),
|
||||
destination.dest_array[1],
|
||||
FloatObject(
|
||||
destination.dest_array[
|
||||
2
|
||||
]
|
||||
),
|
||||
destination.dest_array[3],
|
||||
destination.dest_array[4],
|
||||
]
|
||||
) # 确保键和值是 PdfObject
|
||||
}
|
||||
)
|
||||
else:
|
||||
annot_obj["/A"].update(
|
||||
{
|
||||
NameObject("/D"): ArrayObject(
|
||||
[
|
||||
NumberObject(page_number),
|
||||
destination.dest_array[1],
|
||||
]
|
||||
) # 确保键和值是 PdfObject
|
||||
}
|
||||
)
|
||||
|
||||
rect = annot_obj.get("/Rect")
|
||||
rect = ArrayObject(
|
||||
[
|
||||
FloatObject(rect[0]),
|
||||
rect[1],
|
||||
FloatObject(rect[2]),
|
||||
rect[3],
|
||||
]
|
||||
)
|
||||
annot_obj.update(
|
||||
{
|
||||
NameObject(
|
||||
"/Rect"
|
||||
): rect # 确保键和值是 PdfObject
|
||||
}
|
||||
)
|
||||
|
||||
elif "/S" in action and action["/S"] == "/URI":
|
||||
# 外部链接:跳转到某个URI
|
||||
uri = action.get("/URI")
|
||||
output_writer.addPage(new_page)
|
||||
# Save the merged PDF file
|
||||
with open(output_path, "wb") as output_file:
|
||||
output_writer.write(output_file)
|
||||
|
||||
|
||||
def _merge_pdfs_legacy(pdf1_path, pdf2_path, output_path):
|
||||
import PyPDF2 # PyPDF2这个库有严重的内存泄露问题,把它放到子进程中运行,从而方便内存的释放
|
||||
|
||||
Percent = 0.95
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import time, logging, json, sys, struct
|
||||
import time, json, sys, struct
|
||||
import numpy as np
|
||||
from loguru import logger as logging
|
||||
from scipy.io.wavfile import WAVE_FORMAT
|
||||
|
||||
def write_numpy_to_wave(filename, rate, data, add_header=False):
|
||||
@@ -106,18 +107,14 @@ def is_speaker_speaking(vad, data, sample_rate):
|
||||
class AliyunASR():
|
||||
|
||||
def test_on_sentence_begin(self, message, *args):
|
||||
# print("test_on_sentence_begin:{}".format(message))
|
||||
pass
|
||||
|
||||
def test_on_sentence_end(self, message, *args):
|
||||
# print("test_on_sentence_end:{}".format(message))
|
||||
message = json.loads(message)
|
||||
self.parsed_sentence = message['payload']['result']
|
||||
self.event_on_entence_end.set()
|
||||
# print(self.parsed_sentence)
|
||||
|
||||
def test_on_start(self, message, *args):
|
||||
# print("test_on_start:{}".format(message))
|
||||
pass
|
||||
|
||||
def test_on_error(self, message, *args):
|
||||
@@ -129,13 +126,11 @@ class AliyunASR():
|
||||
pass
|
||||
|
||||
def test_on_result_chg(self, message, *args):
|
||||
# print("test_on_chg:{}".format(message))
|
||||
message = json.loads(message)
|
||||
self.parsed_text = message['payload']['result']
|
||||
self.event_on_result_chg.set()
|
||||
|
||||
def test_on_completed(self, message, *args):
|
||||
# print("on_completed:args=>{} message=>{}".format(args, message))
|
||||
pass
|
||||
|
||||
def audio_convertion_thread(self, uuid):
|
||||
@@ -248,14 +243,14 @@ class AliyunASR():
|
||||
|
||||
try:
|
||||
response = client.do_action_with_exception(request)
|
||||
print(response)
|
||||
logging.info(response)
|
||||
jss = json.loads(response)
|
||||
if 'Token' in jss and 'Id' in jss['Token']:
|
||||
token = jss['Token']['Id']
|
||||
expireTime = jss['Token']['ExpireTime']
|
||||
print("token = " + token)
|
||||
print("expireTime = " + str(expireTime))
|
||||
logging.info("token = " + token)
|
||||
logging.info("expireTime = " + str(expireTime))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
logging.error(e)
|
||||
|
||||
return token
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from crazy_functions.ipc_fns.mp import run_in_subprocess_with_timeout
|
||||
from loguru import logger
|
||||
|
||||
def force_breakdown(txt, limit, get_token_fn):
|
||||
""" 当无法用标点、空行分割时,我们用最暴力的方法切割
|
||||
@@ -76,7 +77,7 @@ def cut(limit, get_token_fn, txt_tocut, must_break_at_empty_line, break_anyway=F
|
||||
remain_txt_to_cut = post
|
||||
remain_txt_to_cut, remain_txt_to_cut_storage = maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage)
|
||||
process = fin_len/total_len
|
||||
print(f'正在文本切分 {int(process*100)}%')
|
||||
logger.info(f'正在文本切分 {int(process*100)}%')
|
||||
if len(remain_txt_to_cut.strip()) == 0:
|
||||
break
|
||||
return res
|
||||
@@ -119,7 +120,7 @@ if __name__ == '__main__':
|
||||
for i in range(5):
|
||||
file_content += file_content
|
||||
|
||||
print(len(file_content))
|
||||
logger.info(len(file_content))
|
||||
TOKEN_LIMIT_PER_FRAGMENT = 2500
|
||||
res = breakdown_text_to_satisfy_token_limit(file_content, TOKEN_LIMIT_PER_FRAGMENT)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_
|
||||
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||
from crazy_functions.crazy_utils import read_and_clean_pdf_text
|
||||
from shared_utils.colorful import *
|
||||
from loguru import logger
|
||||
import os
|
||||
|
||||
def 解析PDF_简单拆解(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
||||
@@ -93,7 +94,7 @@ def 解析PDF_简单拆解(file_manifest, project_folder, llm_kwargs, plugin_kwa
|
||||
generated_html_files.append(ch.save_file(create_report_file_name))
|
||||
except:
|
||||
from toolbox import trimmed_format_exc
|
||||
print('writing html result failed:', trimmed_format_exc())
|
||||
logger.error('writing html result failed:', trimmed_format_exc())
|
||||
|
||||
# 准备文件的下载
|
||||
for pdf_path in generated_conclusion_files:
|
||||
|
||||
@@ -4,7 +4,9 @@ from toolbox import promote_file_to_downloadzone, extract_archive
|
||||
from toolbox import generate_file_link, zip_folder
|
||||
from crazy_functions.crazy_utils import get_files_from_everything
|
||||
from shared_utils.colorful import *
|
||||
from loguru import logger
|
||||
import os
|
||||
import time
|
||||
|
||||
def refresh_key(doc2x_api_key):
|
||||
import requests, json
|
||||
@@ -22,105 +24,140 @@ def refresh_key(doc2x_api_key):
|
||||
raise RuntimeError(format("[ERROR] status code: %d, body: %s" % (res.status_code, res.text)))
|
||||
return doc2x_api_key
|
||||
|
||||
|
||||
|
||||
def 解析PDF_DOC2X_转Latex(pdf_file_path):
|
||||
zip_file_path, unzipped_folder = 解析PDF_DOC2X(pdf_file_path, format='tex')
|
||||
return unzipped_folder
|
||||
|
||||
|
||||
def 解析PDF_DOC2X(pdf_file_path, format='tex'):
|
||||
"""
|
||||
format: 'tex', 'md', 'docx'
|
||||
"""
|
||||
import requests, json, os
|
||||
DOC2X_API_KEY = get_conf('DOC2X_API_KEY')
|
||||
latex_dir = get_log_folder(plugin_name="pdf_ocr_latex")
|
||||
markdown_dir = get_log_folder(plugin_name="pdf_ocr")
|
||||
doc2x_api_key = DOC2X_API_KEY
|
||||
if doc2x_api_key.startswith('sk-'):
|
||||
url = "https://api.doc2x.noedgeai.com/api/v1/pdf"
|
||||
else:
|
||||
doc2x_api_key = refresh_key(doc2x_api_key)
|
||||
url = "https://api.doc2x.noedgeai.com/api/platform/pdf"
|
||||
|
||||
|
||||
# < ------ 第1步:上传 ------ >
|
||||
logger.info("Doc2x 第1步:上传")
|
||||
with open(pdf_file_path, 'rb') as file:
|
||||
res = requests.post(
|
||||
"https://v2.doc2x.noedgeai.com/api/v2/parse/pdf",
|
||||
headers={"Authorization": "Bearer " + doc2x_api_key},
|
||||
data=file
|
||||
)
|
||||
# res_json = []
|
||||
if res.status_code == 200:
|
||||
res_json = res.json()
|
||||
else:
|
||||
raise RuntimeError(f"Doc2x return an error: {res.json()}")
|
||||
uuid = res_json['data']['uid']
|
||||
|
||||
# < ------ 第2步:轮询等待 ------ >
|
||||
logger.info("Doc2x 第2步:轮询等待")
|
||||
params = {'uid': uuid}
|
||||
while True:
|
||||
res = requests.get(
|
||||
'https://v2.doc2x.noedgeai.com/api/v2/parse/status',
|
||||
headers={"Authorization": "Bearer " + doc2x_api_key},
|
||||
params=params
|
||||
)
|
||||
res_json = res.json()
|
||||
if res_json['data']['status'] == "success":
|
||||
break
|
||||
elif res_json['data']['status'] == "processing":
|
||||
time.sleep(3)
|
||||
logger.info(f"Doc2x is processing at {res_json['data']['progress']}%")
|
||||
elif res_json['data']['status'] == "failed":
|
||||
raise RuntimeError(f"Doc2x return an error: {res_json}")
|
||||
|
||||
|
||||
# < ------ 第3步:提交转化 ------ >
|
||||
logger.info("Doc2x 第3步:提交转化")
|
||||
data = {
|
||||
"uid": uuid,
|
||||
"to": format,
|
||||
"formula_mode": "dollar",
|
||||
"filename": "output"
|
||||
}
|
||||
res = requests.post(
|
||||
url,
|
||||
files={"file": open(pdf_file_path, "rb")},
|
||||
data={"ocr": "1"},
|
||||
headers={"Authorization": "Bearer " + doc2x_api_key}
|
||||
'https://v2.doc2x.noedgeai.com/api/v2/convert/parse',
|
||||
headers={"Authorization": "Bearer " + doc2x_api_key},
|
||||
json=data
|
||||
)
|
||||
res_json = []
|
||||
if res.status_code == 200:
|
||||
decoded = res.content.decode("utf-8")
|
||||
for z_decoded in decoded.split('\n'):
|
||||
if len(z_decoded) == 0: continue
|
||||
assert z_decoded.startswith("data: ")
|
||||
z_decoded = z_decoded[len("data: "):]
|
||||
decoded_json = json.loads(z_decoded)
|
||||
res_json.append(decoded_json)
|
||||
res_json = res.json()
|
||||
else:
|
||||
raise RuntimeError(format("[ERROR] status code: %d, body: %s" % (res.status_code, res.text)))
|
||||
raise RuntimeError(f"Doc2x return an error: {res.json()}")
|
||||
|
||||
uuid = res_json[0]['uuid']
|
||||
to = "latex" # latex, md, docx
|
||||
url = "https://api.doc2x.noedgeai.com/api/export"+"?request_id="+uuid+"&to="+to
|
||||
|
||||
res = requests.get(url, headers={"Authorization": "Bearer " + doc2x_api_key})
|
||||
latex_zip_path = os.path.join(latex_dir, gen_time_str() + '.zip')
|
||||
latex_unzip_path = os.path.join(latex_dir, gen_time_str())
|
||||
if res.status_code == 200:
|
||||
with open(latex_zip_path, "wb") as f: f.write(res.content)
|
||||
else:
|
||||
raise RuntimeError(format("[ERROR] status code: %d, body: %s" % (res.status_code, res.text)))
|
||||
# < ------ 第4步:等待结果 ------ >
|
||||
logger.info("Doc2x 第4步:等待结果")
|
||||
params = {'uid': uuid}
|
||||
while True:
|
||||
res = requests.get(
|
||||
'https://v2.doc2x.noedgeai.com/api/v2/convert/parse/result',
|
||||
headers={"Authorization": "Bearer " + doc2x_api_key},
|
||||
params=params
|
||||
)
|
||||
res_json = res.json()
|
||||
if res_json['data']['status'] == "success":
|
||||
break
|
||||
elif res_json['data']['status'] == "processing":
|
||||
time.sleep(3)
|
||||
logger.info(f"Doc2x still processing")
|
||||
elif res_json['data']['status'] == "failed":
|
||||
raise RuntimeError(f"Doc2x return an error: {res_json}")
|
||||
|
||||
|
||||
# < ------ 第5步:最后的处理 ------ >
|
||||
logger.info("Doc2x 第5步:最后的处理")
|
||||
|
||||
if format=='tex':
|
||||
target_path = latex_dir
|
||||
if format=='md':
|
||||
target_path = markdown_dir
|
||||
os.makedirs(target_path, exist_ok=True)
|
||||
|
||||
max_attempt = 3
|
||||
# < ------ 下载 ------ >
|
||||
for attempt in range(max_attempt):
|
||||
try:
|
||||
result_url = res_json['data']['url']
|
||||
res = requests.get(result_url)
|
||||
zip_path = os.path.join(target_path, gen_time_str() + '.zip')
|
||||
unzip_path = os.path.join(target_path, gen_time_str())
|
||||
if res.status_code == 200:
|
||||
with open(zip_path, "wb") as f: f.write(res.content)
|
||||
else:
|
||||
raise RuntimeError(f"Doc2x return an error: {res.json()}")
|
||||
except Exception as e:
|
||||
if attempt < max_attempt - 1:
|
||||
logger.error(f"Failed to download latex file, retrying... {e}")
|
||||
time.sleep(3)
|
||||
continue
|
||||
else:
|
||||
raise e
|
||||
|
||||
# < ------ 解压 ------ >
|
||||
import zipfile
|
||||
with zipfile.ZipFile(latex_zip_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(latex_unzip_path)
|
||||
|
||||
|
||||
return latex_unzip_path
|
||||
|
||||
|
||||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(unzip_path)
|
||||
return zip_path, unzip_path
|
||||
|
||||
|
||||
def 解析PDF_DOC2X_单文件(fp, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, DOC2X_API_KEY, user_request):
|
||||
|
||||
|
||||
def pdf2markdown(filepath):
|
||||
import requests, json, os
|
||||
markdown_dir = get_log_folder(plugin_name="pdf_ocr")
|
||||
doc2x_api_key = DOC2X_API_KEY
|
||||
if doc2x_api_key.startswith('sk-'):
|
||||
url = "https://api.doc2x.noedgeai.com/api/v1/pdf"
|
||||
else:
|
||||
doc2x_api_key = refresh_key(doc2x_api_key)
|
||||
url = "https://api.doc2x.noedgeai.com/api/platform/pdf"
|
||||
|
||||
chatbot.append((None, "加载PDF文件,发送至DOC2X解析..."))
|
||||
chatbot.append((None, f"Doc2x 解析中"))
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
res = requests.post(
|
||||
url,
|
||||
files={"file": open(filepath, "rb")},
|
||||
data={"ocr": "1"},
|
||||
headers={"Authorization": "Bearer " + doc2x_api_key}
|
||||
)
|
||||
res_json = []
|
||||
if res.status_code == 200:
|
||||
decoded = res.content.decode("utf-8")
|
||||
for z_decoded in decoded.split('\n'):
|
||||
if len(z_decoded) == 0: continue
|
||||
assert z_decoded.startswith("data: ")
|
||||
z_decoded = z_decoded[len("data: "):]
|
||||
decoded_json = json.loads(z_decoded)
|
||||
res_json.append(decoded_json)
|
||||
if 'limit exceeded' in decoded_json.get('status', ''):
|
||||
raise RuntimeError("Doc2x API 页数受限,请联系 Doc2x 方面,并更换新的 API 秘钥。")
|
||||
else:
|
||||
raise RuntimeError(format("[ERROR] status code: %d, body: %s" % (res.status_code, res.text)))
|
||||
uuid = res_json[0]['uuid']
|
||||
to = "md" # latex, md, docx
|
||||
url = "https://api.doc2x.noedgeai.com/api/export"+"?request_id="+uuid+"&to="+to
|
||||
md_zip_path, unzipped_folder = 解析PDF_DOC2X(filepath, format='md')
|
||||
|
||||
chatbot.append((None, f"读取解析: {url} ..."))
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
res = requests.get(url, headers={"Authorization": "Bearer " + doc2x_api_key})
|
||||
md_zip_path = os.path.join(markdown_dir, gen_time_str() + '.zip')
|
||||
if res.status_code == 200:
|
||||
with open(md_zip_path, "wb") as f: f.write(res.content)
|
||||
else:
|
||||
raise RuntimeError(format("[ERROR] status code: %d, body: %s" % (res.status_code, res.text)))
|
||||
promote_file_to_downloadzone(md_zip_path, chatbot=chatbot)
|
||||
chatbot.append((None, f"完成解析 {md_zip_path} ..."))
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
import logging
|
||||
import tarfile
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class ArxivDownloader:
|
||||
"""用于下载arXiv论文源码的下载器"""
|
||||
|
||||
def __init__(self, root_dir: str = "./papers", proxies: Optional[Dict[str, str]] = None):
|
||||
"""
|
||||
初始化下载器
|
||||
|
||||
Args:
|
||||
root_dir: 保存下载文件的根目录
|
||||
proxies: 代理服务器设置,例如 {"http": "http://proxy:port", "https": "https://proxy:port"}
|
||||
"""
|
||||
self.root_dir = Path(root_dir)
|
||||
self.root_dir.mkdir(exist_ok=True)
|
||||
self.proxies = proxies
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
def _download_and_extract(self, arxiv_id: str) -> str:
|
||||
"""
|
||||
下载并解压arxiv论文源码
|
||||
|
||||
Args:
|
||||
arxiv_id: arXiv论文ID,例如"2103.00020"
|
||||
|
||||
Returns:
|
||||
str: 解压后的文件目录路径
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当下载失败时抛出
|
||||
"""
|
||||
paper_dir = self.root_dir / arxiv_id
|
||||
tar_path = paper_dir / f"{arxiv_id}.tar.gz"
|
||||
|
||||
# 检查缓存
|
||||
if paper_dir.exists() and any(paper_dir.iterdir()):
|
||||
logging.info(f"Using cached version for {arxiv_id}")
|
||||
return str(paper_dir)
|
||||
|
||||
paper_dir.mkdir(exist_ok=True)
|
||||
|
||||
urls = [
|
||||
f"https://arxiv.org/src/{arxiv_id}",
|
||||
f"https://arxiv.org/e-print/{arxiv_id}"
|
||||
]
|
||||
|
||||
for url in urls:
|
||||
try:
|
||||
logging.info(f"Downloading from {url}")
|
||||
response = requests.get(url, proxies=self.proxies)
|
||||
if response.status_code == 200:
|
||||
tar_path.write_bytes(response.content)
|
||||
with tarfile.open(tar_path, 'r:gz') as tar:
|
||||
tar.extractall(path=paper_dir)
|
||||
return str(paper_dir)
|
||||
except Exception as e:
|
||||
logging.warning(f"Download failed for {url}: {e}")
|
||||
continue
|
||||
|
||||
raise RuntimeError(f"Failed to download paper {arxiv_id}")
|
||||
|
||||
def download_paper(self, arxiv_id: str) -> str:
|
||||
"""
|
||||
下载指定的arXiv论文
|
||||
|
||||
Args:
|
||||
arxiv_id: arXiv论文ID
|
||||
|
||||
Returns:
|
||||
str: 论文文件所在的目录路径
|
||||
"""
|
||||
return self._download_and_extract(arxiv_id)
|
||||
|
||||
|
||||
def main():
|
||||
"""测试下载功能"""
|
||||
# 配置代理(如果需要)
|
||||
proxies = {
|
||||
"http": "http://your-proxy:port",
|
||||
"https": "https://your-proxy:port"
|
||||
}
|
||||
|
||||
# 创建下载器实例(如果不需要代理,可以不传入proxies参数)
|
||||
downloader = ArxivDownloader(root_dir="./downloaded_papers", proxies=None)
|
||||
|
||||
# 测试下载一篇论文(这里使用一个示例ID)
|
||||
try:
|
||||
paper_id = "2103.00020" # 这是一个示例ID
|
||||
paper_dir = downloader.download_paper(paper_id)
|
||||
print(f"Successfully downloaded paper to: {paper_dir}")
|
||||
|
||||
# 检查下载的文件
|
||||
paper_path = Path(paper_dir)
|
||||
if paper_path.exists():
|
||||
print("Downloaded files:")
|
||||
for file in paper_path.rglob("*"):
|
||||
if file.is_file():
|
||||
print(f"- {file.relative_to(paper_path)}")
|
||||
except Exception as e:
|
||||
print(f"Error downloading paper: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,836 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import tarfile
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict, Set
|
||||
|
||||
import aiohttp
|
||||
|
||||
from crazy_functions.rag_fns.arxiv_fns.author_extractor import LatexAuthorExtractor
|
||||
from crazy_functions.rag_fns.arxiv_fns.essay_structure import EssayStructureParser, DocumentStructure, read_tex_file
|
||||
from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section
|
||||
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment
|
||||
from crazy_functions.rag_fns.arxiv_fns.tex_utils import TexUtils
|
||||
from crazy_functions.doc_fns.content_folder import PaperContentFormatter, PaperMetadata
|
||||
|
||||
|
||||
def save_fragments_to_file(fragments: List[SectionFragment], output_dir: Path ) -> Path:
|
||||
"""
|
||||
Save all fragments to a single structured markdown file.
|
||||
|
||||
Args:
|
||||
fragments: List of SectionFragment objects
|
||||
output_dir: Output directory path
|
||||
|
||||
Returns:
|
||||
Path: Path to the generated markdown file
|
||||
"""
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
# Create output directory
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate filename
|
||||
filename = f"paper_latex_content_{timestamp}.md"
|
||||
file_path = output_path/ filename
|
||||
|
||||
# Group fragments by section
|
||||
sections = {}
|
||||
for fragment in fragments:
|
||||
section = fragment.current_section or "Uncategorized"
|
||||
if section not in sections:
|
||||
sections[section] = []
|
||||
sections[section].append(fragment)
|
||||
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
# Write document header
|
||||
f.write("# Document Fragments Analysis\n\n")
|
||||
f.write("## Overview\n")
|
||||
f.write(f"- Total Fragments: {len(fragments)}\n")
|
||||
f.write(f"- Generated Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
|
||||
# Add paper information if available
|
||||
if fragments and (fragments[0].title or fragments[0].abstract):
|
||||
f.write("\n## Paper Information\n")
|
||||
if fragments[0].title:
|
||||
f.write(f"### Title\n{fragments[0].title}\n")
|
||||
if fragments[0].authors:
|
||||
f.write(f"\n### Authors\n{fragments[0].authors}\n")
|
||||
if fragments[0].abstract:
|
||||
f.write(f"\n### Abstract\n{fragments[0].abstract}\n")
|
||||
|
||||
# Write section tree if available
|
||||
if fragments and fragments[0].catalogs:
|
||||
f.write("\n## Section Tree\n")
|
||||
f.write("```\n") # 添加代码块开始标记
|
||||
f.write(fragments[0].catalogs)
|
||||
f.write("\n```") # 添加代码块结束标记
|
||||
|
||||
# Generate table of contents
|
||||
f.write("\n## Table of Contents\n")
|
||||
for section in sections:
|
||||
clean_section = section.strip() or "Uncategorized"
|
||||
fragment_count = len(sections[section])
|
||||
f.write(f"- [{clean_section}](#{clean_section.lower().replace(' ', '-')}) "
|
||||
f"({fragment_count} fragments)\n")
|
||||
|
||||
# Write content sections
|
||||
f.write("\n## Content\n")
|
||||
for section, section_fragments in sections.items():
|
||||
clean_section = section.strip() or "Uncategorized"
|
||||
f.write(f"\n### {clean_section}\n")
|
||||
|
||||
# Write each fragment
|
||||
for i, fragment in enumerate(section_fragments, 1):
|
||||
f.write(f"\n#### Fragment {i}\n")
|
||||
|
||||
# Metadata
|
||||
f.write("**Metadata:**\n")
|
||||
metadata = [
|
||||
f"- Section: {fragment.current_section}",
|
||||
f"- Length: {len(fragment.content)} chars",
|
||||
f"- ArXiv ID: {fragment.arxiv_id}" if fragment.arxiv_id else None
|
||||
]
|
||||
f.write("\n".join(filter(None, metadata)) + "\n")
|
||||
|
||||
# Content
|
||||
f.write("\n**Content:**\n")
|
||||
f.write("\n")
|
||||
f.write(fragment.content)
|
||||
f.write("\n")
|
||||
|
||||
# Bibliography if exists
|
||||
if fragment.bibliography:
|
||||
f.write("\n**Bibliography:**\n")
|
||||
f.write("```bibtex\n")
|
||||
f.write(fragment.bibliography)
|
||||
f.write("\n```\n")
|
||||
|
||||
# Add separator
|
||||
if i < len(section_fragments):
|
||||
f.write("\n---\n")
|
||||
|
||||
# Add statistics
|
||||
f.write("\n## Statistics\n")
|
||||
|
||||
# Length distribution
|
||||
lengths = [len(f.content) for f in fragments]
|
||||
f.write("\n### Length Distribution\n")
|
||||
f.write(f"- Minimum: {min(lengths)} chars\n")
|
||||
f.write(f"- Maximum: {max(lengths)} chars\n")
|
||||
f.write(f"- Average: {sum(lengths) / len(lengths):.1f} chars\n")
|
||||
|
||||
# Section distribution
|
||||
f.write("\n### Section Distribution\n")
|
||||
for section, section_fragments in sections.items():
|
||||
percentage = (len(section_fragments) / len(fragments)) * 100
|
||||
f.write(f"- {section}: {len(section_fragments)} ({percentage:.1f}%)\n")
|
||||
|
||||
print(f"Fragments saved to: {file_path}")
|
||||
return file_path
|
||||
|
||||
|
||||
# 定义各种引用命令的模式
|
||||
CITATION_PATTERNS = [
|
||||
# 基本的 \cite{} 格式
|
||||
r'\\cite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||
# natbib 格式
|
||||
r'\\citep(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||
r'\\citet(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||
r'\\citeauthor(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||
r'\\citeyear(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||
r'\\citealt(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||
r'\\citealp(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||
# biblatex 格式
|
||||
r'\\textcite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||
r'\\parencite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||
r'\\autocite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||
# 自定义 [cite:...] 格式
|
||||
r'\[cite:([^\]]+)\]',
|
||||
]
|
||||
|
||||
# 编译所有模式
|
||||
COMPILED_PATTERNS = [re.compile(pattern) for pattern in CITATION_PATTERNS]
|
||||
|
||||
|
||||
class ArxivSplitter:
|
||||
"""Arxiv论文智能分割器"""
|
||||
|
||||
def __init__(self,
|
||||
root_dir: str = "gpt_log/arxiv_cache",
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
cache_ttl: int = 7 * 24 * 60 * 60):
|
||||
"""
|
||||
初始化分割器
|
||||
|
||||
Args:
|
||||
char_range: 字符数范围(最小值, 最大值)
|
||||
root_dir: 缓存根目录
|
||||
proxies: 代理设置
|
||||
cache_ttl: 缓存过期时间(秒)
|
||||
"""
|
||||
self.root_dir = Path(root_dir)
|
||||
self.root_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.proxies = proxies or {}
|
||||
self.cache_ttl = cache_ttl
|
||||
|
||||
# 动态计算最优线程数
|
||||
import multiprocessing
|
||||
cpu_count = multiprocessing.cpu_count()
|
||||
# 根据CPU核心数动态设置,但设置上限防止过度并发
|
||||
self.document_structure = DocumentStructure()
|
||||
self.document_parser = EssayStructureParser()
|
||||
|
||||
self.max_workers = min(32, cpu_count * 2)
|
||||
|
||||
# 初始化TeX处理器
|
||||
self.tex_processor = TexUtils()
|
||||
|
||||
# 配置日志
|
||||
self._setup_logging()
|
||||
|
||||
def _setup_logging(self):
|
||||
"""配置日志"""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def _normalize_arxiv_id(self, input_str: str) -> str:
|
||||
"""规范化ArXiv ID"""
|
||||
if 'arxiv.org/' in input_str.lower():
|
||||
# 处理URL格式
|
||||
if '/pdf/' in input_str:
|
||||
arxiv_id = input_str.split('/pdf/')[-1]
|
||||
else:
|
||||
arxiv_id = input_str.split('/abs/')[-1]
|
||||
# 移除版本号和其他后缀
|
||||
return arxiv_id.split('v')[0].strip()
|
||||
return input_str.split('v')[0].strip()
|
||||
|
||||
def _check_cache(self, paper_dir: Path) -> bool:
|
||||
"""
|
||||
检查缓存是否有效,包括文件完整性检查
|
||||
|
||||
Args:
|
||||
paper_dir: 论文目录路径
|
||||
|
||||
Returns:
|
||||
bool: 如果缓存有效返回True,否则返回False
|
||||
"""
|
||||
if not paper_dir.exists():
|
||||
return False
|
||||
|
||||
# 检查目录中是否存在必要文件
|
||||
has_tex_files = False
|
||||
has_main_tex = False
|
||||
|
||||
for file_path in paper_dir.rglob("*"):
|
||||
if file_path.suffix == '.tex':
|
||||
has_tex_files = True
|
||||
content = self.tex_processor.read_file(str(file_path))
|
||||
if content and r'\documentclass' in content:
|
||||
has_main_tex = True
|
||||
break
|
||||
|
||||
if not (has_tex_files and has_main_tex):
|
||||
return False
|
||||
|
||||
# 检查缓存时间
|
||||
cache_time = paper_dir.stat().st_mtime
|
||||
if (time.time() - cache_time) < self.cache_ttl:
|
||||
self.logger.info(f"Using valid cache for {paper_dir.name}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def download_paper(self, arxiv_id: str, paper_dir: Path) -> bool:
|
||||
"""
|
||||
异步下载论文,包含重试机制和临时文件处理
|
||||
|
||||
Args:
|
||||
arxiv_id: ArXiv论文ID
|
||||
paper_dir: 目标目录路径
|
||||
|
||||
Returns:
|
||||
bool: 下载成功返回True,否则返回False
|
||||
"""
|
||||
from crazy_functions.rag_fns.arxiv_fns.arxiv_downloader import ArxivDownloader
|
||||
temp_tar_path = paper_dir / f"{arxiv_id}_temp.tar.gz"
|
||||
final_tar_path = paper_dir / f"{arxiv_id}.tar.gz"
|
||||
|
||||
# 确保目录存在
|
||||
paper_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 尝试使用 ArxivDownloader 下载
|
||||
try:
|
||||
downloader = ArxivDownloader(root_dir=str(paper_dir), proxies=self.proxies)
|
||||
downloaded_dir = downloader.download_paper(arxiv_id)
|
||||
if downloaded_dir:
|
||||
self.logger.info(f"Successfully downloaded using ArxivDownloader to {downloaded_dir}")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.warning(f"ArxivDownloader failed: {str(e)}. Falling back to direct download.")
|
||||
|
||||
# 如果 ArxivDownloader 失败,使用原有的下载方式作为备选
|
||||
urls = [
|
||||
f"https://arxiv.org/src/{arxiv_id}",
|
||||
f"https://arxiv.org/e-print/{arxiv_id}"
|
||||
]
|
||||
|
||||
max_retries = 3
|
||||
retry_delay = 1 # 初始重试延迟(秒)
|
||||
|
||||
for url in urls:
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
self.logger.info(f"Downloading from {url} (attempt {attempt + 1}/{max_retries})")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, proxy=self.proxies.get('http')) as response:
|
||||
if response.status == 200:
|
||||
content = await response.read()
|
||||
|
||||
# 写入临时文件
|
||||
temp_tar_path.write_bytes(content)
|
||||
|
||||
try:
|
||||
# 验证tar文件完整性并解压
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, self._process_tar_file, temp_tar_path, paper_dir)
|
||||
|
||||
# 下载成功后移动临时文件到最终位置
|
||||
temp_tar_path.rename(final_tar_path)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Invalid tar file: {str(e)}")
|
||||
if temp_tar_path.exists():
|
||||
temp_tar_path.unlink()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Download attempt {attempt + 1} failed from {url}: {str(e)}")
|
||||
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
def _process_tar_file(self, tar_path: Path, extract_path: Path):
|
||||
"""处理tar文件的同步操作"""
|
||||
with tarfile.open(tar_path, 'r:gz') as tar:
|
||||
tar.testall() # 验证文件完整性
|
||||
tar.extractall(path=extract_path) # 解压文件
|
||||
|
||||
def process_references(self, doc_structure: DocumentStructure, ref_bib: str) -> DocumentStructure:
|
||||
"""
|
||||
Process citations in document structure and add referenced literature for each section
|
||||
|
||||
Args:
|
||||
doc_structure: DocumentStructure object
|
||||
ref_bib: String containing references separated by newlines
|
||||
|
||||
Returns:
|
||||
Updated DocumentStructure object
|
||||
"""
|
||||
try:
|
||||
# Create a copy to avoid modifying the original
|
||||
doc = deepcopy(doc_structure)
|
||||
|
||||
# Parse references into a mapping
|
||||
ref_map = self._parse_references(ref_bib)
|
||||
if not ref_map:
|
||||
self.logger.warning("No valid references found in ref_bib")
|
||||
return doc
|
||||
|
||||
# Process all sections recursively
|
||||
self._process_section_references(doc.toc, ref_map)
|
||||
|
||||
return doc
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error processing references: {str(e)}")
|
||||
return doc_structure # Return original if processing fails
|
||||
|
||||
def _process_section_references(self, sections: List[Section], ref_map: Dict[str, str]) -> None:
|
||||
"""
|
||||
Recursively process sections to add references
|
||||
|
||||
Args:
|
||||
sections: List of Section objects
|
||||
ref_map: Mapping of citation keys to full references
|
||||
"""
|
||||
for section in sections:
|
||||
if section.content:
|
||||
# Find citations in current section
|
||||
cited_refs = self.find_citations(section.content)
|
||||
|
||||
if cited_refs:
|
||||
# Get full references for citations
|
||||
full_refs = []
|
||||
for ref_key in cited_refs:
|
||||
ref_text = ref_map.get(ref_key)
|
||||
if ref_text:
|
||||
full_refs.append(ref_text)
|
||||
else:
|
||||
self.logger.warning(f"Reference not found for citation key: {ref_key}")
|
||||
|
||||
# Add references to section content
|
||||
if full_refs:
|
||||
section.bibliography = "\n\n".join(full_refs)
|
||||
|
||||
# Process subsections recursively
|
||||
if section.subsections:
|
||||
self._process_section_references(section.subsections, ref_map)
|
||||
|
||||
def _parse_references(self, ref_bib: str) -> Dict[str, str]:
|
||||
"""
|
||||
Parse reference string into a mapping of citation keys to full references
|
||||
|
||||
Args:
|
||||
ref_bib: Reference string with references separated by newlines
|
||||
|
||||
Returns:
|
||||
Dict mapping citation keys to full reference text
|
||||
"""
|
||||
ref_map = {}
|
||||
current_ref = []
|
||||
current_key = None
|
||||
|
||||
try:
|
||||
for line in ref_bib.split('\n'):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# New reference entry
|
||||
if line.startswith('@'):
|
||||
# Save previous reference if exists
|
||||
if current_key and current_ref:
|
||||
ref_map[current_key] = '\n'.join(current_ref)
|
||||
current_ref = []
|
||||
|
||||
# Extract key from new reference
|
||||
key_match = re.search(r'{(.*?),', line)
|
||||
if key_match:
|
||||
current_key = key_match.group(1)
|
||||
current_ref.append(line)
|
||||
else:
|
||||
if current_ref is not None:
|
||||
current_ref.append(line)
|
||||
|
||||
# Save last reference
|
||||
if current_key and current_ref:
|
||||
ref_map[current_key] = '\n'.join(current_ref)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error parsing references: {str(e)}")
|
||||
|
||||
return ref_map
|
||||
|
||||
# 编译一次正则表达式以提高效率
|
||||
|
||||
@staticmethod
|
||||
def _clean_citation_key(key: str) -> str:
|
||||
"""Clean individual citation key."""
|
||||
return key.strip().strip(',').strip()
|
||||
|
||||
def _extract_keys_from_group(self, keys_str: str) -> Set[str]:
|
||||
"""Extract and clean individual citation keys from a group."""
|
||||
try:
|
||||
# 分割多个引用键(支持逗号和分号分隔)
|
||||
separators = '[,;]'
|
||||
keys = re.split(separators, keys_str)
|
||||
# 清理并过滤空键
|
||||
return {self._clean_citation_key(k) for k in keys if self._clean_citation_key(k)}
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error processing citation group '{keys_str}': {e}")
|
||||
return set()
|
||||
|
||||
def find_citations(self, content: str) -> Set[str]:
|
||||
"""
|
||||
Find citation keys in text content in various formats.
|
||||
|
||||
Args:
|
||||
content: Text content to search for citations
|
||||
|
||||
Returns:
|
||||
Set of unique citation keys
|
||||
|
||||
Examples:
|
||||
Supported formats include:
|
||||
- \cite{key1,key2}
|
||||
- \cite[p. 1]{key}
|
||||
- \citep{key}
|
||||
- \citet{key}
|
||||
- [cite:key1, key2]
|
||||
- And many other variants
|
||||
"""
|
||||
citations = set()
|
||||
|
||||
if not content:
|
||||
return citations
|
||||
|
||||
try:
|
||||
# 对每个编译好的模式进行搜索
|
||||
for pattern in COMPILED_PATTERNS:
|
||||
matches = pattern.finditer(content)
|
||||
for match in matches:
|
||||
# 获取捕获组中的引用键
|
||||
keys_str = match.group(1)
|
||||
if keys_str:
|
||||
# 提取并添加所有引用键
|
||||
new_keys = self._extract_keys_from_group(keys_str)
|
||||
citations.update(new_keys)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error finding citations: {str(e)}")
|
||||
|
||||
# 移除明显无效的键
|
||||
citations = {key for key in citations
|
||||
if key and not key.startswith(('\\', '{', '}', '[', ']'))}
|
||||
|
||||
return citations
|
||||
|
||||
def get_citation_contexts(self, content: str, context_chars: int = 100) -> dict:
|
||||
"""
|
||||
Find citations and their surrounding context.
|
||||
|
||||
Args:
|
||||
content: Text content to search for citations
|
||||
context_chars: Number of characters of context to include before/after
|
||||
|
||||
Returns:
|
||||
Dict mapping citation keys to lists of context strings
|
||||
"""
|
||||
contexts = {}
|
||||
|
||||
if not content:
|
||||
return contexts
|
||||
|
||||
try:
|
||||
for pattern in COMPILED_PATTERNS:
|
||||
matches = pattern.finditer(content)
|
||||
for match in matches:
|
||||
# 获取匹配的位置
|
||||
start = max(0, match.start() - context_chars)
|
||||
end = min(len(content), match.end() + context_chars)
|
||||
|
||||
# 获取上下文
|
||||
context = content[start:end]
|
||||
|
||||
# 获取并处理引用键
|
||||
keys_str = match.group(1)
|
||||
keys = self._extract_keys_from_group(keys_str)
|
||||
|
||||
# 为每个键添加上下文
|
||||
for key in keys:
|
||||
if key not in contexts:
|
||||
contexts[key] = []
|
||||
contexts[key].append(context)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error finding citation contexts: {str(e)}")
|
||||
|
||||
return contexts
|
||||
|
||||
async def process(self, arxiv_id_or_url: str) -> List[SectionFragment]:
|
||||
"""
|
||||
Process ArXiv paper and convert to list of SectionFragments.
|
||||
Each fragment represents the smallest section unit.
|
||||
|
||||
Args:
|
||||
arxiv_id_or_url: ArXiv paper ID or URL
|
||||
|
||||
Returns:
|
||||
List[SectionFragment]: List of processed paper fragments
|
||||
"""
|
||||
try:
|
||||
arxiv_id = self._normalize_arxiv_id(arxiv_id_or_url)
|
||||
paper_dir = self.root_dir / arxiv_id
|
||||
|
||||
# Check if paper directory exists, if not, try to download
|
||||
if not paper_dir.exists():
|
||||
self.logger.info(f"Downloading paper {arxiv_id}")
|
||||
await self.download_paper(arxiv_id, paper_dir)
|
||||
|
||||
# Find main TeX file
|
||||
main_tex = self.tex_processor.find_main_tex_file(str(paper_dir))
|
||||
if not main_tex:
|
||||
raise RuntimeError(f"No main TeX file found in {paper_dir}")
|
||||
|
||||
# 读取主 TeX 文件内容
|
||||
main_tex_content = read_tex_file(main_tex)
|
||||
|
||||
# Get all related TeX files and references
|
||||
tex_files = self.tex_processor.resolve_includes(main_tex)
|
||||
ref_bib = self.tex_processor.resolve_references(main_tex, paper_dir)
|
||||
|
||||
if not tex_files:
|
||||
raise RuntimeError(f"No valid TeX files found for {arxiv_id}")
|
||||
|
||||
# Reset document structure for new processing
|
||||
self.document_structure = DocumentStructure()
|
||||
|
||||
# 提取作者信息
|
||||
author_extractor = LatexAuthorExtractor()
|
||||
authors = author_extractor.extract_authors(main_tex_content)
|
||||
self.document_structure.authors = authors # 保存到文档结构中
|
||||
|
||||
# Process each TeX file
|
||||
for file_path in tex_files:
|
||||
self.logger.info(f"Processing TeX file: {file_path}")
|
||||
tex_content = read_tex_file(file_path)
|
||||
if tex_content:
|
||||
additional_doc = self.document_parser.parse(tex_content)
|
||||
self.document_structure = self.document_structure.merge(additional_doc)
|
||||
|
||||
# Process references if available
|
||||
if ref_bib:
|
||||
self.document_structure = self.process_references(self.document_structure, ref_bib)
|
||||
self.logger.info("Successfully processed references")
|
||||
else:
|
||||
self.logger.info("No references found to process")
|
||||
|
||||
# Generate table of contents once
|
||||
section_tree = self.document_structure.generate_toc_tree()
|
||||
|
||||
# Convert DocumentStructure to SectionFragments
|
||||
fragments = self._convert_to_fragments(
|
||||
doc_structure=self.document_structure,
|
||||
arxiv_id=arxiv_id,
|
||||
section_tree=section_tree
|
||||
)
|
||||
|
||||
return fragments
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to process {arxiv_id_or_url}: {str(e)}")
|
||||
raise
|
||||
|
||||
def _convert_to_fragments(self,
|
||||
doc_structure: DocumentStructure,
|
||||
arxiv_id: str,
|
||||
section_tree: str) -> List[SectionFragment]:
|
||||
"""
|
||||
Convert DocumentStructure to list of SectionFragments.
|
||||
Creates a fragment for each leaf section in the document hierarchy.
|
||||
|
||||
Args:
|
||||
doc_structure: Source DocumentStructure
|
||||
arxiv_id: ArXiv paper ID
|
||||
section_tree: Pre-generated table of contents tree
|
||||
|
||||
Returns:
|
||||
List[SectionFragment]: List of paper fragments
|
||||
"""
|
||||
fragments = []
|
||||
|
||||
# Create a base template for all fragments to avoid repetitive assignments
|
||||
base_fragment_template = {
|
||||
'title': doc_structure.title,
|
||||
'authors': doc_structure.authors,
|
||||
'abstract': doc_structure.abstract,
|
||||
'catalogs': section_tree,
|
||||
'arxiv_id': arxiv_id
|
||||
}
|
||||
|
||||
def get_leaf_sections(section: Section, path: List[str] = None) -> None:
|
||||
"""
|
||||
Recursively find all leaf sections and create fragments.
|
||||
A leaf section is one that has content but no subsections, or has neither.
|
||||
|
||||
Args:
|
||||
section: Current section being processed
|
||||
path: List of section titles forming the path to current section
|
||||
"""
|
||||
if path is None:
|
||||
path = []
|
||||
|
||||
current_path = path + [section.title]
|
||||
|
||||
if not section.subsections:
|
||||
# This is a leaf section, create a fragment if it has content
|
||||
if section.content or section.bibliography:
|
||||
fragment = SectionFragment(
|
||||
**base_fragment_template,
|
||||
current_section="/".join(current_path),
|
||||
content=self._clean_content(section.content),
|
||||
bibliography=section.bibliography
|
||||
)
|
||||
if self._validate_fragment(fragment):
|
||||
fragments.append(fragment)
|
||||
else:
|
||||
# Process each subsection
|
||||
for subsection in section.subsections:
|
||||
get_leaf_sections(subsection, current_path)
|
||||
|
||||
# Process all top-level sections
|
||||
for section in doc_structure.toc:
|
||||
get_leaf_sections(section)
|
||||
|
||||
# Add a fragment for the abstract if it exists
|
||||
if doc_structure.abstract:
|
||||
abstract_fragment = SectionFragment(
|
||||
**base_fragment_template,
|
||||
current_section="Abstract",
|
||||
content=self._clean_content(doc_structure.abstract)
|
||||
)
|
||||
if self._validate_fragment(abstract_fragment):
|
||||
fragments.insert(0, abstract_fragment)
|
||||
|
||||
self.logger.info(f"Created {len(fragments)} fragments")
|
||||
return fragments
|
||||
|
||||
def _validate_fragment(self, fragment: SectionFragment) -> bool:
|
||||
"""
|
||||
Validate if the fragment has all required fields with meaningful content.
|
||||
|
||||
Args:
|
||||
fragment: SectionFragment to validate
|
||||
|
||||
Returns:
|
||||
bool: True if fragment is valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
return all([
|
||||
fragment.title.strip(),
|
||||
fragment.catalogs.strip(),
|
||||
fragment.current_section.strip(),
|
||||
fragment.content.strip() or fragment.bibliography.strip()
|
||||
])
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
def _clean_content(self, content: str) -> str:
|
||||
"""
|
||||
Clean and normalize content text.
|
||||
|
||||
Args:
|
||||
content: Raw content text
|
||||
|
||||
Returns:
|
||||
str: Cleaned content text
|
||||
"""
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
# Remove excessive whitespace
|
||||
content = re.sub(r'\s+', ' ', content)
|
||||
|
||||
# Remove remaining LaTeX artifacts
|
||||
content = re.sub(r'\\item\s*', '• ', content) # Convert \item to bullet points
|
||||
content = re.sub(r'\\[a-zA-Z]+\{([^}]*)\}', r'\1', content) # Remove simple LaTeX commands
|
||||
|
||||
# Clean special characters
|
||||
content = content.replace('\\\\', '\n') # Convert LaTeX newlines to actual newlines
|
||||
content = re.sub(r'\s*\n\s*', '\n', content) # Clean up newlines
|
||||
|
||||
return content.strip()
|
||||
|
||||
|
||||
def process_arxiv_sync(splitter: ArxivSplitter, arxiv_id: str) -> tuple[List[SectionFragment], str, List[Path]]:
|
||||
"""
|
||||
同步处理 ArXiv 文档并返回分割后的片段
|
||||
|
||||
Args:
|
||||
splitter: ArxivSplitter 实例
|
||||
arxiv_id: ArXiv 文档ID
|
||||
|
||||
Returns:
|
||||
list: 分割后的文档片段列表
|
||||
"""
|
||||
try:
|
||||
from crazy_functions.doc_fns.tex_html_formatter import PaperHtmlFormatter
|
||||
# 创建一个异步函数来执行异步操作
|
||||
async def _process():
|
||||
return await splitter.process(arxiv_id)
|
||||
|
||||
# 使用 asyncio.run() 运行异步函数
|
||||
output_files=[]
|
||||
fragments = asyncio.run(_process())
|
||||
file_save_path = splitter.root_dir / "arxiv_fragments"
|
||||
# 保存片段到文件
|
||||
try:
|
||||
md_output_dir = save_fragments_to_file(
|
||||
fragments,
|
||||
output_dir = file_save_path
|
||||
)
|
||||
output_files.append(md_output_dir)
|
||||
except:
|
||||
pass
|
||||
# 创建论文格式化器
|
||||
formatter = PaperContentFormatter()
|
||||
|
||||
# 准备元数据
|
||||
# 创建格式化选项
|
||||
|
||||
metadata = PaperMetadata(
|
||||
title=fragments[0].title,
|
||||
authors=fragments[0].authors,
|
||||
abstract=fragments[0].abstract,
|
||||
catalogs=fragments[0].catalogs,
|
||||
arxiv_id=fragments[0].arxiv_id
|
||||
)
|
||||
|
||||
# 格式化内容
|
||||
formatted_content = formatter.format(fragments, metadata)
|
||||
|
||||
try:
|
||||
html_formatter = PaperHtmlFormatter(fragments, file_save_path)
|
||||
html_output_dir = html_formatter.save_html()
|
||||
output_files.append(html_output_dir)
|
||||
except:
|
||||
pass
|
||||
return fragments, formatted_content, output_files
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Processing failed for {arxiv_id}: {str(e)}")
|
||||
raise
|
||||
def test_arxiv_splitter():
|
||||
"""测试ArXiv分割器的功能"""
|
||||
|
||||
# 测试配置
|
||||
test_cases = [
|
||||
{
|
||||
"arxiv_id": "2411.03663",
|
||||
"expected_title": "Large Language Models and Simple Scripts",
|
||||
"min_fragments": 10,
|
||||
},
|
||||
# {
|
||||
# "arxiv_id": "1805.10988",
|
||||
# "expected_title": "RAG vs Fine-tuning",
|
||||
# "min_fragments": 15,
|
||||
# }
|
||||
]
|
||||
|
||||
# 创建分割器实例
|
||||
splitter = ArxivSplitter(
|
||||
root_dir="private_upload/default_user"
|
||||
)
|
||||
|
||||
for case in test_cases:
|
||||
print(f"\nTesting paper: {case['arxiv_id']}")
|
||||
try:
|
||||
# fragments = await splitter.process(case['arxiv_id'])
|
||||
fragments, formatted_content, output_dir = process_arxiv_sync(splitter, case['arxiv_id'])
|
||||
# 保存fragments
|
||||
for fragment in fragments:
|
||||
# 长度检查
|
||||
print((fragment.content))
|
||||
print(len(fragment.content))
|
||||
# 类型检查
|
||||
print(output_dir)
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Test failed for {case['arxiv_id']}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_arxiv_splitter()
|
||||
@@ -0,0 +1,177 @@
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class LatexAuthorExtractor:
|
||||
def __init__(self):
|
||||
# Patterns for matching author blocks with balanced braces
|
||||
self.author_block_patterns = [
|
||||
# Standard LaTeX patterns with optional arguments
|
||||
r'\\author(?:\s*\[[^\]]*\])?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||
r'\\(?:title)?author[s]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||
r'\\name[s]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||
r'\\Author[s]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||
r'\\AUTHOR[S]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||
# Conference and journal specific patterns
|
||||
r'\\addauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||
r'\\IEEEauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||
r'\\speaker\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||
r'\\authorrunning\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||
# Academic publisher specific patterns
|
||||
r'\\alignauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||
r'\\spauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||
r'\\authors\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||
]
|
||||
|
||||
# Cleaning patterns for LaTeX commands and formatting
|
||||
self.cleaning_patterns = [
|
||||
# Text formatting commands - preserve content
|
||||
(r'\\textbf\{([^}]+)\}', r'\1'),
|
||||
(r'\\textit\{([^}]+)\}', r'\1'),
|
||||
(r'\\emph\{([^}]+)\}', r'\1'),
|
||||
(r'\\texttt\{([^}]+)\}', r'\1'),
|
||||
(r'\\textrm\{([^}]+)\}', r'\1'),
|
||||
(r'\\text\{([^}]+)\}', r'\1'),
|
||||
|
||||
# Affiliation and footnote markers
|
||||
(r'\$\^{[^}]+}\$', ''),
|
||||
(r'\^{[^}]+}', ''),
|
||||
(r'\\thanks\{[^}]+\}', ''),
|
||||
(r'\\footnote\{[^}]+\}', ''),
|
||||
|
||||
# Email and contact formatting
|
||||
(r'\\email\{([^}]+)\}', r'\1'),
|
||||
(r'\\href\{[^}]+\}\{([^}]+)\}', r'\1'),
|
||||
|
||||
# Institution formatting
|
||||
(r'\\inst\{[^}]+\}', ''),
|
||||
(r'\\affil\{[^}]+\}', ''),
|
||||
|
||||
# Special characters and symbols
|
||||
(r'\\&', '&'),
|
||||
(r'\\\\\s*', ' '),
|
||||
(r'\\,', ' '),
|
||||
(r'\\;', ' '),
|
||||
(r'\\quad', ' '),
|
||||
(r'\\qquad', ' '),
|
||||
|
||||
# Math mode content
|
||||
(r'\$[^$]+\$', ''),
|
||||
|
||||
# Common symbols
|
||||
(r'\\dagger', '†'),
|
||||
(r'\\ddagger', '‡'),
|
||||
(r'\\ast', '*'),
|
||||
(r'\\star', '★'),
|
||||
|
||||
# Remove remaining LaTeX commands
|
||||
(r'\\[a-zA-Z]+', ''),
|
||||
|
||||
# Clean up remaining special characters
|
||||
(r'[\\{}]', '')
|
||||
]
|
||||
|
||||
def extract_author_block(self, text: str) -> Optional[str]:
|
||||
"""
|
||||
Extract the complete author block from LaTeX text.
|
||||
|
||||
Args:
|
||||
text (str): Input LaTeX text
|
||||
|
||||
Returns:
|
||||
Optional[str]: Extracted author block or None if not found
|
||||
"""
|
||||
try:
|
||||
if not text:
|
||||
return None
|
||||
|
||||
for pattern in self.author_block_patterns:
|
||||
match = re.search(pattern, text, re.DOTALL | re.MULTILINE)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
return None
|
||||
|
||||
except (AttributeError, IndexError) as e:
|
||||
print(f"Error extracting author block: {e}")
|
||||
return None
|
||||
|
||||
def clean_tex_commands(self, text: str) -> str:
|
||||
"""
|
||||
Remove LaTeX commands and formatting from text while preserving content.
|
||||
|
||||
Args:
|
||||
text (str): Text containing LaTeX commands
|
||||
|
||||
Returns:
|
||||
str: Cleaned text with commands removed
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
cleaned_text = text
|
||||
|
||||
# Apply cleaning patterns
|
||||
for pattern, replacement in self.cleaning_patterns:
|
||||
cleaned_text = re.sub(pattern, replacement, cleaned_text)
|
||||
|
||||
# Clean up whitespace
|
||||
cleaned_text = re.sub(r'\s+', ' ', cleaned_text)
|
||||
cleaned_text = cleaned_text.strip()
|
||||
|
||||
return cleaned_text
|
||||
|
||||
def extract_authors(self, text: str) -> Optional[str]:
|
||||
"""
|
||||
Extract and clean author information from LaTeX text.
|
||||
|
||||
Args:
|
||||
text (str): Input LaTeX text
|
||||
|
||||
Returns:
|
||||
Optional[str]: Cleaned author information or None if extraction fails
|
||||
"""
|
||||
try:
|
||||
if not text:
|
||||
return None
|
||||
|
||||
# Extract author block
|
||||
author_block = self.extract_author_block(text)
|
||||
if not author_block:
|
||||
return None
|
||||
|
||||
# Clean LaTeX commands
|
||||
cleaned_authors = self.clean_tex_commands(author_block)
|
||||
return cleaned_authors or None
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing text: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def test_author_extractor():
|
||||
"""Test the LatexAuthorExtractor with sample inputs."""
|
||||
test_cases = [
|
||||
# Basic test case
|
||||
(r"\author{John Doe}", "John Doe"),
|
||||
|
||||
# Test with multiple authors
|
||||
(r"\author{Alice Smith \and Bob Jones}", "Alice Smith and Bob Jones"),
|
||||
|
||||
# Test with affiliations
|
||||
(r"\author[1]{John Smith}\affil[1]{University}", "John Smith"),
|
||||
|
||||
]
|
||||
|
||||
extractor = LatexAuthorExtractor()
|
||||
|
||||
for i, (input_tex, expected) in enumerate(test_cases, 1):
|
||||
result = extractor.extract_authors(input_tex)
|
||||
print(f"\nTest case {i}:")
|
||||
print(f"Input: {input_tex[:50]}...")
|
||||
print(f"Expected: {expected[:50]}...")
|
||||
print(f"Got: {result[:50]}...")
|
||||
print(f"Pass: {bool(result and result.strip() == expected.strip())}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_author_extractor()
|
||||
@@ -0,0 +1,290 @@
|
||||
"""
|
||||
LaTeX Document Parser
|
||||
|
||||
This module provides functionality for parsing and extracting structured information from LaTeX documents,
|
||||
including metadata, document structure, and content. It uses modular design and clean architecture principles.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict
|
||||
|
||||
from crazy_functions.rag_fns.arxiv_fns.latex_cleaner import clean_latex_commands
|
||||
from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section, EnhancedSectionExtractor
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def read_tex_file(file_path):
|
||||
encodings = ['utf-8', 'latin1', 'gbk', 'gb2312', 'ascii']
|
||||
for encoding in encodings:
|
||||
try:
|
||||
with open(file_path, 'r', encoding=encoding) as f:
|
||||
return f.read()
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentStructure:
|
||||
title: str = ''
|
||||
authors: str = ''
|
||||
abstract: str = ''
|
||||
toc: List[Section] = field(default_factory=list)
|
||||
metadata: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def merge(self, other: 'DocumentStructure', strategy: str = 'smart') -> 'DocumentStructure':
|
||||
"""
|
||||
Merge this document structure with another one.
|
||||
|
||||
Args:
|
||||
other: Another DocumentStructure to merge with
|
||||
strategy: Merge strategy - 'smart' (default) or 'append'
|
||||
'smart' - Intelligently merge sections with same titles
|
||||
'append' - Simply append sections from other document
|
||||
"""
|
||||
merged = deepcopy(self)
|
||||
|
||||
# Merge title if needed
|
||||
if not merged.title and other.title:
|
||||
merged.title = other.title
|
||||
|
||||
# Merge abstract
|
||||
merged.abstract = self._merge_abstract(merged.abstract, other.abstract)
|
||||
|
||||
# Merge metadata
|
||||
merged.metadata.update(other.metadata)
|
||||
|
||||
if strategy == 'append':
|
||||
merged.toc.extend(deepcopy(other.toc))
|
||||
else: # smart merge
|
||||
# Create sections lookup for efficient merging
|
||||
sections_map = {s.title: s for s in merged.toc}
|
||||
|
||||
for other_section in other.toc:
|
||||
if other_section.title in sections_map:
|
||||
# Merge existing section
|
||||
idx = next(i for i, s in enumerate(merged.toc)
|
||||
if s.title == other_section.title)
|
||||
merged.toc[idx] = merged.toc[idx].merge(other_section)
|
||||
else:
|
||||
# Add new section
|
||||
merged.toc.append(deepcopy(other_section))
|
||||
|
||||
return merged
|
||||
|
||||
@staticmethod
|
||||
def _merge_abstract(abstract1: str, abstract2: str) -> str:
|
||||
"""Merge abstracts intelligently."""
|
||||
if not abstract1:
|
||||
return abstract2
|
||||
if not abstract2:
|
||||
return abstract1
|
||||
# Combine non-empty abstracts with a separator
|
||||
return f"{abstract1}\n\n{abstract2}"
|
||||
|
||||
def generate_toc_tree(self, indent_char: str = " ", abstract_preview_length: int = 0) -> str:
|
||||
"""
|
||||
Generate a tree-like string representation of the table of contents including abstract.
|
||||
|
||||
Args:
|
||||
indent_char: Character(s) used for indentation. Default is two spaces.
|
||||
abstract_preview_length: Maximum length of abstract preview. Default is 200 characters.
|
||||
|
||||
Returns:
|
||||
str: A formatted string showing the hierarchical document structure with abstract
|
||||
"""
|
||||
|
||||
def _format_section(section: Section, level: int = 0) -> str:
|
||||
# Create the current section line with proper indentation
|
||||
current_line = f"{indent_char * level}{'•' if level > 0 else '○'} {section.title}\n"
|
||||
|
||||
# Recursively process subsections
|
||||
subsections = ""
|
||||
if section.subsections:
|
||||
subsections = "".join(_format_section(subsec, level + 1)
|
||||
for subsec in section.subsections)
|
||||
|
||||
return current_line + subsections
|
||||
|
||||
result = []
|
||||
|
||||
# Add document title if it exists
|
||||
if self.title:
|
||||
result.append(f"《{self.title}》\n")
|
||||
|
||||
# Add abstract if it exists
|
||||
if self.abstract:
|
||||
result.append("\n□ Abstract:")
|
||||
# Format abstract content with word wrap
|
||||
abstract_preview = self.abstract[:abstract_preview_length]
|
||||
if len(self.abstract) > abstract_preview_length:
|
||||
abstract_preview += "..."
|
||||
|
||||
# Split abstract into lines and indent them
|
||||
wrapped_lines = []
|
||||
current_line = ""
|
||||
for word in abstract_preview.split():
|
||||
if len(current_line) + len(word) + 1 <= 80: # 80 characters per line
|
||||
current_line = (current_line + " " + word).strip()
|
||||
else:
|
||||
wrapped_lines.append(current_line)
|
||||
current_line = word
|
||||
if current_line:
|
||||
wrapped_lines.append(current_line)
|
||||
|
||||
# Add formatted abstract lines
|
||||
for line in wrapped_lines:
|
||||
result.append(f"\n{indent_char}{line}")
|
||||
result.append("\n") # Add extra newline after abstract
|
||||
|
||||
# Add table of contents header if there are sections
|
||||
if self.toc:
|
||||
result.append("\n◈ Table of Contents:\n")
|
||||
|
||||
# Add all top-level sections and their subsections
|
||||
result.extend(_format_section(section, 0) for section in self.toc)
|
||||
|
||||
return "".join(result)
|
||||
|
||||
|
||||
class BaseExtractor(ABC):
|
||||
"""Base class for LaTeX content extractors."""
|
||||
|
||||
@abstractmethod
|
||||
def extract(self, content: str) -> str:
|
||||
"""Extract specific content from LaTeX document."""
|
||||
pass
|
||||
|
||||
|
||||
class TitleExtractor(BaseExtractor):
|
||||
"""Extracts title from LaTeX document."""
|
||||
|
||||
PATTERNS = [
|
||||
r'\\title{(.+?)}',
|
||||
r'\\title\[.*?\]{(.+?)}',
|
||||
r'\\Title{(.+?)}',
|
||||
r'\\TITLE{(.+?)}',
|
||||
r'\\begin{document}\s*\\section[*]?{(.+?)}',
|
||||
r'\\maketitle\s*\\section[*]?{(.+?)}',
|
||||
r'\\chapter[*]?{(.+?)}'
|
||||
]
|
||||
|
||||
def extract(self, content: str) -> str:
|
||||
"""Extract title using defined patterns."""
|
||||
for pattern in self.PATTERNS:
|
||||
matches = list(re.finditer(pattern, content, re.IGNORECASE | re.DOTALL))
|
||||
for match in matches:
|
||||
title = match.group(1).strip()
|
||||
if title:
|
||||
return clean_latex_commands(title)
|
||||
return ''
|
||||
|
||||
|
||||
class AbstractExtractor(BaseExtractor):
|
||||
"""Extracts abstract from LaTeX document."""
|
||||
|
||||
PATTERNS = [
|
||||
r'\\begin{abstract}(.*?)\\end{abstract}',
|
||||
r'\\abstract{(.*?)}',
|
||||
r'\\ABSTRACT{(.*?)}',
|
||||
r'\\Abstract{(.*?)}',
|
||||
r'\\begin{Abstract}(.*?)\\end{Abstract}',
|
||||
r'\\section[*]?{(?:Abstract|ABSTRACT)}\s*(.*?)(?:\\section|\Z)',
|
||||
r'\\chapter[*]?{(?:Abstract|ABSTRACT)}\s*(.*?)(?:\\chapter|\Z)'
|
||||
]
|
||||
|
||||
def extract(self, content: str) -> str:
|
||||
"""Extract abstract using defined patterns."""
|
||||
for pattern in self.PATTERNS:
|
||||
matches = list(re.finditer(pattern, content, re.IGNORECASE | re.DOTALL))
|
||||
for match in matches:
|
||||
abstract = match.group(1).strip()
|
||||
if abstract:
|
||||
return clean_latex_commands(abstract)
|
||||
return ''
|
||||
|
||||
|
||||
class EssayStructureParser:
|
||||
"""Main class for parsing LaTeX documents."""
|
||||
|
||||
def __init__(self):
|
||||
self.title_extractor = TitleExtractor()
|
||||
self.abstract_extractor = AbstractExtractor()
|
||||
self.section_extractor = EnhancedSectionExtractor() # Using the enhanced extractor
|
||||
|
||||
def parse(self, content: str) -> DocumentStructure:
|
||||
"""Parse LaTeX document and extract structured information."""
|
||||
try:
|
||||
content = self._preprocess_content(content)
|
||||
|
||||
return DocumentStructure(
|
||||
title=self.title_extractor.extract(content),
|
||||
abstract=self.abstract_extractor.extract(content),
|
||||
toc=self.section_extractor.extract(content)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing LaTeX document: {str(e)}")
|
||||
raise
|
||||
|
||||
def _preprocess_content(self, content: str) -> str:
|
||||
"""Preprocess LaTeX content for parsing."""
|
||||
# Remove comments
|
||||
content = re.sub(r'(?<!\\)%.*$', '', content, flags=re.MULTILINE)
|
||||
return content
|
||||
|
||||
|
||||
def pretty_print_structure(doc: DocumentStructure, max_content_length: int = 100):
|
||||
"""Print document structure in a readable format."""
|
||||
print(f"Title: {doc.title}\n")
|
||||
print(f"Abstract: {doc.abstract}\n")
|
||||
print("Table of Contents:")
|
||||
|
||||
def print_section(section: Section, indent: int = 0):
|
||||
print(" " * indent + f"- {section.title}")
|
||||
if section.content:
|
||||
preview = section.content[:max_content_length]
|
||||
if len(section.content) > max_content_length:
|
||||
preview += "..."
|
||||
print(" " * (indent + 1) + f"Content: {preview}")
|
||||
for subsection in section.subsections:
|
||||
print_section(subsection, indent + 1)
|
||||
|
||||
for section in doc.toc:
|
||||
print_section(section)
|
||||
|
||||
|
||||
# Example usage:
|
||||
if __name__ == "__main__":
|
||||
|
||||
# Test with a file
|
||||
file_path = 'test_cache/2411.03663/neurips_2024.tex'
|
||||
main_tex = read_tex_file(file_path)
|
||||
|
||||
# Parse main file
|
||||
parser = EssayStructureParser()
|
||||
main_doc = parser.parse(main_tex)
|
||||
|
||||
# Merge other documents
|
||||
file_path_list = [
|
||||
"test_cache/2411.03663/1_intro.tex",
|
||||
"test_cache/2411.03663/0_abstract.tex",
|
||||
"test_cache/2411.03663/2_pre.tex",
|
||||
"test_cache/2411.03663/3_method.tex",
|
||||
"test_cache/2411.03663/4_experiment.tex",
|
||||
"test_cache/2411.03663/5_related_work.tex",
|
||||
"test_cache/2411.03663/6_conclu.tex",
|
||||
"test_cache/2411.03663/reference.bib"
|
||||
]
|
||||
for file_path in file_path_list:
|
||||
tex_content = read_tex_file(file_path)
|
||||
additional_doc = parser.parse(tex_content)
|
||||
main_doc = main_doc.merge(additional_doc)
|
||||
|
||||
tree = main_doc.generate_toc_tree()
|
||||
pretty_print_structure(main_doc)
|
||||
@@ -0,0 +1,329 @@
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from typing import Set, Dict, Pattern, Optional, List, Tuple
|
||||
|
||||
|
||||
class EnvType(Enum):
|
||||
"""Environment classification types."""
|
||||
PRESERVE = "preserve" # Preserve complete environment including commands
|
||||
REMOVE = "remove" # Remove environment completely
|
||||
EXTRACT = "extract" # Extract and clean content
|
||||
|
||||
|
||||
@dataclass
|
||||
class LatexConfig:
|
||||
"""Configuration for LaTeX processing."""
|
||||
preserve_envs: Set[str] = field(default_factory=lambda: {
|
||||
# Math environments - preserve complete content
|
||||
'equation', 'equation*', 'align', 'align*', 'displaymath',
|
||||
'math', 'eqnarray', 'eqnarray*', 'gather', 'gather*',
|
||||
'multline', 'multline*', 'flalign', 'flalign*',
|
||||
'alignat', 'alignat*', 'cases', 'split', 'aligned',
|
||||
# Tables and figures - preserve structure and content
|
||||
'table', 'table*', 'tabular', 'tabularx', 'array', 'matrix',
|
||||
'figure', 'figure*', 'subfigure', 'wrapfigure',
|
||||
'minipage', 'tabbing', 'verbatim', 'longtable',
|
||||
'sidewaystable', 'sidewaysfigure', 'floatrow',
|
||||
# Arrays and matrices
|
||||
'pmatrix', 'bmatrix', 'Bmatrix', 'vmatrix', 'Vmatrix',
|
||||
'smallmatrix', 'array', 'matrix*', 'pmatrix*', 'bmatrix*',
|
||||
# Algorithms and code
|
||||
'algorithm', 'algorithmic', 'lstlisting', 'verbatim',
|
||||
'minted', 'listing', 'algorithmic*', 'algorithm2e',
|
||||
# Theorems and proofs
|
||||
'theorem', 'proof', 'definition', 'lemma', 'corollary',
|
||||
'proposition', 'example', 'remark', 'note', 'claim',
|
||||
'axiom', 'property', 'assumption', 'conjecture', 'observation',
|
||||
# Bibliography
|
||||
'thebibliography', 'bibliography', 'references'
|
||||
})
|
||||
|
||||
# 引用类命令的特殊处理配置
|
||||
citation_commands: Set[str] = field(default_factory=lambda: {
|
||||
# Basic citations
|
||||
'cite', 'citep', 'citet', 'citeyear', 'citeauthor',
|
||||
'citeyearpar', 'citetext', 'citenum',
|
||||
# Natbib citations
|
||||
'citefullauthor', 'citealp', 'citealt', 'citename',
|
||||
'citepalias', 'citetalias', 'citetext',
|
||||
# Cross-references
|
||||
'ref', 'eqref', 'pageref', 'autoref', 'nameref', 'cref',
|
||||
'Cref', 'vref', 'Vref', 'fref', 'pref',
|
||||
# Hyperref
|
||||
'hyperref', 'href', 'url',
|
||||
# Labels
|
||||
'label', 'tag'
|
||||
})
|
||||
|
||||
preserve_commands: Set[str] = field(default_factory=lambda: {
|
||||
# Text formatting
|
||||
'emph', 'textbf', 'textit', 'underline', 'texttt', 'footnote',
|
||||
'section', 'subsection', 'subsubsection', 'paragraph', 'part',
|
||||
'chapter', 'title', 'author', 'date', 'thanks',
|
||||
# Math operators and symbols
|
||||
'frac', 'sum', 'int', 'prod', 'lim', 'sup', 'inf',
|
||||
'partial', 'nabla', 'implies', 'iff', 'therefore',
|
||||
'exists', 'forall', 'in', 'subset', 'subseteq',
|
||||
# Greek letters and math symbols
|
||||
'alpha', 'beta', 'gamma', 'delta', 'epsilon', 'zeta',
|
||||
'eta', 'theta', 'iota', 'kappa', 'lambda', 'mu',
|
||||
'nu', 'xi', 'pi', 'rho', 'sigma', 'tau',
|
||||
'upsilon', 'phi', 'chi', 'psi', 'omega',
|
||||
'Gamma', 'Delta', 'Theta', 'Lambda', 'Xi', 'Pi',
|
||||
'Sigma', 'Upsilon', 'Phi', 'Psi', 'Omega',
|
||||
# Math commands
|
||||
'left', 'right', 'big', 'Big', 'bigg', 'Bigg',
|
||||
'mathbf', 'mathit', 'mathsf', 'mathtt', 'mathbb',
|
||||
'mathcal', 'mathfrak', 'mathscr', 'mathrm', 'mathop',
|
||||
'operatorname', 'overline', 'underline', 'overbrace',
|
||||
'underbrace', 'overset', 'underset', 'stackrel',
|
||||
# Spacing and alignment
|
||||
'quad', 'qquad', 'hspace', 'vspace', 'medskip',
|
||||
'bigskip', 'smallskip', 'hfill', 'vfill', 'centering',
|
||||
'raggedright', 'raggedleft'
|
||||
})
|
||||
|
||||
remove_commands: Set[str] = field(default_factory=lambda: {
|
||||
# Document setup
|
||||
'documentclass', 'usepackage', 'input', 'include', 'includeonly',
|
||||
'bibliographystyle', 'frontmatter', 'mainmatter',
|
||||
'newtheorem', 'theoremstyle', 'proofname',
|
||||
'newcommand', 'renewcommand', 'providecommand', 'DeclareMathOperator',
|
||||
'newenvironment',
|
||||
# Layout and spacing
|
||||
'pagestyle', 'thispagestyle', 'newpage', 'clearpage',
|
||||
'pagebreak', 'linebreak', 'newline', 'setlength',
|
||||
'setcounter', 'addtocounter', 'makeatletter',
|
||||
'makeatother', 'pagenumbering'
|
||||
})
|
||||
|
||||
latex_chars: Dict[str, str] = field(default_factory=lambda: {
|
||||
'~': ' ', '\\&': '&', '\\%': '%', '\\_': '_', '\\$': '$',
|
||||
'\\#': '#', '\\{': '{', '\\}': '}', '``': '"', "''": '"',
|
||||
'\\textbackslash': '\\', '\\ldots': '...', '\\dots': '...',
|
||||
'\\textasciitilde': '~', '\\textasciicircum': '^'
|
||||
})
|
||||
|
||||
# 保留原始格式的特殊命令模式
|
||||
special_command_patterns: List[Tuple[str, str]] = field(default_factory=lambda: [
|
||||
(r'\\cite\*?(?:\[[^\]]*\])?{([^}]+)}', r'\\cite{\1}'),
|
||||
(r'\\ref\*?{([^}]+)}', r'\\ref{\1}'),
|
||||
(r'\\label{([^}]+)}', r'\\label{\1}'),
|
||||
(r'\\eqref{([^}]+)}', r'\\eqref{\1}'),
|
||||
(r'\\autoref{([^}]+)}', r'\\autoref{\1}'),
|
||||
(r'\\url{([^}]+)}', r'\\url{\1}'),
|
||||
(r'\\href{([^}]+)}{([^}]+)}', r'\\href{\1}{\2}')
|
||||
])
|
||||
|
||||
|
||||
class LatexCleaner:
|
||||
"""Enhanced LaTeX text cleaner that preserves mathematical content and citations."""
|
||||
|
||||
def __init__(self, config: Optional[LatexConfig] = None):
|
||||
self.config = config or LatexConfig()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
# 初始化正则表达式缓存
|
||||
self._regex_cache = {}
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _get_env_pattern(self, env_name: str) -> Pattern:
|
||||
"""Get cached regex pattern for environment matching."""
|
||||
return re.compile(fr'\\begin{{{env_name}}}(.*?)\\end{{{env_name}}}', re.DOTALL)
|
||||
|
||||
def _get_env_type(self, env_name: str) -> EnvType:
|
||||
"""Determine environment processing type."""
|
||||
if env_name.rstrip('*') in {name.rstrip('*') for name in self.config.preserve_envs}:
|
||||
return EnvType.PRESERVE
|
||||
elif env_name in {'comment'}:
|
||||
return EnvType.REMOVE
|
||||
return EnvType.EXTRACT
|
||||
|
||||
def _preserve_special_commands(self, text: str) -> str:
|
||||
"""Preserve special commands like citations and references with their complete structure."""
|
||||
for pattern, replacement in self.config.special_command_patterns:
|
||||
if pattern not in self._regex_cache:
|
||||
self._regex_cache[pattern] = re.compile(pattern)
|
||||
|
||||
def replace_func(match):
|
||||
# 保持原始命令格式
|
||||
return match.group(0)
|
||||
|
||||
text = self._regex_cache[pattern].sub(replace_func, text)
|
||||
return text
|
||||
|
||||
def _process_environment(self, match: re.Match) -> str:
|
||||
"""Process LaTeX environments while preserving complete content for special environments."""
|
||||
try:
|
||||
env_name = match.group(1)
|
||||
content = match.group(2)
|
||||
env_type = self._get_env_type(env_name)
|
||||
|
||||
if env_type == EnvType.PRESERVE:
|
||||
# 完整保留环境内容
|
||||
complete_env = match.group(0)
|
||||
return f"\n[BEGIN_{env_name}]\n{complete_env}\n[END_{env_name}]\n"
|
||||
elif env_type == EnvType.REMOVE:
|
||||
return ' '
|
||||
else:
|
||||
# 处理嵌套环境
|
||||
return self._clean_nested_environments(content)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error processing environment {match.group(1) if match else 'unknown'}: {e}")
|
||||
return match.group(0)
|
||||
|
||||
def _preserve_inline_math(self, text: str) -> str:
|
||||
"""Preserve complete inline math content."""
|
||||
|
||||
def preserve_math(match):
|
||||
return f" {match.group(0)} "
|
||||
|
||||
patterns = [
|
||||
(r'\$[^$]+\$', preserve_math),
|
||||
(r'\\[\(\[].*?\\[\)\]]', preserve_math),
|
||||
(r'\\begin{math}.*?\\end{math}', preserve_math)
|
||||
]
|
||||
|
||||
for pattern, handler in patterns:
|
||||
if pattern not in self._regex_cache:
|
||||
self._regex_cache[pattern] = re.compile(pattern, re.DOTALL)
|
||||
text = self._regex_cache[pattern].sub(handler, text)
|
||||
|
||||
return text
|
||||
|
||||
def _clean_nested_environments(self, text: str) -> str:
|
||||
"""Process nested environments recursively."""
|
||||
pattern = r'\\begin{(\w+)}(.*?)\\end{\1}'
|
||||
if pattern not in self._regex_cache:
|
||||
self._regex_cache[pattern] = re.compile(pattern, re.DOTALL)
|
||||
|
||||
return self._regex_cache[pattern].sub(self._process_environment, text)
|
||||
|
||||
def _clean_commands(self, text: str) -> str:
|
||||
"""Clean LaTeX commands while preserving important content."""
|
||||
# 首先处理特殊命令
|
||||
text = self._preserve_special_commands(text)
|
||||
|
||||
# 保留内联数学
|
||||
text = self._preserve_inline_math(text)
|
||||
|
||||
# 移除指定的命令
|
||||
for cmd in self.config.remove_commands:
|
||||
if cmd not in self._regex_cache:
|
||||
self._regex_cache[cmd] = re.compile(
|
||||
fr'\\{cmd}\*?(?:\[.*?\])?(?:{{.*?}})*'
|
||||
)
|
||||
text = self._regex_cache[cmd].sub('', text)
|
||||
|
||||
# 处理带内容的命令
|
||||
def handle_command(match: re.Match) -> str:
|
||||
cmd = match.group(1).rstrip('*')
|
||||
if cmd in self.config.preserve_commands or cmd in self.config.citation_commands:
|
||||
return match.group(0) # 完整保留命令和内容
|
||||
return ' '
|
||||
|
||||
if 'command_pattern' not in self._regex_cache:
|
||||
self._regex_cache['command_pattern'] = re.compile(
|
||||
r'\\(\w+)\*?(?:\[.*?\])?{(.*?)}'
|
||||
)
|
||||
|
||||
text = self._regex_cache['command_pattern'].sub(handle_command, text)
|
||||
return text
|
||||
|
||||
def _normalize_text(self, text: str) -> str:
|
||||
"""Normalize text while preserving special content markers."""
|
||||
# 替换特殊字符
|
||||
for char, replacement in self.config.latex_chars.items():
|
||||
text = text.replace(char, replacement)
|
||||
|
||||
# 清理空白字符,同时保留环境标记
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = re.sub(r'\s*\[BEGIN_(\w+)\]\s*', r'\n[BEGIN_\1]\n', text)
|
||||
text = re.sub(r'\s*\[END_(\w+)\]\s*', r'\n[END_\1]\n', text)
|
||||
|
||||
# 保持块级环境之间的分隔
|
||||
text = re.sub(r'\n{3,}', '\n\n', text)
|
||||
|
||||
return text.strip()
|
||||
|
||||
def clean_text(self, text: str) -> str:
|
||||
"""Clean LaTeX text while preserving mathematical content, citations, and special environments."""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
try:
|
||||
# 移除注释
|
||||
text = re.sub(r'(?<!\\)%.*?(?=\n|$)', '', text, flags=re.MULTILINE)
|
||||
|
||||
# 处理环境
|
||||
text = self._clean_nested_environments(text)
|
||||
|
||||
# 清理命令并规范化
|
||||
text = self._clean_commands(text)
|
||||
text = self._normalize_text(text)
|
||||
|
||||
return text
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error cleaning text: {e}")
|
||||
return text # 发生错误时返回原始文本
|
||||
|
||||
|
||||
def clean_latex_commands(text: str) -> str:
|
||||
"""Convenience function for quick text cleaning with default config."""
|
||||
cleaner = LatexCleaner()
|
||||
return cleaner.clean_text(text)
|
||||
|
||||
|
||||
# Example usage:
|
||||
if __name__ == "__main__":
|
||||
text = r"""
|
||||
\documentclass{article}
|
||||
\begin{document}
|
||||
|
||||
\section{Introduction}
|
||||
This is a reference to \cite{smith2020} and equation \eqref{eq:main}.
|
||||
|
||||
\begin{equation}\label{eq:main}
|
||||
E = mc^2 \times \sum_{i=1}^{n} x_i
|
||||
\end{equation}
|
||||
|
||||
See Figure \ref{fig:example} for details.
|
||||
|
||||
\begin{figure}
|
||||
\includegraphics{image.png}
|
||||
\caption{Example figure\label
|
||||
\textbf{Important} result: $E=mc^2$ and
|
||||
\begin{equation}
|
||||
F = ma
|
||||
\end{equation}
|
||||
\label{sec:intro}
|
||||
"""
|
||||
|
||||
# Custom configuration
|
||||
config = LatexConfig(
|
||||
preserve_envs={},
|
||||
preserve_commands={'textbf', 'emph'},
|
||||
latex_chars={'~': ' ', '\\&': '&'}
|
||||
)
|
||||
|
||||
|
||||
def read_tex_file(file_path):
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
content = file.read()
|
||||
return content
|
||||
except FileNotFoundError:
|
||||
return "文件未找到,请检查路径是否正确。"
|
||||
except Exception as e:
|
||||
return f"读取文件时发生错误: {e}"
|
||||
|
||||
|
||||
# 使用函数
|
||||
file_path = 'test_cache/2411.03663/neurips_2024.tex'
|
||||
content = read_tex_file(file_path)
|
||||
cleaner = LatexCleaner(config)
|
||||
text = cleaner.clean_text(text)
|
||||
print(text)
|
||||
@@ -0,0 +1,396 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class LaTeXPatterns:
|
||||
"""LaTeX模式存储类,用于集中管理所有LaTeX相关的正则表达式模式"""
|
||||
special_envs = {
|
||||
'math': [
|
||||
# 基础数学环境
|
||||
r'\\begin{(equation|align|gather|eqnarray|multline|flalign|alignat)\*?}.*?\\end{\1\*?}',
|
||||
r'\$\$.*?\$\$',
|
||||
r'\$[^$]+\$',
|
||||
# 矩阵环境
|
||||
r'\\begin{(matrix|pmatrix|bmatrix|Bmatrix|vmatrix|Vmatrix|smallmatrix)\*?}.*?\\end{\1\*?}',
|
||||
# 数组环境
|
||||
r'\\begin{(array|cases|aligned|gathered|split)\*?}.*?\\end{\1\*?}',
|
||||
# 其他数学环境
|
||||
r'\\begin{(subequations|math|displaymath)\*?}.*?\\end{\1\*?}'
|
||||
],
|
||||
|
||||
'table': [
|
||||
# 基础表格环境
|
||||
r'\\begin{(table|tabular|tabularx|tabulary|longtable)\*?}.*?\\end{\1\*?}',
|
||||
# 复杂表格环境
|
||||
r'\\begin{(tabu|supertabular|xtabular|mpsupertabular)\*?}.*?\\end{\1\*?}',
|
||||
# 自定义表格环境
|
||||
r'\\begin{(threeparttable|tablefootnote)\*?}.*?\\end{\1\*?}',
|
||||
# 表格注释环境
|
||||
r'\\begin{(tablenotes)\*?}.*?\\end{\1\*?}'
|
||||
],
|
||||
|
||||
'figure': [
|
||||
# 图片环境
|
||||
r'\\begin{figure\*?}.*?\\end{figure\*?}',
|
||||
r'\\begin{(subfigure|wrapfigure)\*?}.*?\\end{\1\*?}',
|
||||
# 图片插入命令
|
||||
r'\\includegraphics(\[.*?\])?\{.*?\}',
|
||||
# tikz 图形环境
|
||||
r'\\begin{(tikzpicture|pgfpicture)\*?}.*?\\end{\1\*?}',
|
||||
# 其他图形环境
|
||||
r'\\begin{(picture|pspicture)\*?}.*?\\end{\1\*?}'
|
||||
],
|
||||
|
||||
'algorithm': [
|
||||
# 算法环境
|
||||
r'\\begin{(algorithm|algorithmic|algorithm2e|algorithmicx)\*?}.*?\\end{\1\*?}',
|
||||
r'\\begin{(lstlisting|verbatim|minted|listing)\*?}.*?\\end{\1\*?}',
|
||||
# 代码块环境
|
||||
r'\\begin{(code|verbatimtab|verbatimwrite)\*?}.*?\\end{\1\*?}',
|
||||
# 伪代码环境
|
||||
r'\\begin{(pseudocode|procedure)\*?}.*?\\end{\1\*?}'
|
||||
],
|
||||
|
||||
'list': [
|
||||
# 列表环境
|
||||
r'\\begin{(itemize|enumerate|description)\*?}.*?\\end{\1\*?}',
|
||||
r'\\begin{(list|compactlist|bulletlist)\*?}.*?\\end{\1\*?}',
|
||||
# 自定义列表环境
|
||||
r'\\begin{(tasks|todolist)\*?}.*?\\end{\1\*?}'
|
||||
],
|
||||
|
||||
'theorem': [
|
||||
# 定理类环境
|
||||
r'\\begin{(theorem|lemma|proposition|corollary)\*?}.*?\\end{\1\*?}',
|
||||
r'\\begin{(definition|example|proof|remark)\*?}.*?\\end{\1\*?}',
|
||||
# 其他证明环境
|
||||
r'\\begin{(axiom|property|assumption|conjecture)\*?}.*?\\end{\1\*?}'
|
||||
],
|
||||
|
||||
'box': [
|
||||
# 文本框环境
|
||||
r'\\begin{(tcolorbox|mdframed|framed|shaded)\*?}.*?\\end{\1\*?}',
|
||||
r'\\begin{(boxedminipage|shadowbox)\*?}.*?\\end{\1\*?}',
|
||||
# 强调环境
|
||||
r'\\begin{(important|warning|info|note)\*?}.*?\\end{\1\*?}'
|
||||
],
|
||||
|
||||
'quote': [
|
||||
# 引用环境
|
||||
r'\\begin{(quote|quotation|verse|abstract)\*?}.*?\\end{\1\*?}',
|
||||
r'\\begin{(excerpt|epigraph)\*?}.*?\\end{\1\*?}'
|
||||
],
|
||||
|
||||
'bibliography': [
|
||||
# 参考文献环境
|
||||
r'\\begin{(thebibliography|bibliography)\*?}.*?\\end{\1\*?}',
|
||||
r'\\begin{(biblist|citelist)\*?}.*?\\end{\1\*?}'
|
||||
],
|
||||
|
||||
'index': [
|
||||
# 索引环境
|
||||
r'\\begin{(theindex|printindex)\*?}.*?\\end{\1\*?}',
|
||||
r'\\begin{(glossary|acronym)\*?}.*?\\end{\1\*?}'
|
||||
]
|
||||
}
|
||||
# 章节模式
|
||||
section_patterns = [
|
||||
# 基础章节命令
|
||||
r'\\chapter\{([^}]+)\}',
|
||||
r'\\section\{([^}]+)\}',
|
||||
r'\\subsection\{([^}]+)\}',
|
||||
r'\\subsubsection\{([^}]+)\}',
|
||||
r'\\paragraph\{([^}]+)\}',
|
||||
r'\\subparagraph\{([^}]+)\}',
|
||||
|
||||
# 带星号的变体(不编号)
|
||||
r'\\chapter\*\{([^}]+)\}',
|
||||
r'\\section\*\{([^}]+)\}',
|
||||
r'\\subsection\*\{([^}]+)\}',
|
||||
r'\\subsubsection\*\{([^}]+)\}',
|
||||
r'\\paragraph\*\{([^}]+)\}',
|
||||
r'\\subparagraph\*\{([^}]+)\}',
|
||||
|
||||
# 特殊章节
|
||||
r'\\part\{([^}]+)\}',
|
||||
r'\\part\*\{([^}]+)\}',
|
||||
r'\\appendix\{([^}]+)\}',
|
||||
|
||||
# 前言部分
|
||||
r'\\frontmatter\{([^}]+)\}',
|
||||
r'\\mainmatter\{([^}]+)\}',
|
||||
r'\\backmatter\{([^}]+)\}',
|
||||
|
||||
# 目录相关
|
||||
r'\\tableofcontents',
|
||||
r'\\listoffigures',
|
||||
r'\\listoftables',
|
||||
|
||||
# 自定义章节命令
|
||||
r'\\addchap\{([^}]+)\}', # KOMA-Script类
|
||||
r'\\addsec\{([^}]+)\}', # KOMA-Script类
|
||||
r'\\minisec\{([^}]+)\}', # KOMA-Script类
|
||||
|
||||
# 带可选参数的章节命令
|
||||
r'\\chapter\[([^]]+)\]\{([^}]+)\}',
|
||||
r'\\section\[([^]]+)\]\{([^}]+)\}',
|
||||
r'\\subsection\[([^]]+)\]\{([^}]+)\}'
|
||||
]
|
||||
|
||||
# 包含模式
|
||||
include_patterns = [
|
||||
r'\\(input|include|subfile)\{([^}]+)\}'
|
||||
]
|
||||
|
||||
metadata_patterns = {
|
||||
# 标题相关
|
||||
'title': [
|
||||
r'\\title\{([^}]+)\}',
|
||||
r'\\Title\{([^}]+)\}',
|
||||
r'\\doctitle\{([^}]+)\}',
|
||||
r'\\subtitle\{([^}]+)\}',
|
||||
r'\\chapter\*?\{([^}]+)\}', # 第一章可能作为标题
|
||||
r'\\maketitle\s*\\section\*?\{([^}]+)\}' # 第一节可能作为标题
|
||||
],
|
||||
|
||||
# 摘要相关
|
||||
'abstract': [
|
||||
r'\\begin{abstract}(.*?)\\end{abstract}',
|
||||
r'\\abstract\{([^}]+)\}',
|
||||
r'\\begin{摘要}(.*?)\\end{摘要}',
|
||||
r'\\begin{Summary}(.*?)\\end{Summary}',
|
||||
r'\\begin{synopsis}(.*?)\\end{synopsis}',
|
||||
r'\\begin{abstracten}(.*?)\\end{abstracten}' # 英文摘要
|
||||
],
|
||||
|
||||
# 作者信息
|
||||
'author': [
|
||||
r'\\author\{([^}]+)\}',
|
||||
r'\\Author\{([^}]+)\}',
|
||||
r'\\authorinfo\{([^}]+)\}',
|
||||
r'\\authors\{([^}]+)\}',
|
||||
r'\\author\[([^]]+)\]\{([^}]+)\}', # 带附加信息的作者
|
||||
r'\\begin{authors}(.*?)\\end{authors}'
|
||||
],
|
||||
|
||||
# 日期相关
|
||||
'date': [
|
||||
r'\\date\{([^}]+)\}',
|
||||
r'\\Date\{([^}]+)\}',
|
||||
r'\\submitdate\{([^}]+)\}',
|
||||
r'\\publishdate\{([^}]+)\}',
|
||||
r'\\revisiondate\{([^}]+)\}'
|
||||
],
|
||||
|
||||
# 关键词
|
||||
'keywords': [
|
||||
r'\\keywords\{([^}]+)\}',
|
||||
r'\\Keywords\{([^}]+)\}',
|
||||
r'\\begin{keywords}(.*?)\\end{keywords}',
|
||||
r'\\key\{([^}]+)\}',
|
||||
r'\\begin{关键词}(.*?)\\end{关键词}'
|
||||
],
|
||||
|
||||
# 机构/单位
|
||||
'institution': [
|
||||
r'\\institute\{([^}]+)\}',
|
||||
r'\\institution\{([^}]+)\}',
|
||||
r'\\affiliation\{([^}]+)\}',
|
||||
r'\\organization\{([^}]+)\}',
|
||||
r'\\department\{([^}]+)\}'
|
||||
],
|
||||
|
||||
# 学科/主题
|
||||
'subject': [
|
||||
r'\\subject\{([^}]+)\}',
|
||||
r'\\Subject\{([^}]+)\}',
|
||||
r'\\field\{([^}]+)\}',
|
||||
r'\\discipline\{([^}]+)\}'
|
||||
],
|
||||
|
||||
# 版本信息
|
||||
'version': [
|
||||
r'\\version\{([^}]+)\}',
|
||||
r'\\revision\{([^}]+)\}',
|
||||
r'\\release\{([^}]+)\}'
|
||||
],
|
||||
|
||||
# 许可证/版权
|
||||
'license': [
|
||||
r'\\license\{([^}]+)\}',
|
||||
r'\\copyright\{([^}]+)\}',
|
||||
r'\\begin{license}(.*?)\\end{license}'
|
||||
],
|
||||
|
||||
# 联系方式
|
||||
'contact': [
|
||||
r'\\email\{([^}]+)\}',
|
||||
r'\\phone\{([^}]+)\}',
|
||||
r'\\address\{([^}]+)\}',
|
||||
r'\\contact\{([^}]+)\}'
|
||||
],
|
||||
|
||||
# 致谢
|
||||
'acknowledgments': [
|
||||
r'\\begin{acknowledgments}(.*?)\\end{acknowledgments}',
|
||||
r'\\acknowledgments\{([^}]+)\}',
|
||||
r'\\thanks\{([^}]+)\}',
|
||||
r'\\begin{致谢}(.*?)\\end{致谢}'
|
||||
],
|
||||
|
||||
# 项目/基金
|
||||
'funding': [
|
||||
r'\\funding\{([^}]+)\}',
|
||||
r'\\grant\{([^}]+)\}',
|
||||
r'\\project\{([^}]+)\}',
|
||||
r'\\support\{([^}]+)\}'
|
||||
],
|
||||
|
||||
# 分类号/编号
|
||||
'classification': [
|
||||
r'\\classification\{([^}]+)\}',
|
||||
r'\\serialnumber\{([^}]+)\}',
|
||||
r'\\id\{([^}]+)\}',
|
||||
r'\\doi\{([^}]+)\}'
|
||||
],
|
||||
|
||||
# 语言
|
||||
'language': [
|
||||
r'\\documentlanguage\{([^}]+)\}',
|
||||
r'\\lang\{([^}]+)\}',
|
||||
r'\\language\{([^}]+)\}'
|
||||
]
|
||||
}
|
||||
latex_only_patterns = {
|
||||
# 文档类和包引入
|
||||
r'\\documentclass(\[.*?\])?\{.*?\}',
|
||||
r'\\usepackage(\[.*?\])?\{.*?\}',
|
||||
# 常见的文档设置命令
|
||||
r'\\setlength\{.*?\}\{.*?\}',
|
||||
r'\\newcommand\{.*?\}(\[.*?\])?\{.*?\}',
|
||||
r'\\renewcommand\{.*?\}(\[.*?\])?\{.*?\}',
|
||||
r'\\definecolor\{.*?\}\{.*?\}\{.*?\}',
|
||||
# 页面设置相关
|
||||
r'\\pagestyle\{.*?\}',
|
||||
r'\\thispagestyle\{.*?\}',
|
||||
# 其他常见的设置命令
|
||||
r'\\bibliographystyle\{.*?\}',
|
||||
r'\\bibliography\{.*?\}',
|
||||
r'\\setcounter\{.*?\}\{.*?\}',
|
||||
# 字体和文本设置命令
|
||||
r'\\makeFNbottom',
|
||||
r'\\@setfontsize\\[A-Z]+\{.*?\}\{.*?\}', # 匹配字体大小设置
|
||||
r'\\renewcommand\\[A-Z]+\{\\@setfontsize\\[A-Z]+\{.*?\}\{.*?\}\}',
|
||||
r'\\renewcommand\{?\\thefootnote\}?\{\\fnsymbol\{footnote\}\}',
|
||||
r'\\renewcommand\\footnoterule\{.*?\}',
|
||||
r'\\color\{.*?\}',
|
||||
|
||||
# 页面和节标题设置
|
||||
r'\\setcounter\{secnumdepth\}\{.*?\}',
|
||||
r'\\renewcommand\\@biblabel\[.*?\]\{.*?\}',
|
||||
r'\\renewcommand\\@makefntext\[.*?\](\{.*?\})*',
|
||||
r'\\renewcommand\{?\\figurename\}?\{.*?\}',
|
||||
|
||||
# 字体样式设置
|
||||
r'\\sectionfont\{.*?\}',
|
||||
r'\\subsectionfont\{.*?\}',
|
||||
r'\\subsubsectionfont\{.*?\}',
|
||||
|
||||
# 间距和布局设置
|
||||
r'\\setstretch\{.*?\}',
|
||||
r'\\setlength\{\\skip\\footins\}\{.*?\}',
|
||||
r'\\setlength\{\\footnotesep\}\{.*?\}',
|
||||
r'\\setlength\{\\jot\}\{.*?\}',
|
||||
r'\\hrule\s+width\s+.*?\s+height\s+.*?',
|
||||
|
||||
# makeatletter 和 makeatother
|
||||
r'\\makeatletter\s*',
|
||||
r'\\makeatother\s*',
|
||||
r'\\footnotetext\{[^}]*\$\^{[^}]*}\$[^}]*\}', # 带有上标的脚注
|
||||
# r'\\footnotetext\{[^}]*\}', # 普通脚注
|
||||
# r'\\footnotetext\{.*?(?:\$\^{.*?}\$)?.*?(?:email\s*:\s*[^}]*)?.*?\}', # 带有邮箱的脚注
|
||||
# r'\\footnotetext\{.*?(?:ESI|DOI).*?\}', # 带有 DOI 或 ESI 引用的脚注
|
||||
# 文档结构命令
|
||||
r'\\begin\{document\}',
|
||||
r'\\end\{document\}',
|
||||
r'\\maketitle',
|
||||
r'\\printbibliography',
|
||||
r'\\newpage',
|
||||
|
||||
# 输入文件命令
|
||||
r'\\input\{[^}]*\}',
|
||||
r'\\input\{.*?\.tex\}', # 特别匹配 .tex 后缀的输入
|
||||
|
||||
# 脚注相关
|
||||
# r'\\footnotetext\[\d+\]\{[^}]*\}', # 带编号的脚注
|
||||
|
||||
# 致谢环境
|
||||
r'\\begin\{ack\}',
|
||||
r'\\end\{ack\}',
|
||||
r'\\begin\{ack\}[^\n]*(?:\n.*?)*?\\end\{ack\}', # 匹配整个致谢环境及其内容
|
||||
|
||||
# 其他文档控制命令
|
||||
r'\\renewcommand\{\\thefootnote\}\{\\fnsymbol\{footnote\}\}',
|
||||
}
|
||||
math_envs = [
|
||||
# 基础数学环境
|
||||
(r'\\begin{equation\*?}.*?\\end{equation\*?}', 'equation'), # 单行公式
|
||||
(r'\\begin{align\*?}.*?\\end{align\*?}', 'align'), # 多行对齐公式
|
||||
(r'\\begin{gather\*?}.*?\\end{gather\*?}', 'gather'), # 多行居中公式
|
||||
(r'\$\$.*?\$\$', 'display'), # 行间公式
|
||||
(r'\$.*?\$', 'inline'), # 行内公式
|
||||
|
||||
# 矩阵环境
|
||||
(r'\\begin{matrix}.*?\\end{matrix}', 'matrix'), # 基础矩阵
|
||||
(r'\\begin{pmatrix}.*?\\end{pmatrix}', 'pmatrix'), # 圆括号矩阵
|
||||
(r'\\begin{bmatrix}.*?\\end{bmatrix}', 'bmatrix'), # 方括号矩阵
|
||||
(r'\\begin{vmatrix}.*?\\end{vmatrix}', 'vmatrix'), # 竖线矩阵
|
||||
(r'\\begin{Vmatrix}.*?\\end{Vmatrix}', 'Vmatrix'), # 双竖线矩阵
|
||||
(r'\\begin{smallmatrix}.*?\\end{smallmatrix}', 'smallmatrix'), # 小号矩阵
|
||||
|
||||
# 数组环境
|
||||
(r'\\begin{array}.*?\\end{array}', 'array'), # 数组
|
||||
(r'\\begin{cases}.*?\\end{cases}', 'cases'), # 分段函数
|
||||
|
||||
# 多行公式环境
|
||||
(r'\\begin{multline\*?}.*?\\end{multline\*?}', 'multline'), # 多行单个公式
|
||||
(r'\\begin{split}.*?\\end{split}', 'split'), # 拆分长公式
|
||||
(r'\\begin{alignat\*?}.*?\\end{alignat\*?}', 'alignat'), # 对齐环境带间距控制
|
||||
(r'\\begin{flalign\*?}.*?\\end{flalign\*?}', 'flalign'), # 完全左对齐
|
||||
|
||||
# 特殊数学环境
|
||||
(r'\\begin{subequations}.*?\\end{subequations}', 'subequations'), # 子公式编号
|
||||
(r'\\begin{gathered}.*?\\end{gathered}', 'gathered'), # 居中对齐组
|
||||
(r'\\begin{aligned}.*?\\end{aligned}', 'aligned'), # 内部对齐组
|
||||
|
||||
# 定理类环境
|
||||
(r'\\begin{theorem}.*?\\end{theorem}', 'theorem'), # 定理
|
||||
(r'\\begin{lemma}.*?\\end{lemma}', 'lemma'), # 引理
|
||||
(r'\\begin{proof}.*?\\end{proof}', 'proof'), # 证明
|
||||
|
||||
# 数学模式中的表格环境
|
||||
(r'\\begin{tabular}.*?\\end{tabular}', 'tabular'), # 表格
|
||||
(r'\\begin{array}.*?\\end{array}', 'array'), # 数组
|
||||
|
||||
# 其他专业数学环境
|
||||
(r'\\begin{CD}.*?\\end{CD}', 'CD'), # 交换图
|
||||
(r'\\begin{boxed}.*?\\end{boxed}', 'boxed'), # 带框公式
|
||||
(r'\\begin{empheq}.*?\\end{empheq}', 'empheq'), # 强调公式
|
||||
|
||||
# 化学方程式环境 (需要加载 mhchem 包)
|
||||
(r'\\begin{reaction}.*?\\end{reaction}', 'reaction'), # 化学反应式
|
||||
(r'\\ce\{.*?\}', 'chemequation'), # 化学方程式
|
||||
|
||||
# 物理单位环境 (需要加载 siunitx 包)
|
||||
(r'\\SI\{.*?\}\{.*?\}', 'SI'), # 物理单位
|
||||
(r'\\si\{.*?\}', 'si'), # 单位
|
||||
|
||||
# 补充环境
|
||||
(r'\\begin{equation\+}.*?\\end{equation\+}', 'equation+'), # breqn包的自动换行公式
|
||||
(r'\\begin{dmath\*?}.*?\\end{dmath\*?}', 'dmath'), # breqn包的显示数学模式
|
||||
(r'\\begin{dgroup\*?}.*?\\end{dgroup\*?}', 'dgroup'), # breqn包的公式组
|
||||
]
|
||||
|
||||
# 示例使用函数
|
||||
|
||||
# 使用示例
|
||||
@@ -0,0 +1,416 @@
|
||||
import logging
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SectionLevel(Enum):
|
||||
CHAPTER = 0
|
||||
SECTION = 1
|
||||
SUBSECTION = 2
|
||||
SUBSUBSECTION = 3
|
||||
PARAGRAPH = 4
|
||||
SUBPARAGRAPH = 5
|
||||
|
||||
def __lt__(self, other):
|
||||
if not isinstance(other, SectionLevel):
|
||||
return NotImplemented
|
||||
return self.value < other.value
|
||||
|
||||
def __le__(self, other):
|
||||
if not isinstance(other, SectionLevel):
|
||||
return NotImplemented
|
||||
return self.value <= other.value
|
||||
|
||||
def __gt__(self, other):
|
||||
if not isinstance(other, SectionLevel):
|
||||
return NotImplemented
|
||||
return self.value > other.value
|
||||
|
||||
def __ge__(self, other):
|
||||
if not isinstance(other, SectionLevel):
|
||||
return NotImplemented
|
||||
return self.value >= other.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class Section:
|
||||
level: SectionLevel
|
||||
title: str
|
||||
content: str = ''
|
||||
bibliography: str = ''
|
||||
subsections: List['Section'] = field(default_factory=list)
|
||||
|
||||
def merge(self, other: 'Section') -> 'Section':
|
||||
"""Merge this section with another section."""
|
||||
if self.title != other.title or self.level != other.level:
|
||||
raise ValueError("Can only merge sections with same title and level")
|
||||
|
||||
merged = deepcopy(self)
|
||||
merged.content = self._merge_content(self.content, other.content)
|
||||
|
||||
# Create subsections lookup for efficient merging
|
||||
subsections_map = {s.title: s for s in merged.subsections}
|
||||
|
||||
for other_subsection in other.subsections:
|
||||
if other_subsection.title in subsections_map:
|
||||
# Merge existing subsection
|
||||
idx = next(i for i, s in enumerate(merged.subsections)
|
||||
if s.title == other_subsection.title)
|
||||
merged.subsections[idx] = merged.subsections[idx].merge(other_subsection)
|
||||
else:
|
||||
# Add new subsection
|
||||
merged.subsections.append(deepcopy(other_subsection))
|
||||
|
||||
return merged
|
||||
|
||||
@staticmethod
|
||||
def _merge_content(content1: str, content2: str) -> str:
|
||||
"""Merge content strings intelligently."""
|
||||
if not content1:
|
||||
return content2
|
||||
if not content2:
|
||||
return content1
|
||||
# Combine non-empty contents with a separator
|
||||
return f"{content1}\n\n{content2}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LatexEnvironment:
|
||||
"""表示LaTeX环境的数据类"""
|
||||
name: str
|
||||
start: int
|
||||
end: int
|
||||
content: str
|
||||
raw: str
|
||||
|
||||
|
||||
class EnhancedSectionExtractor:
|
||||
"""Enhanced section extractor with comprehensive content handling and hierarchy management."""
|
||||
|
||||
def __init__(self, preserve_environments: bool = True):
|
||||
"""
|
||||
初始化Section提取器
|
||||
|
||||
Args:
|
||||
preserve_environments: 是否保留特定环境(如equation, figure等)的原始LaTeX代码
|
||||
"""
|
||||
self.preserve_environments = preserve_environments
|
||||
|
||||
# Section级别定义
|
||||
self.section_levels = {
|
||||
'chapter': SectionLevel.CHAPTER,
|
||||
'section': SectionLevel.SECTION,
|
||||
'subsection': SectionLevel.SUBSECTION,
|
||||
'subsubsection': SectionLevel.SUBSUBSECTION,
|
||||
'paragraph': SectionLevel.PARAGRAPH,
|
||||
'subparagraph': SectionLevel.SUBPARAGRAPH
|
||||
}
|
||||
|
||||
# 需要保留的环境类型
|
||||
self.important_environments = {
|
||||
'equation', 'equation*', 'align', 'align*',
|
||||
'figure', 'table', 'algorithm', 'algorithmic',
|
||||
'definition', 'theorem', 'lemma', 'proof',
|
||||
'itemize', 'enumerate', 'description'
|
||||
}
|
||||
|
||||
# 改进的section pattern
|
||||
self.section_pattern = (
|
||||
r'\\(?P<type>chapter|section|subsection|subsubsection|paragraph|subparagraph)'
|
||||
r'\*?' # Optional star
|
||||
r'(?:\[(?P<short>.*?)\])?' # Optional short title
|
||||
r'{(?P<title>(?:[^{}]|\{[^{}]*\})*?)}' # Main title with nested braces support
|
||||
)
|
||||
|
||||
# 环境匹配模式
|
||||
self.environment_pattern = (
|
||||
r'\\begin{(?P<env_name>[^}]+)}'
|
||||
r'(?P<env_content>.*?)'
|
||||
r'\\end{(?P=env_name)}'
|
||||
)
|
||||
|
||||
def _find_environments(self, content: str) -> List[LatexEnvironment]:
|
||||
"""
|
||||
查找文档中的所有LaTeX环境。
|
||||
支持嵌套环境的处理。
|
||||
"""
|
||||
environments = []
|
||||
stack = []
|
||||
|
||||
# 使用正则表达式查找所有begin和end标记
|
||||
begin_pattern = r'\\begin{([^}]+)}'
|
||||
end_pattern = r'\\end{([^}]+)}'
|
||||
|
||||
# 组合模式来同时匹配begin和end
|
||||
tokens = []
|
||||
for match in re.finditer(fr'({begin_pattern})|({end_pattern})', content):
|
||||
if match.group(1): # begin标记
|
||||
tokens.append(('begin', match.group(1), match.start()))
|
||||
else: # end标记
|
||||
tokens.append(('end', match.group(2), match.start()))
|
||||
|
||||
# 处理环境嵌套
|
||||
for token_type, env_name, pos in tokens:
|
||||
if token_type == 'begin':
|
||||
stack.append((env_name, pos))
|
||||
elif token_type == 'end' and stack:
|
||||
if stack[-1][0] == env_name:
|
||||
start_env_name, start_pos = stack.pop()
|
||||
env_content = content[start_pos:pos]
|
||||
raw_content = content[start_pos:pos + len('\\end{' + env_name + '}')]
|
||||
|
||||
if start_env_name in self.important_environments:
|
||||
environments.append(LatexEnvironment(
|
||||
name=start_env_name,
|
||||
start=start_pos,
|
||||
end=pos + len('\\end{' + env_name + '}'),
|
||||
content=env_content,
|
||||
raw=raw_content
|
||||
))
|
||||
|
||||
return sorted(environments, key=lambda x: x.start)
|
||||
|
||||
def _protect_environments(self, content: str) -> Tuple[str, Dict[str, str]]:
|
||||
"""
|
||||
保护重要的LaTeX环境,用占位符替换它们。
|
||||
返回处理后的内容和恢复映射。
|
||||
"""
|
||||
environments = self._find_environments(content)
|
||||
replacements = {}
|
||||
|
||||
# 从后向前替换,避免位置改变的问题
|
||||
for env in reversed(environments):
|
||||
if env.name in self.important_environments:
|
||||
placeholder = f'__ENV_{len(replacements)}__'
|
||||
replacements[placeholder] = env.raw
|
||||
content = content[:env.start] + placeholder + content[env.end:]
|
||||
|
||||
return content, replacements
|
||||
|
||||
def _restore_environments(self, content: str, replacements: Dict[str, str]) -> str:
|
||||
"""
|
||||
恢复之前保护的环境。
|
||||
"""
|
||||
for placeholder, original in replacements.items():
|
||||
content = content.replace(placeholder, original)
|
||||
return content
|
||||
|
||||
def extract(self, content: str) -> List[Section]:
|
||||
"""
|
||||
从LaTeX文档中提取sections及其内容。
|
||||
|
||||
Args:
|
||||
content: LaTeX文档内容
|
||||
|
||||
Returns:
|
||||
List[Section]: 提取的section列表,包含层次结构
|
||||
"""
|
||||
try:
|
||||
# 预处理:保护重要环境
|
||||
if self.preserve_environments:
|
||||
content, env_replacements = self._protect_environments(content)
|
||||
|
||||
# 查找所有sections
|
||||
sections = self._find_all_sections(content)
|
||||
if not sections:
|
||||
return []
|
||||
|
||||
# 处理sections
|
||||
root_sections = self._process_sections(content, sections)
|
||||
|
||||
# 如果需要,恢复环境
|
||||
if self.preserve_environments:
|
||||
for section in self._traverse_sections(root_sections):
|
||||
section.content = self._restore_environments(section.content, env_replacements)
|
||||
|
||||
return root_sections
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting sections: {str(e)}")
|
||||
raise
|
||||
|
||||
def _find_all_sections(self, content: str) -> List[dict]:
|
||||
"""查找所有section命令及其位置。"""
|
||||
sections = []
|
||||
|
||||
for match in re.finditer(self.section_pattern, content, re.DOTALL | re.MULTILINE):
|
||||
section_type = match.group('type').lower()
|
||||
if section_type not in self.section_levels:
|
||||
continue
|
||||
|
||||
section = {
|
||||
'type': section_type,
|
||||
'level': self.section_levels[section_type],
|
||||
'title': self._clean_title(match.group('title')),
|
||||
'start': match.start(),
|
||||
'command_end': match.end(),
|
||||
}
|
||||
sections.append(section)
|
||||
|
||||
return sorted(sections, key=lambda x: x['start'])
|
||||
|
||||
def _process_sections(self, content: str, sections: List[dict]) -> List[Section]:
|
||||
"""处理sections以构建层次结构和提取内容。"""
|
||||
# 计算content范围
|
||||
self._calculate_content_ranges(content, sections)
|
||||
|
||||
# 构建层次结构
|
||||
root_sections = []
|
||||
section_stack = []
|
||||
|
||||
for section_info in sections:
|
||||
new_section = Section(
|
||||
level=section_info['level'],
|
||||
title=section_info['title'],
|
||||
content=self._extract_clean_content(content, section_info),
|
||||
subsections=[]
|
||||
)
|
||||
|
||||
# 调整堆栈以找到正确的父section
|
||||
while section_stack and section_stack[-1].level.value >= new_section.level.value:
|
||||
section_stack.pop()
|
||||
|
||||
if section_stack:
|
||||
section_stack[-1].subsections.append(new_section)
|
||||
else:
|
||||
root_sections.append(new_section)
|
||||
|
||||
section_stack.append(new_section)
|
||||
|
||||
return root_sections
|
||||
|
||||
def _calculate_content_ranges(self, content: str, sections: List[dict]):
|
||||
for i, current in enumerate(sections):
|
||||
content_start = current['command_end']
|
||||
|
||||
# 找到下一个section(无论什么级别)
|
||||
content_end = len(content)
|
||||
for next_section in sections[i + 1:]:
|
||||
content_end = next_section['start']
|
||||
break
|
||||
|
||||
current['content_range'] = (content_start, content_end)
|
||||
|
||||
def _calculate_content_ranges_with_subsection_content(self, content: str, sections: List[dict]):
|
||||
"""为每个section计算内容范围。"""
|
||||
for i, current in enumerate(sections):
|
||||
content_start = current['command_end']
|
||||
|
||||
# 找到下一个同级或更高级的section
|
||||
content_end = len(content)
|
||||
for next_section in sections[i + 1:]:
|
||||
if next_section['level'] <= current['level']:
|
||||
content_end = next_section['start']
|
||||
break
|
||||
|
||||
current['content_range'] = (content_start, content_end)
|
||||
|
||||
def _extract_clean_content(self, content: str, section_info: dict) -> str:
|
||||
"""提取并清理section内容。"""
|
||||
start, end = section_info['content_range']
|
||||
raw_content = content[start:end]
|
||||
|
||||
# 清理内容
|
||||
clean_content = self._clean_content(raw_content)
|
||||
return clean_content
|
||||
|
||||
def _clean_content(self, content: str) -> str:
|
||||
"""清理LaTeX内容同时保留重要信息。"""
|
||||
# 移除注释
|
||||
content = re.sub(r'(?<!\\)%.*?\n', '\n', content)
|
||||
|
||||
# LaTeX命令处理规则
|
||||
replacements = [
|
||||
# 保留引用
|
||||
(r'\\cite(?:\[.*?\])?{(.*?)}', r'[cite:\1]'),
|
||||
# 保留脚注
|
||||
(r'\\footnote{(.*?)}', r'[footnote:\1]'),
|
||||
# 处理引用
|
||||
(r'\\ref{(.*?)}', r'[ref:\1]'),
|
||||
# 保留URL
|
||||
(r'\\url{(.*?)}', r'[url:\1]'),
|
||||
# 保留超链接
|
||||
(r'\\href{(.*?)}{(.*?)}', r'[\2](\1)'),
|
||||
# 处理文本格式命令
|
||||
(r'\\(?:textbf|textit|emph){(.*?)}', r'\1'),
|
||||
# 保留特殊字符
|
||||
(r'\\([&%$#_{}])', r'\1'),
|
||||
]
|
||||
|
||||
# 应用所有替换规则
|
||||
for pattern, replacement in replacements:
|
||||
content = re.sub(pattern, replacement, content, flags=re.DOTALL)
|
||||
|
||||
# 清理多余的空白
|
||||
content = re.sub(r'\n\s*\n', '\n\n', content)
|
||||
return content.strip()
|
||||
|
||||
def _clean_title(self, title: str) -> str:
|
||||
"""清理section标题。"""
|
||||
# 处理嵌套的花括号
|
||||
while '{' in title:
|
||||
title = re.sub(r'{([^{}]*)}', r'\1', title)
|
||||
|
||||
# 处理LaTeX命令
|
||||
title = re.sub(r'\\[a-zA-Z]+(?:\[.*?\])?{(.*?)}', r'\1', title)
|
||||
title = re.sub(r'\\([&%$#_{}])', r'\1', title)
|
||||
|
||||
return title.strip()
|
||||
|
||||
def _traverse_sections(self, sections: List[Section]) -> List[Section]:
|
||||
"""遍历所有sections(包括子sections)。"""
|
||||
result = []
|
||||
for section in sections:
|
||||
result.append(section)
|
||||
result.extend(self._traverse_sections(section.subsections))
|
||||
return result
|
||||
|
||||
|
||||
def test_enhanced_extractor():
|
||||
"""使用复杂的测试用例测试提取器。"""
|
||||
test_content = r"""
|
||||
\section{Complex Examples}
|
||||
Here's a complex section with various environments.
|
||||
|
||||
\begin{equation}
|
||||
E = mc^2
|
||||
\end{equation}
|
||||
|
||||
\subsection{Nested Environments}
|
||||
This subsection has nested environments.
|
||||
|
||||
\begin{figure}
|
||||
\begin{equation*}
|
||||
f(x) = \int_0^x g(t) dt
|
||||
\end{equation*}
|
||||
\caption{A nested equation in a figure}
|
||||
\end{figure}
|
||||
|
||||
"""
|
||||
|
||||
extractor = EnhancedSectionExtractor()
|
||||
sections = extractor.extract(test_content)
|
||||
|
||||
def print_section(section, level=0):
|
||||
print("\n" + " " * level + f"[{section.level.name}] {section.title}")
|
||||
if section.content:
|
||||
content_preview = section.content[:150] + "..." if len(section.content) > 150 else section.content
|
||||
print(" " * (level + 1) + f"Content: {content_preview}")
|
||||
for subsection in section.subsections:
|
||||
print_section(subsection, level + 1)
|
||||
|
||||
print("\nExtracted Section Structure:")
|
||||
for section in sections:
|
||||
print_section(section)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_enhanced_extractor()
|
||||
@@ -0,0 +1,14 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SectionFragment:
|
||||
"""Arxiv论文片段数据类"""
|
||||
title: str # 论文标题
|
||||
authors: str
|
||||
abstract: str # 论文摘要
|
||||
catalogs: str # 文章各章节的目录结构
|
||||
arxiv_id: str = "" # 添加 arxiv_id 属性
|
||||
current_section: str = "Introduction" # 当前片段所属的section或者subsection或者孙subsubsection名字
|
||||
content: str = '' # 当前片段的内容
|
||||
bibliography: str = '' # 当前片段的参考文献
|
||||
@@ -0,0 +1,266 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Set, Optional
|
||||
|
||||
from crazy_functions.rag_fns.arxiv_fns.latex_patterns import LaTeXPatterns
|
||||
|
||||
|
||||
class TexUtils:
|
||||
"""TeX文档处理器类"""
|
||||
|
||||
def __init__(self, ):
|
||||
"""
|
||||
初始化TeX处理器
|
||||
|
||||
Args:
|
||||
char_range: 字符数范围(最小值, 最大值)
|
||||
"""
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# 初始化LaTeX环境和命令模式
|
||||
self._init_patterns()
|
||||
self.latex_only_patterns = LaTeXPatterns.latex_only_patterns
|
||||
|
||||
def _init_patterns(self):
|
||||
"""初始化LaTeX模式匹配规则"""
|
||||
# 特殊环境模式
|
||||
self.special_envs = LaTeXPatterns.special_envs
|
||||
# 章节模式
|
||||
self.section_patterns = LaTeXPatterns.section_patterns
|
||||
# 包含模式
|
||||
self.include_patterns = LaTeXPatterns.include_patterns
|
||||
# 元数据模式
|
||||
self.metadata_patterns = LaTeXPatterns.metadata_patterns
|
||||
|
||||
def read_file(self, file_path: str) -> Optional[str]:
|
||||
"""
|
||||
读取TeX文件内容,支持多种编码
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
Optional[str]: 文件内容或None
|
||||
"""
|
||||
encodings = ['utf-8', 'latin1', 'gbk', 'gb2312', 'ascii']
|
||||
for encoding in encodings:
|
||||
try:
|
||||
with open(file_path, 'r', encoding=encoding) as f:
|
||||
return f.read()
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
|
||||
self.logger.warning(f"Failed to read {file_path} with all encodings")
|
||||
return None
|
||||
|
||||
def find_main_tex_file(self, directory: str) -> Optional[str]:
|
||||
"""
|
||||
查找主TeX文件
|
||||
|
||||
Args:
|
||||
directory: 目录路径
|
||||
|
||||
Returns:
|
||||
Optional[str]: 主文件路径或None
|
||||
"""
|
||||
tex_files = list(Path(directory).rglob("*.tex"))
|
||||
if not tex_files:
|
||||
return None
|
||||
|
||||
# 按优先级查找
|
||||
for tex_file in tex_files:
|
||||
content = self.read_file(str(tex_file))
|
||||
if content:
|
||||
if r'\documentclass' in content:
|
||||
return str(tex_file)
|
||||
if tex_file.name.lower() == 'main.tex':
|
||||
return str(tex_file)
|
||||
|
||||
# 返回最大的tex文件
|
||||
return str(max(tex_files, key=lambda x: x.stat().st_size))
|
||||
|
||||
def resolve_includes(self, tex_file: str, processed: Set[str] = None) -> List[str]:
|
||||
"""
|
||||
解析TeX文件中的include引用
|
||||
|
||||
Args:
|
||||
tex_file: TeX文件路径
|
||||
processed: 已处理的文件集合
|
||||
|
||||
Returns:
|
||||
List[str]: 相关文件路径列表
|
||||
"""
|
||||
if processed is None:
|
||||
processed = set()
|
||||
|
||||
if tex_file in processed:
|
||||
return []
|
||||
|
||||
processed.add(tex_file)
|
||||
result = [tex_file]
|
||||
content = self.read_file(tex_file)
|
||||
|
||||
if not content:
|
||||
return result
|
||||
|
||||
base_dir = Path(tex_file).parent
|
||||
for pattern in self.include_patterns:
|
||||
for match in re.finditer(pattern, content):
|
||||
included_file = match.group(2)
|
||||
if not included_file.endswith('.tex'):
|
||||
included_file += '.tex'
|
||||
|
||||
full_path = str(base_dir / included_file)
|
||||
if os.path.exists(full_path) and full_path not in processed:
|
||||
result.extend(self.resolve_includes(full_path, processed))
|
||||
|
||||
return result
|
||||
|
||||
def resolve_references(self, tex_file: str, path_dir: str = None) -> str:
|
||||
"""
|
||||
解析TeX文件中的参考文献引用,返回所有引用文献的内容,只保留title、author和journal字段。
|
||||
如果在tex_file目录下没找到bib文件,会在path_dir中查找。
|
||||
|
||||
Args:
|
||||
tex_file: TeX文件路径
|
||||
path_dir: 额外的参考文献搜索路径
|
||||
|
||||
Returns:
|
||||
str: 所有参考文献内容的字符串,只包含特定字段,不同参考文献之间用空行分隔
|
||||
"""
|
||||
all_references = [] # 存储所有参考文献内容
|
||||
content = self.read_file(tex_file)
|
||||
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
# 扩展参考文献引用的模式
|
||||
bib_patterns = [
|
||||
r'\\bibliography\{([^}]+)\}',
|
||||
r'\\addbibresource\{([^}]+)\}',
|
||||
r'\\bibliographyfile\{([^}]+)\}',
|
||||
r'\\begin\{thebibliography\}',
|
||||
r'\\bibinput\{([^}]+)\}',
|
||||
r'\\newrefsection\{([^}]+)\}'
|
||||
]
|
||||
|
||||
base_dir = Path(tex_file).parent
|
||||
found_in_tex_dir = False
|
||||
|
||||
# 首先在tex文件目录下查找显式引用的bib文件
|
||||
for pattern in bib_patterns:
|
||||
for match in re.finditer(pattern, content):
|
||||
if not match.groups():
|
||||
continue
|
||||
|
||||
bib_files = match.group(1).split(',')
|
||||
for bib_file in bib_files:
|
||||
bib_file = bib_file.strip()
|
||||
if not bib_file.endswith('.bib'):
|
||||
bib_file += '.bib'
|
||||
|
||||
full_path = str(base_dir / bib_file)
|
||||
if os.path.exists(full_path):
|
||||
found_in_tex_dir = True
|
||||
bib_content = self.read_file(full_path)
|
||||
if bib_content:
|
||||
processed_refs = self._process_bib_content(bib_content)
|
||||
all_references.extend(processed_refs)
|
||||
|
||||
# 如果在tex文件目录下没找到bib文件,且提供了额外搜索路径
|
||||
if not found_in_tex_dir and path_dir:
|
||||
search_dir = Path(path_dir)
|
||||
try:
|
||||
for bib_path in search_dir.glob('**/*.bib'):
|
||||
bib_content = self.read_file(str(bib_path))
|
||||
if bib_content:
|
||||
processed_refs = self._process_bib_content(bib_content)
|
||||
all_references.extend(processed_refs)
|
||||
except Exception as e:
|
||||
print(f"Error searching in path_dir: {e}")
|
||||
|
||||
# 合并所有参考文献内容,用空行分隔
|
||||
return "\n\n".join(all_references)
|
||||
|
||||
def _process_bib_content(self, content: str) -> List[str]:
|
||||
"""
|
||||
处理bib文件内容,提取每个参考文献的特定字段
|
||||
|
||||
Args:
|
||||
content: bib文件内容
|
||||
|
||||
Returns:
|
||||
List[str]: 处理后的参考文献列表
|
||||
"""
|
||||
processed_refs = []
|
||||
# 匹配完整的参考文献条目
|
||||
ref_pattern = r'@\w+\{[^@]*\}'
|
||||
# 匹配参考文献类型和键值
|
||||
entry_start_pattern = r'@(\w+)\{([^,]*?),'
|
||||
# 匹配字段
|
||||
field_pattern = r'(\w+)\s*=\s*\{([^}]*)\}'
|
||||
|
||||
# 查找所有参考文献条目
|
||||
for ref_match in re.finditer(ref_pattern, content, re.DOTALL):
|
||||
ref_content = ref_match.group(0)
|
||||
|
||||
# 获取参考文献类型和键值
|
||||
entry_match = re.match(entry_start_pattern, ref_content)
|
||||
if not entry_match:
|
||||
continue
|
||||
|
||||
entry_type, cite_key = entry_match.groups()
|
||||
|
||||
# 提取需要的字段
|
||||
needed_fields = {'title': None, 'author': None, 'journal': None}
|
||||
for field_match in re.finditer(field_pattern, ref_content):
|
||||
field_name, field_value = field_match.groups()
|
||||
field_name = field_name.lower()
|
||||
if field_name in needed_fields:
|
||||
needed_fields[field_name] = field_value.strip()
|
||||
|
||||
# 构建新的参考文献条目
|
||||
if any(needed_fields.values()): # 如果至少有一个需要的字段
|
||||
ref_lines = [f"@{entry_type}{{{cite_key},"]
|
||||
for field_name, field_value in needed_fields.items():
|
||||
if field_value:
|
||||
ref_lines.append(f" {field_name}={{{field_value}}},")
|
||||
ref_lines[-1] = ref_lines[-1][:-1] # 移除最后一个逗号
|
||||
ref_lines.append("}")
|
||||
|
||||
processed_refs.append("\n".join(ref_lines))
|
||||
|
||||
return processed_refs
|
||||
|
||||
def _extract_inline_references(self, content: str) -> str:
|
||||
"""
|
||||
从tex文件内容中提取直接写在文件中的参考文献
|
||||
|
||||
Args:
|
||||
content: tex文件内容
|
||||
|
||||
Returns:
|
||||
str: 提取的参考文献内容,如果没有找到则返回空字符串
|
||||
"""
|
||||
# 查找参考文献环境
|
||||
bib_start = r'\\begin\{thebibliography\}'
|
||||
bib_end = r'\\end\{thebibliography\}'
|
||||
|
||||
start_match = re.search(bib_start, content)
|
||||
end_match = re.search(bib_end, content)
|
||||
|
||||
if start_match and end_match:
|
||||
return content[start_match.start():end_match.end()]
|
||||
|
||||
return ""
|
||||
|
||||
def _preprocess_content(self, content: str) -> str:
|
||||
"""预处理TeX内容"""
|
||||
# 移除注释
|
||||
content = re.sub(r'(?m)%.*$', '', content)
|
||||
# 规范化空白字符
|
||||
# content = re.sub(r'\s+', ' ', content)
|
||||
content = re.sub(r'\n\s*\n', '\n\n', content)
|
||||
return content.strip()
|
||||
@@ -0,0 +1,165 @@
|
||||
import atexit
|
||||
import os
|
||||
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.ingestion import run_transformations
|
||||
from llama_index.core.schema import TextNode
|
||||
from loguru import logger
|
||||
|
||||
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
|
||||
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
|
||||
|
||||
DEFAULT_QUERY_GENERATION_PROMPT = """\
|
||||
Now, you have context information as below:
|
||||
---------------------
|
||||
{context_str}
|
||||
---------------------
|
||||
Answer the user request below (use the context information if necessary, otherwise you can ignore them):
|
||||
---------------------
|
||||
{query_str}
|
||||
"""
|
||||
|
||||
QUESTION_ANSWER_RECORD = """\
|
||||
{{
|
||||
"type": "This is a previous conversation with the user",
|
||||
"question": "{question}",
|
||||
"answer": "{answer}",
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
class SaveLoad():
|
||||
|
||||
def does_checkpoint_exist(self, checkpoint_dir=None):
|
||||
import os, glob
|
||||
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
|
||||
if not os.path.exists(checkpoint_dir): return False
|
||||
if len(glob.glob(os.path.join(checkpoint_dir, "*.json"))) == 0: return False
|
||||
return True
|
||||
|
||||
def save_to_checkpoint(self, checkpoint_dir=None):
|
||||
logger.info(f'saving vector store to: {checkpoint_dir}')
|
||||
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
|
||||
self.vs_index.storage_context.persist(persist_dir=checkpoint_dir)
|
||||
|
||||
def load_from_checkpoint(self, checkpoint_dir=None):
|
||||
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
|
||||
if self.does_checkpoint_exist(checkpoint_dir=checkpoint_dir):
|
||||
logger.info('loading checkpoint from disk')
|
||||
from llama_index.core import StorageContext, load_index_from_storage
|
||||
storage_context = StorageContext.from_defaults(persist_dir=checkpoint_dir)
|
||||
self.vs_index = load_index_from_storage(storage_context, embed_model=self.embed_model)
|
||||
return self.vs_index
|
||||
else:
|
||||
return self.create_new_vs()
|
||||
|
||||
def create_new_vs(self):
|
||||
return GptacVectorStoreIndex.default_vector_store(embed_model=self.embed_model)
|
||||
|
||||
def purge(self):
|
||||
import shutil
|
||||
shutil.rmtree(self.checkpoint_dir, ignore_errors=True)
|
||||
self.vs_index = self.create_new_vs()
|
||||
|
||||
|
||||
class LlamaIndexRagWorker(SaveLoad):
|
||||
def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None:
|
||||
self.debug_mode = True
|
||||
self.embed_model = OpenAiEmbeddingModel(llm_kwargs)
|
||||
self.user_name = user_name
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
|
||||
# 确保checkpoint_dir存在
|
||||
if checkpoint_dir:
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
logger.info(f"Initializing LlamaIndexRagWorker with checkpoint_dir: {checkpoint_dir}")
|
||||
|
||||
# 初始化向量存储
|
||||
if auto_load_checkpoint and self.does_checkpoint_exist():
|
||||
logger.info("Loading existing vector store from checkpoint")
|
||||
self.vs_index = self.load_from_checkpoint()
|
||||
else:
|
||||
logger.info("Creating new vector store")
|
||||
self.vs_index = self.create_new_vs()
|
||||
|
||||
# 注册退出时保存
|
||||
atexit.register(self.save_to_checkpoint)
|
||||
|
||||
def add_text_to_vector_store(self, text: str) -> None:
|
||||
"""添加文本到向量存储"""
|
||||
try:
|
||||
logger.info(f"Adding text to vector store (first 100 chars): {text[:100]}...")
|
||||
node = TextNode(text=text)
|
||||
nodes = run_transformations(
|
||||
[node],
|
||||
self.vs_index._transformations,
|
||||
show_progress=True
|
||||
)
|
||||
self.vs_index.insert_nodes(nodes)
|
||||
|
||||
# 立即保存
|
||||
self.save_to_checkpoint()
|
||||
|
||||
if self.debug_mode:
|
||||
self.inspect_vector_store()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding text to vector store: {str(e)}")
|
||||
raise
|
||||
|
||||
def save_to_checkpoint(self, checkpoint_dir=None):
|
||||
"""保存向量存储到检查点"""
|
||||
try:
|
||||
if checkpoint_dir is None:
|
||||
checkpoint_dir = self.checkpoint_dir
|
||||
logger.info(f'Saving vector store to: {checkpoint_dir}')
|
||||
if checkpoint_dir:
|
||||
self.vs_index.storage_context.persist(persist_dir=checkpoint_dir)
|
||||
logger.info('Vector store saved successfully')
|
||||
else:
|
||||
logger.warning('No checkpoint directory specified, skipping save')
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint: {str(e)}")
|
||||
raise
|
||||
|
||||
def assign_embedding_model(self):
|
||||
pass
|
||||
|
||||
def inspect_vector_store(self):
|
||||
# This function is for debugging
|
||||
self.vs_index.storage_context.index_store.to_dict()
|
||||
docstore = self.vs_index.storage_context.docstore.docs
|
||||
vector_store_preview = "\n".join([f"{_id} | {tn.text}" for _id, tn in docstore.items()])
|
||||
logger.info('\n++ --------inspect_vector_store begin--------')
|
||||
logger.info(vector_store_preview)
|
||||
logger.info('oo --------inspect_vector_store end--------')
|
||||
return vector_store_preview
|
||||
|
||||
def add_documents_to_vector_store(self, document_list):
|
||||
documents = [Document(text=t) for t in document_list]
|
||||
documents_nodes = run_transformations(
|
||||
documents, # type: ignore
|
||||
self.vs_index._transformations,
|
||||
show_progress=True
|
||||
)
|
||||
self.vs_index.insert_nodes(documents_nodes)
|
||||
if self.debug_mode: self.inspect_vector_store()
|
||||
|
||||
def remember_qa(self, question, answer):
|
||||
formatted_str = QUESTION_ANSWER_RECORD.format(question=question, answer=answer)
|
||||
self.add_text_to_vector_store(formatted_str)
|
||||
|
||||
def retrieve_from_store_with_query(self, query):
|
||||
if self.debug_mode: self.inspect_vector_store()
|
||||
retriever = self.vs_index.as_retriever()
|
||||
return retriever.retrieve(query)
|
||||
|
||||
def build_prompt(self, query, nodes):
|
||||
context_str = self.generate_node_array_preview(nodes)
|
||||
return DEFAULT_QUERY_GENERATION_PROMPT.format(context_str=context_str, query_str=query)
|
||||
|
||||
def generate_node_array_preview(self, nodes):
|
||||
buf = "\n".join(([f"(No.{i + 1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)]))
|
||||
if self.debug_mode: logger.info(buf)
|
||||
return buf
|
||||
@@ -0,0 +1,104 @@
|
||||
import atexit
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from llama_index.core import StorageContext
|
||||
from llama_index.vector_stores.milvus import MilvusVectorStore
|
||||
from loguru import logger
|
||||
|
||||
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
|
||||
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
|
||||
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
|
||||
|
||||
DEFAULT_QUERY_GENERATION_PROMPT = """\
|
||||
Now, you have context information as below:
|
||||
---------------------
|
||||
{context_str}
|
||||
---------------------
|
||||
Answer the user request below (use the context information if necessary, otherwise you can ignore them):
|
||||
---------------------
|
||||
{query_str}
|
||||
"""
|
||||
|
||||
QUESTION_ANSWER_RECORD = """\
|
||||
{{
|
||||
"type": "This is a previous conversation with the user",
|
||||
"question": "{question}",
|
||||
"answer": "{answer}",
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
class MilvusSaveLoad():
|
||||
|
||||
def does_checkpoint_exist(self, checkpoint_dir=None):
|
||||
import os, glob
|
||||
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
|
||||
if not os.path.exists(checkpoint_dir): return False
|
||||
if len(glob.glob(os.path.join(checkpoint_dir, "*.json"))) == 0: return False
|
||||
return True
|
||||
|
||||
def save_to_checkpoint(self, checkpoint_dir=None):
|
||||
logger.info(f'saving vector store to: {checkpoint_dir}')
|
||||
# if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
|
||||
# self.vs_index.storage_context.persist(persist_dir=checkpoint_dir)
|
||||
|
||||
def load_from_checkpoint(self, checkpoint_dir=None):
|
||||
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
|
||||
if self.does_checkpoint_exist(checkpoint_dir=checkpoint_dir):
|
||||
logger.info('loading checkpoint from disk')
|
||||
from llama_index.core import StorageContext, load_index_from_storage
|
||||
storage_context = StorageContext.from_defaults(persist_dir=checkpoint_dir)
|
||||
try:
|
||||
self.vs_index = load_index_from_storage(storage_context, embed_model=self.embed_model)
|
||||
return self.vs_index
|
||||
except:
|
||||
return self.create_new_vs(checkpoint_dir)
|
||||
else:
|
||||
return self.create_new_vs(checkpoint_dir)
|
||||
|
||||
def create_new_vs(self, checkpoint_dir, overwrite=False):
|
||||
vector_store = MilvusVectorStore(
|
||||
uri=os.path.join(checkpoint_dir, "milvus_demo.db"),
|
||||
dim=self.embed_model.embedding_dimension(),
|
||||
overwrite=overwrite
|
||||
)
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
index = GptacVectorStoreIndex.default_vector_store(storage_context=storage_context,
|
||||
embed_model=self.embed_model)
|
||||
return index
|
||||
|
||||
def purge(self):
|
||||
self.vs_index = self.create_new_vs(self.checkpoint_dir, overwrite=True)
|
||||
|
||||
|
||||
class MilvusRagWorker(MilvusSaveLoad, LlamaIndexRagWorker):
|
||||
|
||||
def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None:
|
||||
self.debug_mode = True
|
||||
self.embed_model = OpenAiEmbeddingModel(llm_kwargs)
|
||||
self.user_name = user_name
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
if auto_load_checkpoint:
|
||||
self.vs_index = self.load_from_checkpoint(checkpoint_dir)
|
||||
else:
|
||||
self.vs_index = self.create_new_vs(checkpoint_dir)
|
||||
atexit.register(lambda: self.save_to_checkpoint(checkpoint_dir))
|
||||
|
||||
def inspect_vector_store(self):
|
||||
# This function is for debugging
|
||||
try:
|
||||
self.vs_index.storage_context.index_store.to_dict()
|
||||
docstore = self.vs_index.storage_context.docstore.docs
|
||||
if not docstore.items():
|
||||
raise ValueError("cannot inspect")
|
||||
vector_store_preview = "\n".join([f"{_id} | {tn.text}" for _id, tn in docstore.items()])
|
||||
except:
|
||||
dummy_retrieve_res: List["NodeWithScore"] = self.vs_index.as_retriever().retrieve(' ')
|
||||
vector_store_preview = "\n".join(
|
||||
[f"{node.id_} | {node.text}" for node in dummy_retrieve_res]
|
||||
)
|
||||
logger.info('\n++ --------inspect_vector_store begin--------')
|
||||
logger.info(vector_store_preview)
|
||||
logger.info('oo --------inspect_vector_store end--------')
|
||||
return vector_store_preview
|
||||
@@ -0,0 +1,47 @@
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
supports_format = ['.csv', '.docx', '.doc', '.epub', '.ipynb', '.mbox', '.md', '.pdf', '.txt', '.ppt',
|
||||
'.pptm', '.pptx', '.py', '.xls', '.xlsx', '.html', '.json', '.xml', '.yaml', '.yml', '.m']
|
||||
|
||||
|
||||
def read_docx_doc(file_path):
|
||||
if file_path.split(".")[-1] == "docx":
|
||||
from docx import Document
|
||||
doc = Document(file_path)
|
||||
file_content = "\n".join([para.text for para in doc.paragraphs])
|
||||
else:
|
||||
try:
|
||||
import win32com.client
|
||||
word = win32com.client.Dispatch("Word.Application")
|
||||
word.visible = False
|
||||
# 打开文件
|
||||
doc = word.Documents.Open(os.getcwd() + '/' + file_path)
|
||||
# file_content = doc.Content.Text
|
||||
doc = word.ActiveDocument
|
||||
file_content = doc.Range().Text
|
||||
doc.Close()
|
||||
word.Quit()
|
||||
except:
|
||||
raise RuntimeError('请先将.doc文档转换为.docx文档。')
|
||||
return file_content
|
||||
|
||||
|
||||
# 修改后的 extract_text 函数,结合 SimpleDirectoryReader 和自定义解析逻辑
|
||||
import os
|
||||
|
||||
|
||||
def extract_text(file_path):
|
||||
_, ext = os.path.splitext(file_path.lower())
|
||||
|
||||
# 使用 SimpleDirectoryReader 处理它支持的文件格式
|
||||
if ext in ['.docx', '.doc']:
|
||||
return read_docx_doc(file_path)
|
||||
try:
|
||||
reader = SimpleDirectoryReader(input_files=[file_path])
|
||||
documents = reader.load_data()
|
||||
if len(documents) > 0:
|
||||
return documents[0].text
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,56 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.callbacks.base import CallbackManager
|
||||
from llama_index.core.schema import TransformComponent
|
||||
from llama_index.core.service_context import ServiceContext
|
||||
from llama_index.core.settings import (
|
||||
Settings,
|
||||
callback_manager_from_settings_or_context,
|
||||
transformations_from_settings_or_context,
|
||||
)
|
||||
from llama_index.core.storage.storage_context import StorageContext
|
||||
|
||||
|
||||
class GptacVectorStoreIndex(VectorStoreIndex):
|
||||
|
||||
@classmethod
|
||||
def default_vector_store(
|
||||
cls,
|
||||
storage_context: Optional[StorageContext] = None,
|
||||
show_progress: bool = False,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
transformations: Optional[List[TransformComponent]] = None,
|
||||
# deprecated
|
||||
service_context: Optional[ServiceContext] = None,
|
||||
embed_model=None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Create index from documents.
|
||||
|
||||
Args:
|
||||
documents (Optional[Sequence[BaseDocument]]): List of documents to
|
||||
build the index from.
|
||||
|
||||
"""
|
||||
storage_context = storage_context or StorageContext.from_defaults()
|
||||
docstore = storage_context.docstore
|
||||
callback_manager = (
|
||||
callback_manager
|
||||
or callback_manager_from_settings_or_context(Settings, service_context)
|
||||
)
|
||||
transformations = transformations or transformations_from_settings_or_context(
|
||||
Settings, service_context
|
||||
)
|
||||
|
||||
with callback_manager.as_trace("index_construction"):
|
||||
return cls(
|
||||
nodes=[],
|
||||
storage_context=storage_context,
|
||||
callback_manager=callback_manager,
|
||||
show_progress=show_progress,
|
||||
transformations=transformations,
|
||||
service_context=service_context,
|
||||
embed_model=embed_model,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -1,16 +1,17 @@
|
||||
# From project chatglm-langchain
|
||||
|
||||
import threading
|
||||
from toolbox import Singleton
|
||||
import os
|
||||
import shutil
|
||||
import os
|
||||
import uuid
|
||||
import tqdm
|
||||
import shutil
|
||||
import threading
|
||||
import numpy as np
|
||||
from toolbox import Singleton
|
||||
from loguru import logger
|
||||
from langchain.vectorstores import FAISS
|
||||
from langchain.docstore.document import Document
|
||||
from typing import List, Tuple
|
||||
import numpy as np
|
||||
from crazy_functions.vector_fns.general_file_loader import load_file
|
||||
|
||||
embedding_model_dict = {
|
||||
@@ -150,17 +151,17 @@ class LocalDocQA:
|
||||
failed_files = []
|
||||
if isinstance(filepath, str):
|
||||
if not os.path.exists(filepath):
|
||||
print("路径不存在")
|
||||
logger.error("路径不存在")
|
||||
return None
|
||||
elif os.path.isfile(filepath):
|
||||
file = os.path.split(filepath)[-1]
|
||||
try:
|
||||
docs = load_file(filepath, SENTENCE_SIZE)
|
||||
print(f"{file} 已成功加载")
|
||||
logger.info(f"{file} 已成功加载")
|
||||
loaded_files.append(filepath)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(f"{file} 未能成功加载")
|
||||
logger.error(e)
|
||||
logger.error(f"{file} 未能成功加载")
|
||||
return None
|
||||
elif os.path.isdir(filepath):
|
||||
docs = []
|
||||
@@ -170,23 +171,23 @@ class LocalDocQA:
|
||||
docs += load_file(fullfilepath, SENTENCE_SIZE)
|
||||
loaded_files.append(fullfilepath)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
logger.error(e)
|
||||
failed_files.append(file)
|
||||
|
||||
if len(failed_files) > 0:
|
||||
print("以下文件未能成功加载:")
|
||||
logger.error("以下文件未能成功加载:")
|
||||
for file in failed_files:
|
||||
print(f"{file}\n")
|
||||
logger.error(f"{file}\n")
|
||||
|
||||
else:
|
||||
docs = []
|
||||
for file in filepath:
|
||||
docs += load_file(file, SENTENCE_SIZE)
|
||||
print(f"{file} 已成功加载")
|
||||
logger.info(f"{file} 已成功加载")
|
||||
loaded_files.append(file)
|
||||
|
||||
if len(docs) > 0:
|
||||
print("文件加载完毕,正在生成向量库")
|
||||
logger.info("文件加载完毕,正在生成向量库")
|
||||
if vs_path and os.path.isdir(vs_path):
|
||||
try:
|
||||
self.vector_store = FAISS.load_local(vs_path, text2vec)
|
||||
@@ -233,7 +234,7 @@ class LocalDocQA:
|
||||
prompt += "\n\n".join([f"({k}): " + doc.page_content for k, doc in enumerate(related_docs_with_score)])
|
||||
prompt += "\n\n---\n\n"
|
||||
prompt = prompt.encode('utf-8', 'ignore').decode() # avoid reading non-utf8 chars
|
||||
# print(prompt)
|
||||
# logger.info(prompt)
|
||||
response = {"query": query, "source_documents": related_docs_with_score}
|
||||
return response, prompt
|
||||
|
||||
@@ -262,7 +263,7 @@ def construct_vector_store(vs_id, vs_path, files, sentence_size, history, one_co
|
||||
else:
|
||||
pass
|
||||
# file_status = "文件未成功加载,请重新上传文件"
|
||||
# print(file_status)
|
||||
# logger.info(file_status)
|
||||
return local_doc_qa, vs_path
|
||||
|
||||
@Singleton
|
||||
@@ -278,7 +279,7 @@ class knowledge_archive_interface():
|
||||
if self.text2vec_large_chinese is None:
|
||||
# < -------------------预热文本向量化模组--------------- >
|
||||
from toolbox import ProxyNetworkActivate
|
||||
print('Checking Text2vec ...')
|
||||
logger.info('Checking Text2vec ...')
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
with ProxyNetworkActivate('Download_LLM'): # 临时地激活代理网络
|
||||
self.text2vec_large_chinese = HuggingFaceEmbeddings(model_name="GanymedeNil/text2vec-large-chinese")
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
import re, requests, unicodedata, os
|
||||
from toolbox import update_ui, get_log_folder
|
||||
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
||||
from toolbox import CatchException, report_exception, get_conf
|
||||
import re, requests, unicodedata, os
|
||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from loguru import logger
|
||||
|
||||
def download_arxiv_(url_pdf):
|
||||
if 'arxiv.org' not in url_pdf:
|
||||
if ('.' in url_pdf) and ('/' not in url_pdf):
|
||||
new_url = 'https://arxiv.org/abs/'+url_pdf
|
||||
print('下载编号:', url_pdf, '自动定位:', new_url)
|
||||
logger.info('下载编号:', url_pdf, '自动定位:', new_url)
|
||||
# download_arxiv_(new_url)
|
||||
return download_arxiv_(new_url)
|
||||
else:
|
||||
print('不能识别的URL!')
|
||||
logger.info('不能识别的URL!')
|
||||
return None
|
||||
if 'abs' in url_pdf:
|
||||
url_pdf = url_pdf.replace('abs', 'pdf')
|
||||
@@ -42,15 +44,12 @@ def download_arxiv_(url_pdf):
|
||||
requests_pdf_url = url_pdf
|
||||
file_path = download_dir+title_str
|
||||
|
||||
print('下载中')
|
||||
logger.info('下载中')
|
||||
proxies = get_conf('proxies')
|
||||
r = requests.get(requests_pdf_url, proxies=proxies)
|
||||
with open(file_path, 'wb+') as f:
|
||||
f.write(r.content)
|
||||
print('下载完成')
|
||||
|
||||
# print('输出下载命令:','aria2c -o \"%s\" %s'%(title_str,url_pdf))
|
||||
# subprocess.call('aria2c --all-proxy=\"172.18.116.150:11084\" -o \"%s\" %s'%(download_dir+title_str,url_pdf), shell=True)
|
||||
logger.info('下载完成')
|
||||
|
||||
x = "%s %s %s.bib" % (paper_id, other_info['year'], other_info['authors'])
|
||||
x = x.replace('?', '?')\
|
||||
@@ -63,19 +62,9 @@ def download_arxiv_(url_pdf):
|
||||
|
||||
|
||||
def get_name(_url_):
|
||||
import os
|
||||
from bs4 import BeautifulSoup
|
||||
print('正在获取文献名!')
|
||||
print(_url_)
|
||||
|
||||
# arxiv_recall = {}
|
||||
# if os.path.exists('./arxiv_recall.pkl'):
|
||||
# with open('./arxiv_recall.pkl', 'rb') as f:
|
||||
# arxiv_recall = pickle.load(f)
|
||||
|
||||
# if _url_ in arxiv_recall:
|
||||
# print('在缓存中')
|
||||
# return arxiv_recall[_url_]
|
||||
logger.info('正在获取文献名!')
|
||||
logger.info(_url_)
|
||||
|
||||
proxies = get_conf('proxies')
|
||||
res = requests.get(_url_, proxies=proxies)
|
||||
@@ -92,7 +81,7 @@ def get_name(_url_):
|
||||
other_details['abstract'] = abstract
|
||||
except:
|
||||
other_details['year'] = ''
|
||||
print('年份获取失败')
|
||||
logger.info('年份获取失败')
|
||||
|
||||
# get author
|
||||
try:
|
||||
@@ -101,7 +90,7 @@ def get_name(_url_):
|
||||
other_details['authors'] = authors
|
||||
except:
|
||||
other_details['authors'] = ''
|
||||
print('authors获取失败')
|
||||
logger.info('authors获取失败')
|
||||
|
||||
# get comment
|
||||
try:
|
||||
@@ -116,11 +105,11 @@ def get_name(_url_):
|
||||
other_details['comment'] = ''
|
||||
except:
|
||||
other_details['comment'] = ''
|
||||
print('年份获取失败')
|
||||
logger.info('年份获取失败')
|
||||
|
||||
title_str = BeautifulSoup(
|
||||
res.text, 'html.parser').find('title').contents[0]
|
||||
print('获取成功:', title_str)
|
||||
logger.info('获取成功:', title_str)
|
||||
# arxiv_recall[_url_] = (title_str+'.pdf', other_details)
|
||||
# with open('./arxiv_recall.pkl', 'wb') as f:
|
||||
# pickle.dump(arxiv_recall, f)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from toolbox import CatchException, update_ui
|
||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
|
||||
@CatchException
|
||||
def 交互功能模板函数(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
|
||||
@@ -16,8 +16,8 @@ Testing:
|
||||
|
||||
from toolbox import CatchException, update_ui, gen_time_str, trimmed_format_exc, is_the_upload_folder
|
||||
from toolbox import promote_file_to_downloadzone, get_log_folder, update_ui_lastest_msg
|
||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, get_plugin_arg
|
||||
from .crazy_utils import input_clipping, try_install_deps
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, get_plugin_arg
|
||||
from crazy_functions.crazy_utils import input_clipping, try_install_deps
|
||||
from crazy_functions.gen_fns.gen_fns_shared import is_function_successfully_generated
|
||||
from crazy_functions.gen_fns.gen_fns_shared import get_class_name
|
||||
from crazy_functions.gen_fns.gen_fns_shared import subprocess_worker
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from toolbox import CatchException, update_ui, gen_time_str
|
||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from .crazy_utils import input_clipping
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from crazy_functions.crazy_utils import input_clipping
|
||||
import copy, json
|
||||
|
||||
@CatchException
|
||||
|
||||
@@ -6,13 +6,14 @@
|
||||
"""
|
||||
|
||||
|
||||
import time
|
||||
from toolbox import CatchException, update_ui, gen_time_str, trimmed_format_exc, ProxyNetworkActivate
|
||||
from toolbox import get_conf, select_api_key, update_ui_lastest_msg, Singleton
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, get_plugin_arg
|
||||
from crazy_functions.crazy_utils import input_clipping, try_install_deps
|
||||
from crazy_functions.agent_fns.persistent import GradioMultiuserManagerForPersistentClasses
|
||||
from crazy_functions.agent_fns.auto_agent import AutoGenMath
|
||||
import time
|
||||
from loguru import logger
|
||||
|
||||
def remove_model_prefix(llm):
|
||||
if llm.startswith('api2d-'): llm = llm.replace('api2d-', '')
|
||||
@@ -80,12 +81,12 @@ def 多智能体终端(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_
|
||||
persistent_key = f"{user_uuid}->多智能体终端"
|
||||
if persistent_class_multi_user_manager.already_alive(persistent_key):
|
||||
# 当已经存在一个正在运行的多智能体终端时,直接将用户输入传递给它,而不是再次启动一个新的多智能体终端
|
||||
print('[debug] feed new user input')
|
||||
logger.info('[debug] feed new user input')
|
||||
executor = persistent_class_multi_user_manager.get(persistent_key)
|
||||
exit_reason = yield from executor.main_process_ui_control(txt, create_or_resume="resume")
|
||||
else:
|
||||
# 运行多智能体终端 (首次)
|
||||
print('[debug] create new executor instance')
|
||||
logger.info('[debug] create new executor instance')
|
||||
history = []
|
||||
chatbot.append(["正在启动: 多智能体终端", "插件动态生成, 执行开始, 作者 Microsoft & Binary-Husky."])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
@@ -1,127 +0,0 @@
|
||||
from toolbox import update_ui
|
||||
from toolbox import CatchException, report_exception
|
||||
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
fast_debug = False
|
||||
|
||||
|
||||
def 解析docx(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
||||
import time, os
|
||||
# pip install python-docx 用于docx格式,跨平台
|
||||
# pip install pywin32 用于doc格式,仅支持Win平台
|
||||
for index, fp in enumerate(file_manifest):
|
||||
if fp.split(".")[-1] == "docx":
|
||||
from docx import Document
|
||||
doc = Document(fp)
|
||||
file_content = "\n".join([para.text for para in doc.paragraphs])
|
||||
else:
|
||||
try:
|
||||
import win32com.client
|
||||
word = win32com.client.Dispatch("Word.Application")
|
||||
word.visible = False
|
||||
# 打开文件
|
||||
doc = word.Documents.Open(os.getcwd() + '/' + fp)
|
||||
# file_content = doc.Content.Text
|
||||
doc = word.ActiveDocument
|
||||
file_content = doc.Range().Text
|
||||
doc.Close()
|
||||
word.Quit()
|
||||
except:
|
||||
raise RuntimeError('请先将.doc文档转换为.docx文档。')
|
||||
|
||||
# private_upload里面的文件名在解压zip后容易出现乱码(rar和7z格式正常),故可以只分析文章内容,不输入文件名
|
||||
from crazy_functions.pdf_fns.breakdown_txt import breakdown_text_to_satisfy_token_limit
|
||||
from request_llms.bridge_all import model_info
|
||||
max_token = model_info[llm_kwargs['llm_model']]['max_token']
|
||||
TOKEN_LIMIT_PER_FRAGMENT = max_token * 3 // 4
|
||||
paper_fragments = breakdown_text_to_satisfy_token_limit(txt=file_content, limit=TOKEN_LIMIT_PER_FRAGMENT, llm_model=llm_kwargs['llm_model'])
|
||||
this_paper_history = []
|
||||
for i, paper_frag in enumerate(paper_fragments):
|
||||
i_say = f'请对下面的文章片段用中文做概述,文件名是{os.path.relpath(fp, project_folder)},文章内容是 ```{paper_frag}```'
|
||||
i_say_show_user = f'请对下面的文章片段做概述: {os.path.abspath(fp)}的第{i+1}/{len(paper_fragments)}个片段。'
|
||||
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||
inputs=i_say,
|
||||
inputs_show_user=i_say_show_user,
|
||||
llm_kwargs=llm_kwargs,
|
||||
chatbot=chatbot,
|
||||
history=[],
|
||||
sys_prompt="总结文章。"
|
||||
)
|
||||
|
||||
chatbot[-1] = (i_say_show_user, gpt_say)
|
||||
history.extend([i_say_show_user,gpt_say])
|
||||
this_paper_history.extend([i_say_show_user,gpt_say])
|
||||
|
||||
# 已经对该文章的所有片段总结完毕,如果文章被切分了,
|
||||
if len(paper_fragments) > 1:
|
||||
i_say = f"根据以上的对话,总结文章{os.path.abspath(fp)}的主要内容。"
|
||||
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||
inputs=i_say,
|
||||
inputs_show_user=i_say,
|
||||
llm_kwargs=llm_kwargs,
|
||||
chatbot=chatbot,
|
||||
history=this_paper_history,
|
||||
sys_prompt="总结文章。"
|
||||
)
|
||||
|
||||
history.extend([i_say,gpt_say])
|
||||
this_paper_history.extend([i_say,gpt_say])
|
||||
|
||||
res = write_history_to_file(history)
|
||||
promote_file_to_downloadzone(res, chatbot=chatbot)
|
||||
chatbot.append(("完成了吗?", res))
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
res = write_history_to_file(history)
|
||||
promote_file_to_downloadzone(res, chatbot=chatbot)
|
||||
chatbot.append(("所有文件都总结完成了吗?", res))
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
|
||||
@CatchException
|
||||
def 总结word文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
import glob, os
|
||||
|
||||
# 基本信息:功能、贡献者
|
||||
chatbot.append([
|
||||
"函数插件功能?",
|
||||
"批量总结Word文档。函数插件贡献者: JasonGuo1。注意, 如果是.doc文件, 请先转化为.docx格式。"])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
# 尝试导入依赖,如果缺少依赖,则给出安装建议
|
||||
try:
|
||||
from docx import Document
|
||||
except:
|
||||
report_exception(chatbot, history,
|
||||
a=f"解析项目: {txt}",
|
||||
b=f"导入软件依赖失败。使用该模块需要额外依赖,安装方法```pip install --upgrade python-docx pywin32```。")
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
return
|
||||
|
||||
# 清空历史,以免输入溢出
|
||||
history = []
|
||||
|
||||
# 检测输入参数,如没有给定输入参数,直接退出
|
||||
if os.path.exists(txt):
|
||||
project_folder = txt
|
||||
else:
|
||||
if txt == "": txt = '空空如也的输入栏'
|
||||
report_exception(chatbot, history, a=f"解析项目: {txt}", b=f"找不到本地项目或无权访问: {txt}")
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
return
|
||||
|
||||
# 搜索需要处理的文件清单
|
||||
if txt.endswith('.docx') or txt.endswith('.doc'):
|
||||
file_manifest = [txt]
|
||||
else:
|
||||
file_manifest = [f for f in glob.glob(f'{project_folder}/**/*.docx', recursive=True)] + \
|
||||
[f for f in glob.glob(f'{project_folder}/**/*.doc', recursive=True)]
|
||||
|
||||
# 如果没找到任何文件
|
||||
if len(file_manifest) == 0:
|
||||
report_exception(chatbot, history, a=f"解析项目: {txt}", b=f"找不到任何.docx或doc文件: {txt}")
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
return
|
||||
|
||||
# 开始正式执行任务
|
||||
yield from 解析docx(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt)
|
||||
@@ -1,5 +1,5 @@
|
||||
from toolbox import CatchException, report_exception, select_api_key, update_ui, get_conf
|
||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from toolbox import write_history_to_file, promote_file_to_downloadzone, get_log_folder
|
||||
|
||||
def split_audio_file(filename, split_duration=1000):
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
from loguru import logger
|
||||
|
||||
from toolbox import update_ui, promote_file_to_downloadzone, gen_time_str
|
||||
from toolbox import CatchException, report_exception
|
||||
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from .crazy_utils import read_and_clean_pdf_text
|
||||
from .crazy_utils import input_clipping
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from crazy_functions.crazy_utils import read_and_clean_pdf_text
|
||||
from crazy_functions.crazy_utils import input_clipping
|
||||
|
||||
|
||||
|
||||
def 解析PDF(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
||||
file_write_buffer = []
|
||||
for file_name in file_manifest:
|
||||
print('begin analysis on:', file_name)
|
||||
logger.info('begin analysis on:', file_name)
|
||||
############################## <第 0 步,切割PDF> ##################################
|
||||
# 递归地切割PDF文件,每一块(尽量是完整的一个section,比如introduction,experiment等,必要时再进行切割)
|
||||
# 的长度必须小于 2500 个 Token
|
||||
@@ -38,7 +40,7 @@ def 解析PDF(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot,
|
||||
last_iteration_result = paper_meta # 初始值是摘要
|
||||
MAX_WORD_TOTAL = 4096 * 0.7
|
||||
n_fragment = len(paper_fragments)
|
||||
if n_fragment >= 20: print('文章极长,不能达到预期效果')
|
||||
if n_fragment >= 20: logger.warning('文章极长,不能达到预期效果')
|
||||
for i in range(n_fragment):
|
||||
NUM_OF_WORD = MAX_WORD_TOTAL // n_fragment
|
||||
i_say = f"Read this section, recapitulate the content of this section with less than {NUM_OF_WORD} Chinese characters: {paper_fragments[i]}"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from loguru import logger
|
||||
from toolbox import update_ui
|
||||
from toolbox import CatchException, report_exception
|
||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
||||
|
||||
fast_debug = False
|
||||
@@ -57,7 +58,6 @@ def readPdf(pdfPath):
|
||||
layout = device.get_result()
|
||||
for obj in layout._objs:
|
||||
if isinstance(obj, pdfminer.layout.LTTextBoxHorizontal):
|
||||
# print(obj.get_text())
|
||||
outTextList.append(obj.get_text())
|
||||
|
||||
return outTextList
|
||||
@@ -66,7 +66,7 @@ def readPdf(pdfPath):
|
||||
def 解析Paper(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
||||
import time, glob, os
|
||||
from bs4 import BeautifulSoup
|
||||
print('begin analysis on:', file_manifest)
|
||||
logger.info('begin analysis on:', file_manifest)
|
||||
for index, fp in enumerate(file_manifest):
|
||||
if ".tex" in fp:
|
||||
with open(fp, 'r', encoding='utf-8', errors='replace') as f:
|
||||
|
||||
496
crazy_functions/批量文件询问.py
普通文件
496
crazy_functions/批量文件询问.py
普通文件
@@ -0,0 +1,496 @@
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple, Dict, Generator
|
||||
|
||||
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||
from crazy_functions.pdf_fns.breakdown_txt import breakdown_text_to_satisfy_token_limit
|
||||
from crazy_functions.rag_fns.rag_file_support import extract_text
|
||||
from request_llms.bridge_all import model_info
|
||||
from toolbox import update_ui, CatchException, report_exception
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileFragment:
|
||||
"""文件片段数据类,用于组织处理单元"""
|
||||
file_path: str
|
||||
content: str
|
||||
rel_path: str
|
||||
fragment_index: int
|
||||
total_fragments: int
|
||||
|
||||
|
||||
class BatchDocumentSummarizer:
|
||||
"""优化的文档总结器 - 批处理版本"""
|
||||
|
||||
def __init__(self, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List, history: List, system_prompt: str):
|
||||
"""初始化总结器"""
|
||||
self.llm_kwargs = llm_kwargs
|
||||
self.plugin_kwargs = plugin_kwargs
|
||||
self.chatbot = chatbot
|
||||
self.history = history
|
||||
self.system_prompt = system_prompt
|
||||
self.failed_files = []
|
||||
self.file_summaries_map = {}
|
||||
|
||||
def _get_token_limit(self) -> int:
|
||||
"""获取模型token限制"""
|
||||
max_token = model_info[self.llm_kwargs['llm_model']]['max_token']
|
||||
return max_token * 3 // 4
|
||||
|
||||
def _create_batch_inputs(self, fragments: List[FileFragment]) -> Tuple[List, List, List]:
|
||||
"""创建批处理输入"""
|
||||
inputs_array = []
|
||||
inputs_show_user_array = []
|
||||
history_array = []
|
||||
|
||||
for frag in fragments:
|
||||
if self.plugin_kwargs.get("advanced_arg"):
|
||||
i_say = (f'请按照用户要求对文件内容进行处理,文件名为{os.path.basename(frag.file_path)},'
|
||||
f'用户要求为:{self.plugin_kwargs["advanced_arg"]}:'
|
||||
f'文件内容是 ```{frag.content}```')
|
||||
i_say_show_user = (f'正在处理 {frag.rel_path} (片段 {frag.fragment_index + 1}/{frag.total_fragments})')
|
||||
else:
|
||||
i_say = (f'请对下面的内容用中文做总结,不超过500字,文件名是{os.path.basename(frag.file_path)},'
|
||||
f'内容是 ```{frag.content}```')
|
||||
i_say_show_user = f'正在处理 {frag.rel_path} (片段 {frag.fragment_index + 1}/{frag.total_fragments})'
|
||||
|
||||
inputs_array.append(i_say)
|
||||
inputs_show_user_array.append(i_say_show_user)
|
||||
history_array.append([])
|
||||
|
||||
return inputs_array, inputs_show_user_array, history_array
|
||||
|
||||
def _process_single_file_with_timeout(self, file_info: Tuple[str, str], mutable_status: List) -> List[FileFragment]:
|
||||
"""包装了超时控制的文件处理函数"""
|
||||
|
||||
def timeout_handler():
|
||||
thread = threading.current_thread()
|
||||
if hasattr(thread, '_timeout_occurred'):
|
||||
thread._timeout_occurred = True
|
||||
|
||||
# 设置超时标记
|
||||
thread = threading.current_thread()
|
||||
thread._timeout_occurred = False
|
||||
|
||||
# 设置超时定时器
|
||||
timer = threading.Timer(self.watch_dog_patience, timeout_handler)
|
||||
timer.start()
|
||||
|
||||
try:
|
||||
fp, project_folder = file_info
|
||||
fragments = []
|
||||
|
||||
# 定期检查是否超时
|
||||
def check_timeout():
|
||||
if hasattr(thread, '_timeout_occurred') and thread._timeout_occurred:
|
||||
raise TimeoutError("处理超时")
|
||||
|
||||
# 更新状态
|
||||
mutable_status[0] = "检查文件大小"
|
||||
mutable_status[1] = time.time()
|
||||
check_timeout()
|
||||
|
||||
# 文件大小检查
|
||||
if os.path.getsize(fp) > self.max_file_size:
|
||||
self.failed_files.append((fp, f"文件过大:超过{self.max_file_size / 1024 / 1024}MB"))
|
||||
mutable_status[2] = "文件过大"
|
||||
return fragments
|
||||
|
||||
check_timeout()
|
||||
|
||||
# 更新状态
|
||||
mutable_status[0] = "提取文件内容"
|
||||
mutable_status[1] = time.time()
|
||||
|
||||
# 提取内容
|
||||
content = extract_text(fp)
|
||||
if content is None:
|
||||
self.failed_files.append((fp, "文件解析失败:不支持的格式或文件损坏"))
|
||||
mutable_status[2] = "格式不支持"
|
||||
return fragments
|
||||
elif not content.strip():
|
||||
self.failed_files.append((fp, "文件内容为空"))
|
||||
mutable_status[2] = "内容为空"
|
||||
return fragments
|
||||
|
||||
check_timeout()
|
||||
|
||||
# 更新状态
|
||||
mutable_status[0] = "分割文本"
|
||||
mutable_status[1] = time.time()
|
||||
|
||||
# 分割文本
|
||||
try:
|
||||
paper_fragments = breakdown_text_to_satisfy_token_limit(
|
||||
txt=content,
|
||||
limit=self._get_token_limit(),
|
||||
llm_model=self.llm_kwargs['llm_model']
|
||||
)
|
||||
except Exception as e:
|
||||
self.failed_files.append((fp, f"文本分割失败:{str(e)}"))
|
||||
mutable_status[2] = "分割失败"
|
||||
return fragments
|
||||
|
||||
check_timeout()
|
||||
|
||||
# 处理片段
|
||||
rel_path = os.path.relpath(fp, project_folder)
|
||||
for i, frag in enumerate(paper_fragments):
|
||||
if frag.strip():
|
||||
fragments.append(FileFragment(
|
||||
file_path=fp,
|
||||
content=frag,
|
||||
rel_path=rel_path,
|
||||
fragment_index=i,
|
||||
total_fragments=len(paper_fragments)
|
||||
))
|
||||
|
||||
mutable_status[2] = "处理完成"
|
||||
return fragments
|
||||
|
||||
except TimeoutError as e:
|
||||
self.failed_files.append((fp, "处理超时"))
|
||||
mutable_status[2] = "处理超时"
|
||||
return []
|
||||
except Exception as e:
|
||||
self.failed_files.append((fp, f"处理失败:{str(e)}"))
|
||||
mutable_status[2] = "处理异常"
|
||||
return []
|
||||
finally:
|
||||
timer.cancel()
|
||||
|
||||
def prepare_fragments(self, project_folder: str, file_paths: List[str]) -> Generator:
|
||||
import concurrent.futures
|
||||
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Generator, List
|
||||
"""并行准备所有文件的处理片段"""
|
||||
all_fragments = []
|
||||
total_files = len(file_paths)
|
||||
|
||||
# 配置参数
|
||||
self.refresh_interval = 0.2 # UI刷新间隔
|
||||
self.watch_dog_patience = 5 # 看门狗超时时间
|
||||
self.max_file_size = 10 * 1024 * 1024 # 10MB限制
|
||||
self.max_workers = min(32, len(file_paths)) # 最多32个线程
|
||||
|
||||
# 创建有超时控制的线程池
|
||||
executor = ThreadPoolExecutor(max_workers=self.max_workers)
|
||||
|
||||
# 用于跨线程状态传递的可变列表 - 增加文件名信息
|
||||
mutable_status_array = [["等待中", time.time(), "pending", file_path] for file_path in file_paths]
|
||||
|
||||
# 创建文件处理任务
|
||||
file_infos = [(fp, project_folder) for fp in file_paths]
|
||||
|
||||
# 提交所有任务,使用带超时控制的处理函数
|
||||
futures = [
|
||||
executor.submit(
|
||||
self._process_single_file_with_timeout,
|
||||
file_info,
|
||||
mutable_status_array[i]
|
||||
) for i, file_info in enumerate(file_infos)
|
||||
]
|
||||
|
||||
# 更新UI的计数器
|
||||
cnt = 0
|
||||
|
||||
try:
|
||||
# 监控任务执行
|
||||
while True:
|
||||
time.sleep(self.refresh_interval)
|
||||
cnt += 1
|
||||
|
||||
# 检查任务完成状态
|
||||
worker_done = [f.done() for f in futures]
|
||||
|
||||
# 更新状态显示
|
||||
status_str = ""
|
||||
for i, (status, timestamp, desc, file_path) in enumerate(mutable_status_array):
|
||||
# 获取文件名(去掉路径)
|
||||
file_name = os.path.basename(file_path)
|
||||
if worker_done[i]:
|
||||
status_str += f"文件 {file_name}: {desc}\n"
|
||||
else:
|
||||
status_str += f"文件 {file_name}: {status} {desc}\n"
|
||||
|
||||
# 更新UI
|
||||
self.chatbot[-1] = [
|
||||
"处理进度",
|
||||
f"正在处理文件...\n\n{status_str}" + "." * (cnt % 10 + 1)
|
||||
]
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
# 检查是否所有任务完成
|
||||
if all(worker_done):
|
||||
break
|
||||
|
||||
finally:
|
||||
# 确保线程池正确关闭
|
||||
executor.shutdown(wait=False)
|
||||
|
||||
# 收集结果
|
||||
processed_files = 0
|
||||
for future in futures:
|
||||
try:
|
||||
fragments = future.result(timeout=0.1) # 给予一个短暂的超时时间来获取结果
|
||||
all_fragments.extend(fragments)
|
||||
processed_files += 1
|
||||
except concurrent.futures.TimeoutError:
|
||||
# 处理获取结果超时
|
||||
file_index = futures.index(future)
|
||||
self.failed_files.append((file_paths[file_index], "结果获取超时"))
|
||||
continue
|
||||
except Exception as e:
|
||||
# 处理其他异常
|
||||
file_index = futures.index(future)
|
||||
self.failed_files.append((file_paths[file_index], f"未知错误:{str(e)}"))
|
||||
continue
|
||||
|
||||
# 最终进度更新
|
||||
self.chatbot.append([
|
||||
"文件处理完成",
|
||||
f"成功处理 {len(all_fragments)} 个片段,失败 {len(self.failed_files)} 个文件"
|
||||
])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
return all_fragments
|
||||
|
||||
def _process_fragments_batch(self, fragments: List[FileFragment]) -> Generator:
|
||||
"""批量处理文件片段"""
|
||||
from collections import defaultdict
|
||||
batch_size = 64 # 每批处理的片段数
|
||||
max_retries = 3 # 最大重试次数
|
||||
retry_delay = 5 # 重试延迟(秒)
|
||||
results = defaultdict(list)
|
||||
|
||||
# 按批次处理
|
||||
for i in range(0, len(fragments), batch_size):
|
||||
batch = fragments[i:i + batch_size]
|
||||
|
||||
inputs_array, inputs_show_user_array, history_array = self._create_batch_inputs(batch)
|
||||
sys_prompt_array = ["请总结以下内容:"] * len(batch)
|
||||
|
||||
# 添加重试机制
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
|
||||
inputs_array=inputs_array,
|
||||
inputs_show_user_array=inputs_show_user_array,
|
||||
llm_kwargs=self.llm_kwargs,
|
||||
chatbot=self.chatbot,
|
||||
history_array=history_array,
|
||||
sys_prompt_array=sys_prompt_array,
|
||||
)
|
||||
|
||||
# 处理响应
|
||||
for j, frag in enumerate(batch):
|
||||
summary = response_collection[j * 2 + 1]
|
||||
if summary and summary.strip():
|
||||
results[frag.rel_path].append({
|
||||
'index': frag.fragment_index,
|
||||
'summary': summary,
|
||||
'total': frag.total_fragments
|
||||
})
|
||||
break # 成功处理,跳出重试循环
|
||||
|
||||
except Exception as e:
|
||||
if retry == max_retries - 1: # 最后一次重试失败
|
||||
for frag in batch:
|
||||
self.failed_files.append((frag.file_path, f"处理失败:{str(e)}"))
|
||||
else:
|
||||
yield from update_ui(self.chatbot.append([f"批次处理失败,{retry_delay}秒后重试...", str(e)]))
|
||||
time.sleep(retry_delay)
|
||||
|
||||
return results
|
||||
|
||||
def _generate_final_summary_request(self) -> Tuple[List, List, List]:
|
||||
"""准备最终总结请求"""
|
||||
if not self.file_summaries_map:
|
||||
return (["无可用的文件总结"], ["生成最终总结"], [[]])
|
||||
|
||||
summaries = list(self.file_summaries_map.values())
|
||||
if all(not summary for summary in summaries):
|
||||
return (["所有文件处理均失败"], ["生成最终总结"], [[]])
|
||||
|
||||
if self.plugin_kwargs.get("advanced_arg"):
|
||||
i_say = "根据以上所有文件的处理结果,按要求进行综合处理:" + self.plugin_kwargs['advanced_arg']
|
||||
else:
|
||||
i_say = "请根据以上所有文件的处理结果,生成最终的总结,不超过1000字。"
|
||||
|
||||
return ([i_say], [i_say], [summaries])
|
||||
|
||||
def process_files(self, project_folder: str, file_paths: List[str]) -> Generator:
|
||||
"""处理所有文件"""
|
||||
total_files = len(file_paths)
|
||||
self.chatbot.append([f"开始处理", f"总计 {total_files} 个文件"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
# 1. 准备所有文件片段
|
||||
# 在 process_files 函数中:
|
||||
fragments = yield from self.prepare_fragments(project_folder, file_paths)
|
||||
if not fragments:
|
||||
self.chatbot.append(["处理失败", "没有可处理的文件内容"])
|
||||
return "没有可处理的文件内容"
|
||||
|
||||
# 2. 批量处理所有文件片段
|
||||
self.chatbot.append([f"文件分析", f"共计 {len(fragments)} 个处理单元"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
try:
|
||||
file_summaries = yield from self._process_fragments_batch(fragments)
|
||||
except Exception as e:
|
||||
self.chatbot.append(["处理错误", f"批处理过程失败:{str(e)}"])
|
||||
return "处理过程发生错误"
|
||||
|
||||
# 3. 为每个文件生成整体总结
|
||||
self.chatbot.append(["生成总结", "正在汇总文件内容..."])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
# 处理每个文件的总结
|
||||
for rel_path, summaries in file_summaries.items():
|
||||
if len(summaries) > 1: # 多片段文件需要生成整体总结
|
||||
sorted_summaries = sorted(summaries, key=lambda x: x['index'])
|
||||
if self.plugin_kwargs.get("advanced_arg"):
|
||||
|
||||
i_say = f'请按照用户要求对文件内容进行处理,用户要求为:{self.plugin_kwargs["advanced_arg"]}:'
|
||||
else:
|
||||
i_say = f"请总结文件 {os.path.basename(rel_path)} 的主要内容,不超过500字。"
|
||||
|
||||
try:
|
||||
summary_texts = [s['summary'] for s in sorted_summaries]
|
||||
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
|
||||
inputs_array=[i_say],
|
||||
inputs_show_user_array=[f"生成 {rel_path} 的处理结果"],
|
||||
llm_kwargs=self.llm_kwargs,
|
||||
chatbot=self.chatbot,
|
||||
history_array=[summary_texts],
|
||||
sys_prompt_array=["你是一个优秀的助手,"],
|
||||
)
|
||||
self.file_summaries_map[rel_path] = response_collection[1]
|
||||
except Exception as e:
|
||||
self.chatbot.append(["警告", f"文件 {rel_path} 总结生成失败:{str(e)}"])
|
||||
self.file_summaries_map[rel_path] = "总结生成失败"
|
||||
else: # 单片段文件直接使用其唯一的总结
|
||||
self.file_summaries_map[rel_path] = summaries[0]['summary']
|
||||
|
||||
# 4. 生成最终总结
|
||||
if total_files ==1:
|
||||
return "文件数为1,此时不调用总结模块"
|
||||
else:
|
||||
try:
|
||||
# 收集所有文件的总结用于生成最终总结
|
||||
file_summaries_for_final = []
|
||||
for rel_path, summary in self.file_summaries_map.items():
|
||||
file_summaries_for_final.append(f"文件 {rel_path} 的总结:\n{summary}")
|
||||
|
||||
if self.plugin_kwargs.get("advanced_arg"):
|
||||
final_summary_prompt = ("根据以下所有文件的总结内容,按要求进行综合处理:" +
|
||||
self.plugin_kwargs['advanced_arg'])
|
||||
else:
|
||||
final_summary_prompt = "请根据以下所有文件的总结内容,生成最终的总结报告。"
|
||||
|
||||
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
|
||||
inputs_array=[final_summary_prompt],
|
||||
inputs_show_user_array=["生成最终总结报告"],
|
||||
llm_kwargs=self.llm_kwargs,
|
||||
chatbot=self.chatbot,
|
||||
history_array=[file_summaries_for_final],
|
||||
sys_prompt_array=["总结所有文件内容。"],
|
||||
max_workers=1
|
||||
)
|
||||
|
||||
return response_collection[1] if len(response_collection) > 1 else "生成总结失败"
|
||||
except Exception as e:
|
||||
self.chatbot.append(["错误", f"最终总结生成失败:{str(e)}"])
|
||||
return "生成总结失败"
|
||||
|
||||
def save_results(self, final_summary: str):
|
||||
"""保存结果到文件"""
|
||||
from toolbox import promote_file_to_downloadzone, write_history_to_file
|
||||
from crazy_functions.doc_fns.batch_file_query_doc import MarkdownFormatter, HtmlFormatter, WordFormatter
|
||||
import os
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# 创建各种格式化器
|
||||
md_formatter = MarkdownFormatter(final_summary, self.file_summaries_map, self.failed_files)
|
||||
html_formatter = HtmlFormatter(final_summary, self.file_summaries_map, self.failed_files)
|
||||
word_formatter = WordFormatter(final_summary, self.file_summaries_map, self.failed_files)
|
||||
|
||||
result_files = []
|
||||
|
||||
# 保存 Markdown
|
||||
md_content = md_formatter.create_document()
|
||||
result_file_md = write_history_to_file(
|
||||
history=[md_content], # 直接传入内容列表
|
||||
file_basename=f"文档总结_{timestamp}.md"
|
||||
)
|
||||
result_files.append(result_file_md)
|
||||
|
||||
# 保存 HTML
|
||||
html_content = html_formatter.create_document()
|
||||
result_file_html = write_history_to_file(
|
||||
history=[html_content],
|
||||
file_basename=f"文档总结_{timestamp}.html"
|
||||
)
|
||||
result_files.append(result_file_html)
|
||||
|
||||
# 保存 Word
|
||||
doc = word_formatter.create_document()
|
||||
# 由于 Word 文档需要用 doc.save(),我们使用与 md 文件相同的目录
|
||||
result_file_docx = os.path.join(
|
||||
os.path.dirname(result_file_md),
|
||||
f"文档总结_{timestamp}.docx"
|
||||
)
|
||||
doc.save(result_file_docx)
|
||||
result_files.append(result_file_docx)
|
||||
|
||||
# 添加到下载区
|
||||
for file in result_files:
|
||||
promote_file_to_downloadzone(file, chatbot=self.chatbot)
|
||||
|
||||
self.chatbot.append(["处理完成", f"结果已保存至: {', '.join(result_files)}"])
|
||||
@CatchException
|
||||
def 批量文件询问(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
|
||||
history: List, system_prompt: str, user_request: str):
|
||||
"""主函数 - 优化版本"""
|
||||
# 初始化
|
||||
import glob
|
||||
import re
|
||||
from crazy_functions.rag_fns.rag_file_support import supports_format
|
||||
from toolbox import report_exception
|
||||
|
||||
summarizer = BatchDocumentSummarizer(llm_kwargs, plugin_kwargs, chatbot, history, system_prompt)
|
||||
chatbot.append(["函数插件功能", f"作者:lbykkkk,批量总结文件。支持格式: {', '.join(supports_format)}等其他文本格式文件,如果长时间卡在文件处理过程,请查看处理进度,然后删除所有处于“pending”状态的文件,然后重新上传处理。"])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
# 验证输入路径
|
||||
if not os.path.exists(txt):
|
||||
report_exception(chatbot, history, a=f"解析项目: {txt}", b=f"找不到项目或无权访问: {txt}")
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
# 获取文件列表
|
||||
project_folder = txt
|
||||
extract_folder = next((d for d in glob.glob(f'{project_folder}/*')
|
||||
if os.path.isdir(d) and d.endswith('.extract')), project_folder)
|
||||
|
||||
exclude_patterns = r'/[^/]+\.(zip|rar|7z|tar|gz)$'
|
||||
file_manifest = [f for f in glob.glob(f'{extract_folder}/**', recursive=True)
|
||||
if os.path.isfile(f) and not re.search(exclude_patterns, f)]
|
||||
|
||||
if not file_manifest:
|
||||
report_exception(chatbot, history, a=f"解析项目: {txt}", b="未找到支持的文件类型")
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
# 处理所有文件并生成总结
|
||||
final_summary = yield from summarizer.process_files(project_folder, file_manifest)
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
# 保存结果
|
||||
summarizer.save_results(final_summary)
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
@@ -1,9 +1,9 @@
|
||||
from toolbox import CatchException, report_exception, get_log_folder, gen_time_str
|
||||
from toolbox import update_ui, promote_file_to_downloadzone, update_ui_lastest_msg, disable_auto_promotion
|
||||
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from .crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||
from .crazy_utils import read_and_clean_pdf_text
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||
from crazy_functions.crazy_utils import read_and_clean_pdf_text
|
||||
from .pdf_fns.parse_pdf import parse_pdf, get_avail_grobid_url, translate_pdf
|
||||
from shared_utils.colorful import *
|
||||
import copy
|
||||
@@ -60,7 +60,7 @@ def 批量翻译PDF文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
|
||||
# 清空历史,以免输入溢出
|
||||
history = []
|
||||
|
||||
from .crazy_utils import get_files_from_everything
|
||||
from crazy_functions.crazy_utils import get_files_from_everything
|
||||
success, file_manifest, project_folder = get_files_from_everything(txt, type='.pdf')
|
||||
if len(file_manifest) > 0:
|
||||
# 尝试导入依赖,如果缺少依赖,则给出安装建议
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from loguru import logger
|
||||
from toolbox import CatchException, update_ui, gen_time_str, promote_file_to_downloadzone
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from crazy_functions.crazy_utils import input_clipping
|
||||
@@ -34,10 +35,10 @@ def eval_manim(code):
|
||||
return f'gpt_log/{time_str}.mp4'
|
||||
except subprocess.CalledProcessError as e:
|
||||
output = e.output.decode()
|
||||
print(f"Command returned non-zero exit status {e.returncode}: {output}.")
|
||||
logger.error(f"Command returned non-zero exit status {e.returncode}: {output}.")
|
||||
return f"Evaluating python script failed: {e.output}."
|
||||
except:
|
||||
print('generating mp4 failed')
|
||||
logger.error('generating mp4 failed')
|
||||
return "Generating mp4 failed."
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
from loguru import logger
|
||||
from toolbox import update_ui
|
||||
from toolbox import CatchException, report_exception
|
||||
from .crazy_utils import read_and_clean_pdf_text
|
||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
fast_debug = False
|
||||
from crazy_functions.crazy_utils import read_and_clean_pdf_text
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
|
||||
|
||||
def 解析PDF(file_name, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
||||
import tiktoken
|
||||
print('begin analysis on:', file_name)
|
||||
logger.info('begin analysis on:', file_name)
|
||||
|
||||
############################## <第 0 步,切割PDF> ##################################
|
||||
# 递归地切割PDF文件,每一块(尽量是完整的一个section,比如introduction,experiment等,必要时再进行切割)
|
||||
@@ -36,7 +35,7 @@ def 解析PDF(file_name, llm_kwargs, plugin_kwargs, chatbot, history, system_pro
|
||||
last_iteration_result = paper_meta # 初始值是摘要
|
||||
MAX_WORD_TOTAL = 4096
|
||||
n_fragment = len(paper_fragments)
|
||||
if n_fragment >= 20: print('文章极长,不能达到预期效果')
|
||||
if n_fragment >= 20: logger.warning('文章极长,不能达到预期效果')
|
||||
for i in range(n_fragment):
|
||||
NUM_OF_WORD = MAX_WORD_TOTAL // n_fragment
|
||||
i_say = f"Read this section, recapitulate the content of this section with less than {NUM_OF_WORD} words: {paper_fragments[i]}"
|
||||
@@ -57,7 +56,7 @@ def 解析PDF(file_name, llm_kwargs, plugin_kwargs, chatbot, history, system_pro
|
||||
chatbot.append([i_say_show_user, gpt_say])
|
||||
|
||||
############################## <第 4 步,设置一个token上限,防止回答时Token溢出> ##################################
|
||||
from .crazy_utils import input_clipping
|
||||
from crazy_functions.crazy_utils import input_clipping
|
||||
_, final_results = input_clipping("", final_results, max_token_limit=3200)
|
||||
yield from update_ui(chatbot=chatbot, history=final_results) # 注意这里的历史记录被替代了
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from loguru import logger
|
||||
from toolbox import update_ui
|
||||
from toolbox import CatchException, report_exception
|
||||
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
fast_debug = False
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
|
||||
def 生成函数注释(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
||||
import time, os
|
||||
print('begin analysis on:', file_manifest)
|
||||
logger.info('begin analysis on:', file_manifest)
|
||||
for index, fp in enumerate(file_manifest):
|
||||
with open(fp, 'r', encoding='utf-8', errors='replace') as f:
|
||||
file_content = f.read()
|
||||
@@ -16,22 +16,20 @@ def 生成函数注释(file_manifest, project_folder, llm_kwargs, plugin_kwargs,
|
||||
chatbot.append((i_say_show_user, "[Local Message] waiting gpt response."))
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
if not fast_debug:
|
||||
msg = '正常'
|
||||
# ** gpt request **
|
||||
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||
i_say, i_say_show_user, llm_kwargs, chatbot, history=[], sys_prompt=system_prompt) # 带超时倒计时
|
||||
msg = '正常'
|
||||
# ** gpt request **
|
||||
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||
i_say, i_say_show_user, llm_kwargs, chatbot, history=[], sys_prompt=system_prompt) # 带超时倒计时
|
||||
|
||||
chatbot[-1] = (i_say_show_user, gpt_say)
|
||||
history.append(i_say_show_user); history.append(gpt_say)
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg=msg) # 刷新界面
|
||||
if not fast_debug: time.sleep(2)
|
||||
|
||||
if not fast_debug:
|
||||
res = write_history_to_file(history)
|
||||
promote_file_to_downloadzone(res, chatbot=chatbot)
|
||||
chatbot.append(("完成了吗?", res))
|
||||
chatbot[-1] = (i_say_show_user, gpt_say)
|
||||
history.append(i_say_show_user); history.append(gpt_say)
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg=msg) # 刷新界面
|
||||
time.sleep(2)
|
||||
|
||||
res = write_history_to_file(history)
|
||||
promote_file_to_downloadzone(res, chatbot=chatbot)
|
||||
chatbot.append(("完成了吗?", res))
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg=msg) # 刷新界面
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from toolbox import CatchException, update_ui, report_exception
|
||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from crazy_functions.plugin_template.plugin_class_template import (
|
||||
GptAcademicPluginTemplate,
|
||||
)
|
||||
@@ -201,8 +201,7 @@ def 解析历史输入(history, llm_kwargs, file_manifest, chatbot, plugin_kwarg
|
||||
MAX_WORD_TOTAL = 4096
|
||||
n_txt = len(txt)
|
||||
last_iteration_result = "从以下文本中提取摘要。"
|
||||
if n_txt >= 20:
|
||||
print("文章极长,不能达到预期效果")
|
||||
|
||||
for i in range(n_txt):
|
||||
NUM_OF_WORD = MAX_WORD_TOTAL // n_txt
|
||||
i_say = f"Read this section, recapitulate the content of this section with less than {NUM_OF_WORD} words in Chinese: {txt[i]}"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from toolbox import CatchException, update_ui, ProxyNetworkActivate, update_ui_lastest_msg, get_log_folder, get_user
|
||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, get_files_from_everything
|
||||
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, get_files_from_everything
|
||||
from loguru import logger
|
||||
install_msg ="""
|
||||
|
||||
1. python -m pip install torch --index-url https://download.pytorch.org/whl/cpu
|
||||
@@ -40,7 +40,7 @@ def 知识库文件注入(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
|
||||
except Exception as e:
|
||||
chatbot.append(["依赖不足", f"{str(e)}\n\n导入依赖失败。请用以下命令安装" + install_msg])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
# from .crazy_utils import try_install_deps
|
||||
# from crazy_functions.crazy_utils import try_install_deps
|
||||
# try_install_deps(['zh_langchain==0.2.1', 'pypinyin'], reload_m=['pypinyin', 'zh_langchain'])
|
||||
# yield from update_ui_lastest_msg("安装完成,您可以再次重试。", chatbot, history)
|
||||
return
|
||||
@@ -60,7 +60,7 @@ def 知识库文件注入(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
|
||||
# < -------------------预热文本向量化模组--------------- >
|
||||
chatbot.append(['<br/>'.join(file_manifest), "正在预热文本向量化模组, 如果是第一次运行, 将消耗较长时间下载中文向量化模型..."])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
print('Checking Text2vec ...')
|
||||
logger.info('Checking Text2vec ...')
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
with ProxyNetworkActivate('Download_LLM'): # 临时地激活代理网络
|
||||
HuggingFaceEmbeddings(model_name="GanymedeNil/text2vec-large-chinese")
|
||||
@@ -68,7 +68,7 @@ def 知识库文件注入(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
|
||||
# < -------------------构建知识库--------------- >
|
||||
chatbot.append(['<br/>'.join(file_manifest), "正在构建知识库..."])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
print('Establishing knowledge archive ...')
|
||||
logger.info('Establishing knowledge archive ...')
|
||||
with ProxyNetworkActivate('Download_LLM'): # 临时地激活代理网络
|
||||
kai = knowledge_archive_interface()
|
||||
vs_path = get_log_folder(user=get_user(chatbot), plugin_name='vec_store')
|
||||
@@ -93,7 +93,7 @@ def 读取知识库作答(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
|
||||
except Exception as e:
|
||||
chatbot.append(["依赖不足", f"{str(e)}\n\n导入依赖失败。请用以下命令安装" + install_msg])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
# from .crazy_utils import try_install_deps
|
||||
# from crazy_functions.crazy_utils import try_install_deps
|
||||
# try_install_deps(['zh_langchain==0.2.1', 'pypinyin'], reload_m=['pypinyin', 'zh_langchain'])
|
||||
# yield from update_ui_lastest_msg("安装完成,您可以再次重试。", chatbot, history)
|
||||
return
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from toolbox import CatchException, update_ui
|
||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, input_clipping
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, input_clipping
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from request_llms.bridge_all import model_info
|
||||
@@ -23,8 +23,8 @@ def google(query, proxies):
|
||||
item = {'title': title, 'link': link}
|
||||
results.append(item)
|
||||
|
||||
for r in results:
|
||||
print(r['link'])
|
||||
# for r in results:
|
||||
# print(r['link'])
|
||||
return results
|
||||
|
||||
def scrape_text(url, proxies) -> str:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from toolbox import CatchException, update_ui
|
||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, input_clipping
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, input_clipping
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from request_llms.bridge_all import model_info
|
||||
@@ -22,8 +22,8 @@ def bing_search(query, proxies=None):
|
||||
item = {'title': title, 'link': link}
|
||||
results.append(item)
|
||||
|
||||
for r in results:
|
||||
print(r['link'])
|
||||
# for r in results:
|
||||
# print(r['link'])
|
||||
return results
|
||||
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ def parseNotebook(filename, enable_markdown=1):
|
||||
|
||||
|
||||
def ipynb解释(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
||||
from .crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||
|
||||
if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
|
||||
enable_markdown = plugin_kwargs.get("advanced_arg", "1")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from toolbox import CatchException, update_ui, get_conf
|
||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
import datetime
|
||||
@CatchException
|
||||
def 同时问询(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
from toolbox import update_ui
|
||||
from toolbox import CatchException, get_conf, markdown_convertion
|
||||
from request_llms.bridge_all import predict_no_ui_long_connection
|
||||
from crazy_functions.crazy_utils import input_clipping
|
||||
from crazy_functions.agent_fns.watchdog import WatchDog
|
||||
from request_llms.bridge_all import predict_no_ui_long_connection
|
||||
from crazy_functions.live_audio.aliyunASR import AliyunASR
|
||||
from loguru import logger
|
||||
|
||||
import threading, time
|
||||
import numpy as np
|
||||
from .live_audio.aliyunASR import AliyunASR
|
||||
import json
|
||||
import re
|
||||
|
||||
@@ -42,9 +44,9 @@ class AsyncGptTask():
|
||||
gpt_say_partial = predict_no_ui_long_connection(inputs=i_say, llm_kwargs=llm_kwargs, history=history, sys_prompt=sys_prompt,
|
||||
observe_window=observe_window[index], console_slience=True)
|
||||
except ConnectionAbortedError as token_exceed_err:
|
||||
print('至少一个线程任务Token溢出而失败', e)
|
||||
logger.error('至少一个线程任务Token溢出而失败', e)
|
||||
except Exception as e:
|
||||
print('至少一个线程任务意外失败', e)
|
||||
logger.error('至少一个线程任务意外失败', e)
|
||||
|
||||
def add_async_gpt_task(self, i_say, chatbot_index, llm_kwargs, history, system_prompt):
|
||||
self.observe_future.append([""])
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
from toolbox import update_ui
|
||||
from toolbox import CatchException, report_exception
|
||||
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
|
||||
|
||||
def 解析Paper(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
||||
import time, glob, os
|
||||
print('begin analysis on:', file_manifest)
|
||||
for index, fp in enumerate(file_manifest):
|
||||
with open(fp, 'r', encoding='utf-8', errors='replace') as f:
|
||||
file_content = f.read()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from toolbox import CatchException, report_exception, promote_file_to_downloadzone
|
||||
from toolbox import update_ui, update_ui_lastest_msg, disable_auto_promotion, write_history_to_file
|
||||
import logging
|
||||
|
||||
@@ -180,6 +180,7 @@ version: '3'
|
||||
services:
|
||||
gpt_academic_with_latex:
|
||||
image: ghcr.io/binary-husky/gpt_academic_with_latex:master # (Auto Built by Dockerfile: docs/GithubAction+NoLocal+Latex)
|
||||
# 对于ARM64设备,请将以上镜像名称替换为 ghcr.io/binary-husky/gpt_academic_with_latex_arm:master
|
||||
environment:
|
||||
# 请查阅 `config.py` 以查看所有的配置信息
|
||||
API_KEY: ' sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx '
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
# 此Dockerfile不再维护,请前往docs/GithubAction+JittorLLMs
|
||||
@@ -1,57 +0,0 @@
|
||||
# docker build -t gpt-academic-all-capacity -f docs/GithubAction+AllCapacity --network=host --build-arg http_proxy=http://localhost:10881 --build-arg https_proxy=http://localhost:10881 .
|
||||
# docker build -t gpt-academic-all-capacity -f docs/GithubAction+AllCapacityBeta --network=host .
|
||||
# docker run -it --net=host gpt-academic-all-capacity bash
|
||||
|
||||
# 从NVIDIA源,从而支持显卡(检查宿主的nvidia-smi中的cuda版本必须>=11.3)
|
||||
FROM fuqingxu/11.3.1-runtime-ubuntu20.04-with-texlive:latest
|
||||
|
||||
# edge-tts需要的依赖,某些pip包所需的依赖
|
||||
RUN apt update && apt install ffmpeg build-essential -y
|
||||
|
||||
# use python3 as the system default python
|
||||
WORKDIR /gpt
|
||||
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.8
|
||||
|
||||
# # 非必要步骤,更换pip源 (以下三行,可以删除)
|
||||
# RUN echo '[global]' > /etc/pip.conf && \
|
||||
# echo 'index-url = https://mirrors.aliyun.com/pypi/simple/' >> /etc/pip.conf && \
|
||||
# echo 'trusted-host = mirrors.aliyun.com' >> /etc/pip.conf
|
||||
|
||||
# 下载pytorch
|
||||
RUN python3 -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113
|
||||
# 准备pip依赖
|
||||
RUN python3 -m pip install openai numpy arxiv rich
|
||||
RUN python3 -m pip install colorama Markdown pygments pymupdf
|
||||
RUN python3 -m pip install python-docx moviepy pdfminer
|
||||
RUN python3 -m pip install zh_langchain==0.2.1 pypinyin
|
||||
RUN python3 -m pip install rarfile py7zr
|
||||
RUN python3 -m pip install aliyun-python-sdk-core==2.13.3 pyOpenSSL webrtcvad scipy git+https://github.com/aliyun/alibabacloud-nls-python-sdk.git
|
||||
# 下载分支
|
||||
WORKDIR /gpt
|
||||
RUN git clone --depth=1 https://github.com/binary-husky/gpt_academic.git
|
||||
WORKDIR /gpt/gpt_academic
|
||||
RUN git clone --depth=1 https://github.com/OpenLMLab/MOSS.git request_llms/moss
|
||||
|
||||
RUN python3 -m pip install -r requirements.txt
|
||||
RUN python3 -m pip install -r request_llms/requirements_moss.txt
|
||||
RUN python3 -m pip install -r request_llms/requirements_qwen.txt
|
||||
RUN python3 -m pip install -r request_llms/requirements_chatglm.txt
|
||||
RUN python3 -m pip install -r request_llms/requirements_newbing.txt
|
||||
RUN python3 -m pip install nougat-ocr
|
||||
|
||||
|
||||
# 预热Tiktoken模块
|
||||
RUN python3 -c 'from check_proxy import warm_up_modules; warm_up_modules()'
|
||||
|
||||
# 安装知识库插件的额外依赖
|
||||
RUN apt-get update && apt-get install libgl1 -y
|
||||
RUN pip3 install transformers protobuf langchain sentence-transformers faiss-cpu nltk beautifulsoup4 bitsandbytes tabulate icetk --upgrade
|
||||
RUN pip3 install unstructured[all-docs] --upgrade
|
||||
RUN python3 -c 'from check_proxy import warm_up_vectordb; warm_up_vectordb()'
|
||||
RUN rm -rf /usr/local/lib/python3.8/dist-packages/tests
|
||||
|
||||
|
||||
# COPY .cache /root/.cache
|
||||
# COPY config_private.py config_private.py
|
||||
# 启动
|
||||
CMD ["python3", "-u", "main.py"]
|
||||
@@ -1,35 +1,34 @@
|
||||
# 此Dockerfile适用于“无本地模型”的环境构建,如果需要使用chatglm等本地模型,请参考 docs/Dockerfile+ChatGLM
|
||||
# 此Dockerfile适用于"无本地模型"的环境构建,如果需要使用chatglm等本地模型,请参考 docs/Dockerfile+ChatGLM
|
||||
# - 1 修改 `config.py`
|
||||
# - 2 构建 docker build -t gpt-academic-nolocal-latex -f docs/GithubAction+NoLocal+Latex .
|
||||
# - 3 运行 docker run -v /home/fuqingxu/arxiv_cache:/root/arxiv_cache --rm -it --net=host gpt-academic-nolocal-latex
|
||||
|
||||
FROM fuqingxu/python311_texlive_ctex:latest
|
||||
ENV PATH "$PATH:/usr/local/texlive/2022/bin/x86_64-linux"
|
||||
ENV PATH "$PATH:/usr/local/texlive/2023/bin/x86_64-linux"
|
||||
ENV PATH "$PATH:/usr/local/texlive/2024/bin/x86_64-linux"
|
||||
ENV PATH "$PATH:/usr/local/texlive/2025/bin/x86_64-linux"
|
||||
ENV PATH "$PATH:/usr/local/texlive/2026/bin/x86_64-linux"
|
||||
|
||||
# 指定路径
|
||||
FROM menghuan1918/ubuntu_uv_ctex:latest
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
WORKDIR /gpt
|
||||
|
||||
RUN pip3 install openai numpy arxiv rich
|
||||
RUN pip3 install colorama Markdown pygments pymupdf
|
||||
RUN pip3 install python-docx pdfminer
|
||||
RUN pip3 install nougat-ocr
|
||||
|
||||
# 装载项目文件
|
||||
COPY . .
|
||||
|
||||
# 先复制依赖文件
|
||||
COPY requirements.txt .
|
||||
|
||||
# 安装依赖
|
||||
RUN pip3 install -r requirements.txt
|
||||
RUN pip install --break-system-packages openai numpy arxiv rich colorama Markdown pygments pymupdf python-docx pdfminer \
|
||||
&& pip install --break-system-packages -r requirements.txt \
|
||||
&& if [ "$(uname -m)" = "x86_64" ]; then \
|
||||
pip install --break-system-packages nougat-ocr; \
|
||||
fi \
|
||||
&& pip cache purge \
|
||||
&& rm -rf /root/.cache/pip/*
|
||||
|
||||
# edge-tts需要的依赖
|
||||
RUN apt update && apt install ffmpeg -y
|
||||
# 创建非root用户
|
||||
RUN useradd -m gptuser && chown -R gptuser /gpt
|
||||
USER gptuser
|
||||
|
||||
# 最后才复制代码文件,这样代码更新时只需重建最后几层,可以大幅减少docker pull所需的大小
|
||||
COPY --chown=gptuser:gptuser . .
|
||||
|
||||
# 可选步骤,用于预热模块
|
||||
RUN python3 -c 'from check_proxy import warm_up_modules; warm_up_modules()'
|
||||
RUN python3 -c 'from check_proxy import warm_up_modules; warm_up_modules()'
|
||||
|
||||
# 启动
|
||||
CMD ["python3", "-u", "main.py"]
|
||||
|
||||
@@ -4,7 +4,7 @@ We currently support fastapi in order to solve sub-path deploy issue.
|
||||
|
||||
1. change CUSTOM_PATH setting in `config.py`
|
||||
|
||||
``` sh
|
||||
```sh
|
||||
nano config.py
|
||||
```
|
||||
|
||||
@@ -35,9 +35,8 @@ if __name__ == "__main__":
|
||||
main()
|
||||
```
|
||||
|
||||
|
||||
3. Go!
|
||||
|
||||
``` sh
|
||||
```sh
|
||||
python main.py
|
||||
```
|
||||
|
||||
文件差异内容过多而无法显示
加载差异
@@ -106,5 +106,24 @@
|
||||
"解析PDF_DOC2X_转Latex": "ParsePDF_DOC2X_toLatex",
|
||||
"解析PDF_基于DOC2X": "ParsePDF_basedDOC2X",
|
||||
"解析PDF_简单拆解": "ParsePDF_simpleDecomposition",
|
||||
"解析PDF_DOC2X_单文件": "ParsePDF_DOC2X_singleFile"
|
||||
"解析PDF_DOC2X_单文件": "ParsePDF_DOC2X_singleFile",
|
||||
"注释Python项目": "CommentPythonProject",
|
||||
"注释源代码": "CommentSourceCode",
|
||||
"log亮黄": "log_yellow",
|
||||
"log亮绿": "log_green",
|
||||
"log亮红": "log_red",
|
||||
"log亮紫": "log_purple",
|
||||
"log亮蓝": "log_blue",
|
||||
"Rag问答": "RagQA",
|
||||
"sprint红": "sprint_red",
|
||||
"sprint绿": "sprint_green",
|
||||
"sprint黄": "sprint_yellow",
|
||||
"sprint蓝": "sprint_blue",
|
||||
"sprint紫": "sprint_purple",
|
||||
"sprint靛": "sprint_indigo",
|
||||
"sprint亮红": "sprint_bright_red",
|
||||
"sprint亮绿": "sprint_bright_green",
|
||||
"sprint亮黄": "sprint_bright_yellow",
|
||||
"sprint亮蓝": "sprint_bright_blue",
|
||||
"sprint亮紫": "sprint_bright_purple"
|
||||
}
|
||||
216
instruction.txt
普通文件
216
instruction.txt
普通文件
@@ -0,0 +1,216 @@
|
||||
|
||||
1、GPT Academic 项目结构
|
||||
.
|
||||
├── Dockerfile
|
||||
├── LICENSE
|
||||
├── README.md
|
||||
├── check_proxy.py
|
||||
├── config.py
|
||||
├── config_private.py
|
||||
├── core_functional.py
|
||||
├── crazy_functional.py
|
||||
├── crazy_functions
|
||||
│ ├── Arxiv_论文对话.py
|
||||
│ ├── Conversation_To_File.py
|
||||
│ ├── Image_Generate.py
|
||||
│ ├── Image_Generate_Wrap.py
|
||||
│ ├── Internet_GPT.py
|
||||
│ ├── Internet_GPT_Wrap.py
|
||||
│ ├── Latex_Function.py
|
||||
│ ├── Latex_Function_Wrap.py
|
||||
│ ├── Latex全文润色.py
|
||||
│ ├── Latex全文翻译.py
|
||||
│ ├── Markdown_Translate.py
|
||||
│ ├── PDF_Translate.py
|
||||
│ ├── PDF_Translate_Wrap.py
|
||||
│ ├── Rag_Interface.py
|
||||
│ ├── Social_Helper.py
|
||||
│ ├── SourceCode_Analyse.py
|
||||
│ ├── SourceCode_Comment.py
|
||||
│ ├── SourceCode_Comment_Wrap.py
|
||||
│ ├── __init__.py
|
||||
│ │ ├── auto_agent.py
|
||||
│ │ ├── echo_agent.py
|
||||
│ │ ├── general.py
|
||||
│ │ ├── persistent.py
|
||||
│ │ ├── pipe.py
|
||||
│ │ ├── python_comment_agent.py
|
||||
│ │ ├── python_comment_compare.html
|
||||
│ │ └── watchdog.py
|
||||
│ ├── ast_fns
|
||||
│ │ └── comment_remove.py
|
||||
│ ├── chatglm微调工具.py
|
||||
│ ├── crazy_utils.py
|
||||
│ ├── diagram_fns
|
||||
│ │ └── file_tree.py
|
||||
│ ├── game_fns
|
||||
│ │ ├── game_ascii_art.py
|
||||
│ │ ├── game_interactive_story.py
|
||||
│ │ └── game_utils.py
|
||||
│ ├── gen_fns
|
||||
│ │ └── gen_fns_shared.py
|
||||
│ ├── ipc_fns
|
||||
│ │ └── mp.py
|
||||
│ ├── json_fns
|
||||
│ │ ├── pydantic_io.py
|
||||
│ │ └── select_tool.py
|
||||
│ ├── latex_fns
|
||||
│ │ ├── latex_actions.py
|
||||
│ │ ├── latex_pickle_io.py
|
||||
│ │ └── latex_toolbox.py
|
||||
│ ├── live_audio
|
||||
│ │ ├── aliyunASR.py
|
||||
│ │ └── audio_io.py
|
||||
│ ├── multi_stage
|
||||
│ │ └── multi_stage_utils.py
|
||||
│ ├── rag_essay_fns
|
||||
│ │ └── multi_stage_utils.py
|
||||
│ ├── pdf_fns
|
||||
│ │ ├── breakdown_txt.py
|
||||
│ │ ├── parse_pdf.py
|
||||
│ │ ├── parse_pdf_grobid.py
|
||||
│ │ ├── parse_pdf_legacy.py
|
||||
│ │ ├── parse_pdf_via_doc2x.py
|
||||
│ │ ├── parse_word.py
|
||||
│ │ ├── report_gen_html.py
|
||||
│ │ ├── report_template.html
|
||||
│ │ └── report_template_v2.html
|
||||
│ ├── plugin_template
|
||||
│ │ └── plugin_class_template.py
|
||||
│ ├── prompts
|
||||
│ │ └── internet.py
|
||||
│ ├── rag_fns
|
||||
│ │ ├── llama_index_worker.py
|
||||
│ │ ├── milvus_worker.py
|
||||
│ │ ├── rag_file_support.py
|
||||
│ │ └── vector_store_index.py
|
||||
│ ├── vector_fns
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── general_file_loader.py
|
||||
│ │ └── vector_database.py
|
||||
│ ├── vt_fns
|
||||
│ │ ├── vt_call_plugin.py
|
||||
│ │ ├── vt_modify_config.py
|
||||
│ │ └── vt_state.py
|
||||
│ ├── 下载arxiv论文翻译摘要.py
|
||||
│ ├── 互动小游戏.py
|
||||
│ ├── 交互功能函数模板.py
|
||||
│ ├── 函数动态生成.py
|
||||
│ ├── 命令行助手.py
|
||||
│ ├── 多智能体.py
|
||||
│ ├── 总结word文档.py
|
||||
│ ├── 总结音视频.py
|
||||
│ ├── 批量总结PDF文档.py
|
||||
│ ├── 批量总结PDF文档pdfminer.py
|
||||
│ ├── 批量文件询问.py
|
||||
│ ├── 批量翻译PDF文档_NOUGAT.py
|
||||
│ ├── 数学动画生成manim.py
|
||||
│ ├── 理解PDF文档内容.py
|
||||
│ ├── 生成函数注释.py
|
||||
│ ├── 生成多种Mermaid图表.py
|
||||
│ ├── 知识库问答.py
|
||||
│ ├── 联网的ChatGPT.py
|
||||
│ ├── 联网的ChatGPT_bing版.py
|
||||
│ ├── 虚空终端.py
|
||||
│ ├── 解析JupyterNotebook.py
|
||||
│ ├── 询问多个大语言模型.py
|
||||
│ ├── 语音助手.py
|
||||
│ ├── 读文章写摘要.py
|
||||
│ ├── 谷歌检索小助手.py
|
||||
│ ├── 辅助功能.py
|
||||
│ └── 高级功能函数模板.py
|
||||
├── docker-compose.yml
|
||||
├── instruction.txt
|
||||
├── main.py
|
||||
├── multi_language.py
|
||||
├── requirements.txt
|
||||
├── shared_utils
|
||||
│ ├── advanced_markdown_format.py
|
||||
│ ├── char_visual_effect.py
|
||||
│ ├── colorful.py
|
||||
│ ├── config_loader.py
|
||||
│ ├── connect_void_terminal.py
|
||||
│ ├── cookie_manager.py
|
||||
│ ├── fastapi_server.py
|
||||
│ ├── handle_upload.py
|
||||
│ ├── key_pattern_manager.py
|
||||
│ ├── logging.py
|
||||
│ ├── map_names.py
|
||||
│ └── text_mask.py
|
||||
├── toolbox.py
|
||||
└── version
|
||||
|
||||
2、light_rag的实现方案路径为crazy_functions/rag_fns/LightRAG,主要功能实现文件为operate.py,rag使用到的其他文件为prompt.py、base.py、storage.py、utils.py,请参考实现方案实现插件功能。light_rag的使用案例可以参考crazy_functions/rag_fns/LightRAG/examples路径下的lightrag_hf_demo.py、lightrag_lmdeploy_demo.py:
|
||||
路径目录结构为
|
||||
|
||||
├── README.md
|
||||
├── examples
|
||||
│ ├── batch_eval.py
|
||||
│ ├── generate_query.py
|
||||
│ ├── graph_visual_with_html.py
|
||||
│ ├── graph_visual_with_neo4j.py
|
||||
│ ├── lightrag_azure_openai_demo.py
|
||||
│ ├── lightrag_bedrock_demo.py
|
||||
│ ├── lightrag_hf_demo.py
|
||||
│ ├── lightrag_ollama_demo.py
|
||||
│ ├── lightrag_openai_compatible_demo.py
|
||||
│ ├── lightrag_openai_demo.py
|
||||
│ └── vram_management_demo.py
|
||||
├── lightrag
|
||||
│ ├── __init__.py
|
||||
│ ├── base.py
|
||||
│ ├── lightrag.py
|
||||
│ ├── llm.py
|
||||
│ ├── operate.py
|
||||
│ ├── prompt.py
|
||||
│ ├── storage.py
|
||||
│ └── utils.py
|
||||
├── reproduce
|
||||
│ ├── Step_0.py
|
||||
│ ├── Step_1.py
|
||||
│ ├── Step_1_openai_compatible.py
|
||||
│ ├── Step_2.py
|
||||
│ ├── Step_3.py
|
||||
│ └── Step_3_openai_compatible.py
|
||||
├── requirements.txt
|
||||
└── setup.py
|
||||
|
||||
|
||||
3、我需要开发一个rag插件,请帮我实现一个插件,插件的名称是rag论文总结,插件主入口在crazy_functions/Arxiv_论文对话.py中的Rag论文对话函数,插件的功能步骤分为文件处理和RAG两个步骤,以下是具体的一些要求:
|
||||
I. 函数头如下:
|
||||
@CatchException
|
||||
def rag论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
|
||||
history: List, system_prompt: str, user_request: str):
|
||||
II. 函数返回可参考crazy_functions/批量文件询问.py中的“批量文件询问”函数,主要采用yield方式
|
||||
|
||||
3、对于RAG,我希望采用light_rag的方案,参考已有方案其主要的功能实现是:
|
||||
主要功能包括:
|
||||
a. 分别为project和arxiv创建rag_handler,project类的fragment类内容为
|
||||
@dataclass
|
||||
class DocFragment:
|
||||
"""文本片段数据类"""
|
||||
file_path: str # 原始文件路径
|
||||
content: str # 片段内容
|
||||
segment_index: int # 片段序号
|
||||
total_segments: int # 总片段数
|
||||
rel_path: str # 相对路径
|
||||
arxiv的fragment内容为:
|
||||
@dataclass
|
||||
class ArxivFragment:
|
||||
"""Arxiv论文片段数据类"""
|
||||
file_path: str
|
||||
content: str
|
||||
segment_index: int
|
||||
total_segments: int
|
||||
rel_path: str
|
||||
segment_type: str
|
||||
title: str
|
||||
abstract: str
|
||||
section: str
|
||||
is_appendix: bool
|
||||
b 如果目录下不存在抽取好的实体或关系的摘要,利用`_handle_entity_relation_summary`函数对d步骤生成的文本块进行实体或关系的摘要,并将其存储在project或者arxiv的路径下,路径为获取fragment.file_path的前三级目录(按照“/”区分每一级),如果原目录存在抽取好的,请直接使用,不再重复抽取。
|
||||
f 利用`_handle_single_entity_extraction` 和 `_handle_single_relationship_extraction`:从记录中提取单个实体或关系信息。
|
||||
g `_merge_nodes_then_upsert` 和 `_merge_edges_then_upsert`:合并并插入节点或边。
|
||||
h `extract_entities`:处理多个文本块,提取实体和关系,并存储在知识图谱和向量数据库中。
|
||||
i `local_query`:根据查询提取关键词并生成响应。
|
||||
|
||||
31
main.py
31
main.py
@@ -13,16 +13,10 @@ help_menu_description = \
|
||||
</br></br>如何语音对话: 请阅读Wiki
|
||||
</br></br>如何临时更换API_KEY: 在输入区输入临时API_KEY后提交(网页刷新后失效)"""
|
||||
|
||||
from loguru import logger
|
||||
def enable_log(PATH_LOGGING):
|
||||
import logging
|
||||
admin_log_path = os.path.join(PATH_LOGGING, "admin")
|
||||
os.makedirs(admin_log_path, exist_ok=True)
|
||||
log_dir = os.path.join(admin_log_path, "chat_secrets.log")
|
||||
try:logging.basicConfig(filename=log_dir, level=logging.INFO, encoding="utf-8", format="%(asctime)s %(levelname)-8s %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
||||
except:logging.basicConfig(filename=log_dir, level=logging.INFO, format="%(asctime)s %(levelname)-8s %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
||||
# Disable logging output from the 'httpx' logger
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
print(f"所有对话记录将自动保存在本地目录{log_dir}, 请注意自我隐私保护哦!")
|
||||
from shared_utils.logging import setup_logging
|
||||
setup_logging(PATH_LOGGING)
|
||||
|
||||
def encode_plugin_info(k, plugin)->str:
|
||||
import copy
|
||||
@@ -42,9 +36,16 @@ def main():
|
||||
import gradio as gr
|
||||
if gr.__version__ not in ['3.32.9', '3.32.10', '3.32.11']:
|
||||
raise ModuleNotFoundError("使用项目内置Gradio获取最优体验! 请运行 `pip install -r requirements.txt` 指令安装内置Gradio及其他依赖, 详情信息见requirements.txt.")
|
||||
from request_llms.bridge_all import predict
|
||||
|
||||
# 一些基础工具
|
||||
from toolbox import format_io, find_free_port, on_file_uploaded, on_report_generated, get_conf, ArgsGeneralWrapper, DummyWith
|
||||
|
||||
# 对话、日志记录
|
||||
enable_log(get_conf("PATH_LOGGING"))
|
||||
|
||||
# 对话句柄
|
||||
from request_llms.bridge_all import predict
|
||||
|
||||
# 读取配置
|
||||
proxies, WEB_PORT, LLM_MODEL, CONCURRENT_COUNT, AUTHENTICATION = get_conf('proxies', 'WEB_PORT', 'LLM_MODEL', 'CONCURRENT_COUNT', 'AUTHENTICATION')
|
||||
CHATBOT_HEIGHT, LAYOUT, AVAIL_LLM_MODELS, AUTO_CLEAR_TXT = get_conf('CHATBOT_HEIGHT', 'LAYOUT', 'AVAIL_LLM_MODELS', 'AUTO_CLEAR_TXT')
|
||||
@@ -61,8 +62,6 @@ def main():
|
||||
from themes.theme import load_dynamic_theme, to_cookie_str, from_cookie_str, assign_user_uuid
|
||||
title_html = f"<h1 align=\"center\">GPT 学术优化 {get_current_version()}</h1>{theme_declaration}"
|
||||
|
||||
# 对话、日志记录
|
||||
enable_log(PATH_LOGGING)
|
||||
|
||||
# 一些普通功能模块
|
||||
from core_functional import get_core_functions
|
||||
@@ -118,8 +117,8 @@ def main():
|
||||
choices=[
|
||||
"常规对话",
|
||||
"多模型对话",
|
||||
"智能召回 RAG",
|
||||
# "智能上下文",
|
||||
# "智能召回 RAG",
|
||||
], value="常规对话",
|
||||
interactive=True, label='', show_label=False,
|
||||
elem_classes='normal_mut_select', elem_id="gpt-submit-dropdown").style(container=False)
|
||||
@@ -339,9 +338,9 @@ def main():
|
||||
# Gradio的inbrowser触发不太稳定,回滚代码到原始的浏览器打开函数
|
||||
def run_delayed_tasks():
|
||||
import threading, webbrowser, time
|
||||
print(f"如果浏览器没有自动打开,请复制并转到以下URL:")
|
||||
if DARK_MODE: print(f"\t「暗色主题已启用(支持动态切换主题)」: http://localhost:{PORT}")
|
||||
else: print(f"\t「亮色主题已启用(支持动态切换主题)」: http://localhost:{PORT}")
|
||||
logger.info(f"如果浏览器没有自动打开,请复制并转到以下URL:")
|
||||
if DARK_MODE: logger.info(f"\t「暗色主题已启用(支持动态切换主题)」: http://localhost:{PORT}")
|
||||
else: logger.info(f"\t「亮色主题已启用(支持动态切换主题)」: http://localhost:{PORT}")
|
||||
|
||||
def auto_updates(): time.sleep(0); auto_update()
|
||||
def open_browser(): time.sleep(2); webbrowser.open_new_tab(f"http://localhost:{PORT}")
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
2. predict_no_ui_long_connection(...)
|
||||
"""
|
||||
import tiktoken, copy, re
|
||||
from loguru import logger
|
||||
from functools import lru_cache
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from toolbox import get_conf, trimmed_format_exc, apply_gpt_academic_string_mask, read_one_api_model_name
|
||||
@@ -51,9 +52,9 @@ class LazyloadTiktoken(object):
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=128)
|
||||
def get_encoder(model):
|
||||
print('正在加载tokenizer,如果是第一次运行,可能需要一点时间下载参数')
|
||||
logger.info('正在加载tokenizer,如果是第一次运行,可能需要一点时间下载参数')
|
||||
tmp = tiktoken.encoding_for_model(model)
|
||||
print('加载tokenizer完毕')
|
||||
logger.info('加载tokenizer完毕')
|
||||
return tmp
|
||||
|
||||
def encode(self, *args, **kwargs):
|
||||
@@ -83,7 +84,7 @@ try:
|
||||
API_URL = get_conf("API_URL")
|
||||
if API_URL != "https://api.openai.com/v1/chat/completions":
|
||||
openai_endpoint = API_URL
|
||||
print("警告!API_URL配置选项将被弃用,请更换为API_URL_REDIRECT配置")
|
||||
logger.warning("警告!API_URL配置选项将被弃用,请更换为API_URL_REDIRECT配置")
|
||||
except:
|
||||
pass
|
||||
# 新版配置
|
||||
@@ -248,6 +249,27 @@ model_info = {
|
||||
"token_cnt": get_token_num_gpt4,
|
||||
},
|
||||
|
||||
"o1-preview": {
|
||||
"fn_with_ui": chatgpt_ui,
|
||||
"fn_without_ui": chatgpt_noui,
|
||||
"endpoint": openai_endpoint,
|
||||
"max_token": 128000,
|
||||
"tokenizer": tokenizer_gpt4,
|
||||
"token_cnt": get_token_num_gpt4,
|
||||
"openai_disable_system_prompt": True,
|
||||
"openai_disable_stream": True,
|
||||
},
|
||||
"o1-mini": {
|
||||
"fn_with_ui": chatgpt_ui,
|
||||
"fn_without_ui": chatgpt_noui,
|
||||
"endpoint": openai_endpoint,
|
||||
"max_token": 128000,
|
||||
"tokenizer": tokenizer_gpt4,
|
||||
"token_cnt": get_token_num_gpt4,
|
||||
"openai_disable_system_prompt": True,
|
||||
"openai_disable_stream": True,
|
||||
},
|
||||
|
||||
"gpt-4-turbo": {
|
||||
"fn_with_ui": chatgpt_ui,
|
||||
"fn_without_ui": chatgpt_noui,
|
||||
@@ -363,6 +385,14 @@ model_info = {
|
||||
"tokenizer": tokenizer_gpt35,
|
||||
"token_cnt": get_token_num_gpt35,
|
||||
},
|
||||
"glm-4-plus":{
|
||||
"fn_with_ui": zhipu_ui,
|
||||
"fn_without_ui": zhipu_noui,
|
||||
"endpoint": None,
|
||||
"max_token": 10124 * 8,
|
||||
"tokenizer": tokenizer_gpt35,
|
||||
"token_cnt": get_token_num_gpt35,
|
||||
},
|
||||
|
||||
# api_2d (此后不需要在此处添加api2d的接口了,因为下面的代码会自动添加)
|
||||
"api2d-gpt-4": {
|
||||
@@ -407,22 +437,46 @@ model_info = {
|
||||
"tokenizer": tokenizer_gpt35,
|
||||
"token_cnt": get_token_num_gpt35,
|
||||
},
|
||||
# Gemini
|
||||
# Note: now gemini-pro is an alias of gemini-1.0-pro.
|
||||
# Warning: gemini-pro-vision has been deprecated.
|
||||
# Support for gemini-pro-vision has been removed.
|
||||
"gemini-pro": {
|
||||
"fn_with_ui": genai_ui,
|
||||
"fn_without_ui": genai_noui,
|
||||
"endpoint": gemini_endpoint,
|
||||
"has_multimodal_capacity": False,
|
||||
"max_token": 1024 * 32,
|
||||
"tokenizer": tokenizer_gpt35,
|
||||
"token_cnt": get_token_num_gpt35,
|
||||
},
|
||||
"gemini-pro-vision": {
|
||||
"gemini-1.0-pro": {
|
||||
"fn_with_ui": genai_ui,
|
||||
"fn_without_ui": genai_noui,
|
||||
"endpoint": gemini_endpoint,
|
||||
"has_multimodal_capacity": False,
|
||||
"max_token": 1024 * 32,
|
||||
"tokenizer": tokenizer_gpt35,
|
||||
"token_cnt": get_token_num_gpt35,
|
||||
},
|
||||
"gemini-1.5-pro": {
|
||||
"fn_with_ui": genai_ui,
|
||||
"fn_without_ui": genai_noui,
|
||||
"endpoint": gemini_endpoint,
|
||||
"has_multimodal_capacity": True,
|
||||
"max_token": 1024 * 204800,
|
||||
"tokenizer": tokenizer_gpt35,
|
||||
"token_cnt": get_token_num_gpt35,
|
||||
},
|
||||
"gemini-1.5-flash": {
|
||||
"fn_with_ui": genai_ui,
|
||||
"fn_without_ui": genai_noui,
|
||||
"endpoint": gemini_endpoint,
|
||||
"has_multimodal_capacity": True,
|
||||
"max_token": 1024 * 204800,
|
||||
"tokenizer": tokenizer_gpt35,
|
||||
"token_cnt": get_token_num_gpt35,
|
||||
},
|
||||
|
||||
# cohere
|
||||
"cohere-command-r-plus": {
|
||||
@@ -638,7 +692,7 @@ if "newbing" in AVAIL_LLM_MODELS: # same with newbing-free
|
||||
}
|
||||
})
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
logger.error(trimmed_format_exc())
|
||||
if "chatglmft" in AVAIL_LLM_MODELS: # same with newbing-free
|
||||
try:
|
||||
from .bridge_chatglmft import predict_no_ui_long_connection as chatglmft_noui
|
||||
@@ -654,7 +708,7 @@ if "chatglmft" in AVAIL_LLM_MODELS: # same with newbing-free
|
||||
}
|
||||
})
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
logger.error(trimmed_format_exc())
|
||||
# -=-=-=-=-=-=- 上海AI-LAB书生大模型 -=-=-=-=-=-=-
|
||||
if "internlm" in AVAIL_LLM_MODELS:
|
||||
try:
|
||||
@@ -671,7 +725,7 @@ if "internlm" in AVAIL_LLM_MODELS:
|
||||
}
|
||||
})
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
logger.error(trimmed_format_exc())
|
||||
if "chatglm_onnx" in AVAIL_LLM_MODELS:
|
||||
try:
|
||||
from .bridge_chatglmonnx import predict_no_ui_long_connection as chatglm_onnx_noui
|
||||
@@ -687,7 +741,7 @@ if "chatglm_onnx" in AVAIL_LLM_MODELS:
|
||||
}
|
||||
})
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
logger.error(trimmed_format_exc())
|
||||
# -=-=-=-=-=-=- 通义-本地模型 -=-=-=-=-=-=-
|
||||
if "qwen-local" in AVAIL_LLM_MODELS:
|
||||
try:
|
||||
@@ -705,7 +759,7 @@ if "qwen-local" in AVAIL_LLM_MODELS:
|
||||
}
|
||||
})
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
logger.error(trimmed_format_exc())
|
||||
# -=-=-=-=-=-=- 通义-在线模型 -=-=-=-=-=-=-
|
||||
if "qwen-turbo" in AVAIL_LLM_MODELS or "qwen-plus" in AVAIL_LLM_MODELS or "qwen-max" in AVAIL_LLM_MODELS: # zhipuai
|
||||
try:
|
||||
@@ -741,7 +795,7 @@ if "qwen-turbo" in AVAIL_LLM_MODELS or "qwen-plus" in AVAIL_LLM_MODELS or "qwen-
|
||||
}
|
||||
})
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
logger.error(trimmed_format_exc())
|
||||
# -=-=-=-=-=-=- 零一万物模型 -=-=-=-=-=-=-
|
||||
yi_models = ["yi-34b-chat-0205","yi-34b-chat-200k","yi-large","yi-medium","yi-spark","yi-large-turbo","yi-large-preview"]
|
||||
if any(item in yi_models for item in AVAIL_LLM_MODELS):
|
||||
@@ -821,7 +875,7 @@ if any(item in yi_models for item in AVAIL_LLM_MODELS):
|
||||
},
|
||||
})
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
logger.error(trimmed_format_exc())
|
||||
# -=-=-=-=-=-=- 讯飞星火认知大模型 -=-=-=-=-=-=-
|
||||
if "spark" in AVAIL_LLM_MODELS:
|
||||
try:
|
||||
@@ -839,7 +893,7 @@ if "spark" in AVAIL_LLM_MODELS:
|
||||
}
|
||||
})
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
logger.error(trimmed_format_exc())
|
||||
if "sparkv2" in AVAIL_LLM_MODELS: # 讯飞星火认知大模型
|
||||
try:
|
||||
from .bridge_spark import predict_no_ui_long_connection as spark_noui
|
||||
@@ -856,8 +910,8 @@ if "sparkv2" in AVAIL_LLM_MODELS: # 讯飞星火认知大模型
|
||||
}
|
||||
})
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
if "sparkv3" in AVAIL_LLM_MODELS or "sparkv3.5" in AVAIL_LLM_MODELS: # 讯飞星火认知大模型
|
||||
logger.error(trimmed_format_exc())
|
||||
if any(x in AVAIL_LLM_MODELS for x in ("sparkv3", "sparkv3.5", "sparkv4")): # 讯飞星火认知大模型
|
||||
try:
|
||||
from .bridge_spark import predict_no_ui_long_connection as spark_noui
|
||||
from .bridge_spark import predict as spark_ui
|
||||
@@ -891,7 +945,7 @@ if "sparkv3" in AVAIL_LLM_MODELS or "sparkv3.5" in AVAIL_LLM_MODELS: # 讯飞
|
||||
}
|
||||
})
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
logger.error(trimmed_format_exc())
|
||||
if "llama2" in AVAIL_LLM_MODELS: # llama2
|
||||
try:
|
||||
from .bridge_llama2 import predict_no_ui_long_connection as llama2_noui
|
||||
@@ -907,7 +961,7 @@ if "llama2" in AVAIL_LLM_MODELS: # llama2
|
||||
}
|
||||
})
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
logger.error(trimmed_format_exc())
|
||||
# -=-=-=-=-=-=- 智谱 -=-=-=-=-=-=-
|
||||
if "zhipuai" in AVAIL_LLM_MODELS: # zhipuai 是glm-4的别名,向后兼容配置
|
||||
try:
|
||||
@@ -922,7 +976,7 @@ if "zhipuai" in AVAIL_LLM_MODELS: # zhipuai 是glm-4的别名,向后兼容
|
||||
},
|
||||
})
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
logger.error(trimmed_format_exc())
|
||||
# -=-=-=-=-=-=- 幻方-深度求索大模型 -=-=-=-=-=-=-
|
||||
if "deepseekcoder" in AVAIL_LLM_MODELS: # deepseekcoder
|
||||
try:
|
||||
@@ -939,7 +993,7 @@ if "deepseekcoder" in AVAIL_LLM_MODELS: # deepseekcoder
|
||||
}
|
||||
})
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
logger.error(trimmed_format_exc())
|
||||
# -=-=-=-=-=-=- 幻方-深度求索大模型在线API -=-=-=-=-=-=-
|
||||
if "deepseek-chat" in AVAIL_LLM_MODELS or "deepseek-coder" in AVAIL_LLM_MODELS:
|
||||
try:
|
||||
@@ -967,7 +1021,7 @@ if "deepseek-chat" in AVAIL_LLM_MODELS or "deepseek-coder" in AVAIL_LLM_MODELS:
|
||||
},
|
||||
})
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
logger.error(trimmed_format_exc())
|
||||
# -=-=-=-=-=-=- one-api 对齐支持 -=-=-=-=-=-=-
|
||||
for model in [m for m in AVAIL_LLM_MODELS if m.startswith("one-api-")]:
|
||||
# 为了更灵活地接入one-api多模型管理界面,设计了此接口,例子:AVAIL_LLM_MODELS = ["one-api-mixtral-8x7b(max_token=6666)"]
|
||||
@@ -980,7 +1034,7 @@ for model in [m for m in AVAIL_LLM_MODELS if m.startswith("one-api-")]:
|
||||
# 如果是已知模型,则尝试获取其信息
|
||||
original_model_info = model_info.get(origin_model_name.replace("one-api-", "", 1), None)
|
||||
except:
|
||||
print(f"one-api模型 {model} 的 max_token 配置不是整数,请检查配置文件。")
|
||||
logger.error(f"one-api模型 {model} 的 max_token 配置不是整数,请检查配置文件。")
|
||||
continue
|
||||
this_model_info = {
|
||||
"fn_with_ui": chatgpt_ui,
|
||||
@@ -1011,7 +1065,7 @@ for model in [m for m in AVAIL_LLM_MODELS if m.startswith("vllm-")]:
|
||||
try:
|
||||
_, max_token_tmp = read_one_api_model_name(model)
|
||||
except:
|
||||
print(f"vllm模型 {model} 的 max_token 配置不是整数,请检查配置文件。")
|
||||
logger.error(f"vllm模型 {model} 的 max_token 配置不是整数,请检查配置文件。")
|
||||
continue
|
||||
model_info.update({
|
||||
model: {
|
||||
@@ -1038,7 +1092,7 @@ for model in [m for m in AVAIL_LLM_MODELS if m.startswith("ollama-")]:
|
||||
try:
|
||||
_, max_token_tmp = read_one_api_model_name(model)
|
||||
except:
|
||||
print(f"ollama模型 {model} 的 max_token 配置不是整数,请检查配置文件。")
|
||||
logger.error(f"ollama模型 {model} 的 max_token 配置不是整数,请检查配置文件。")
|
||||
continue
|
||||
model_info.update({
|
||||
model: {
|
||||
@@ -1074,6 +1128,24 @@ if len(AZURE_CFG_ARRAY) > 0:
|
||||
if azure_model_name not in AVAIL_LLM_MODELS:
|
||||
AVAIL_LLM_MODELS += [azure_model_name]
|
||||
|
||||
# -=-=-=-=-=-=- Openrouter模型对齐支持 -=-=-=-=-=-=-
|
||||
# 为了更灵活地接入Openrouter路由,设计了此接口
|
||||
for model in [m for m in AVAIL_LLM_MODELS if m.startswith("openrouter-")]:
|
||||
from request_llms.bridge_openrouter import predict_no_ui_long_connection as openrouter_noui
|
||||
from request_llms.bridge_openrouter import predict as openrouter_ui
|
||||
model_info.update({
|
||||
model: {
|
||||
"fn_with_ui": openrouter_ui,
|
||||
"fn_without_ui": openrouter_noui,
|
||||
# 以下参数参考gpt-4o-mini的配置, 请根据实际情况修改
|
||||
"endpoint": openai_endpoint,
|
||||
"has_multimodal_capacity": True,
|
||||
"max_token": 128000,
|
||||
"tokenizer": tokenizer_gpt4,
|
||||
"token_cnt": get_token_num_gpt4,
|
||||
},
|
||||
})
|
||||
|
||||
|
||||
# -=-=-=-=-=-=--=-=-=-=-=-=--=-=-=-=-=-=--=-=-=-=-=-=-=-=
|
||||
# -=-=-=-=-=-=-=-=-=- ☝️ 以上是模型路由 -=-=-=-=-=-=-=-=-=
|
||||
@@ -1219,5 +1291,5 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot,
|
||||
if additional_fn: # 根据基础功能区 ModelOverride 参数调整模型类型
|
||||
llm_kwargs, additional_fn, method = execute_model_override(llm_kwargs, additional_fn, method)
|
||||
|
||||
# 更新一下llm_kwargs的参数,否则会出现参数不匹配的问题
|
||||
yield from method(inputs, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, stream, additional_fn)
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ class GetGLM3Handle(LocalLLMHandle):
|
||||
|
||||
def load_model_and_tokenizer(self):
|
||||
# 🏃♂️🏃♂️🏃♂️ 子进程执行
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
|
||||
import os, glob
|
||||
import os
|
||||
import platform
|
||||
@@ -45,15 +45,13 @@ class GetGLM3Handle(LocalLLMHandle):
|
||||
chatglm_model = AutoModel.from_pretrained(
|
||||
pretrained_model_name_or_path=_model_name_,
|
||||
trust_remote_code=True,
|
||||
device="cuda",
|
||||
load_in_4bit=True,
|
||||
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
|
||||
)
|
||||
elif LOCAL_MODEL_QUANT == "INT8": # INT8
|
||||
chatglm_model = AutoModel.from_pretrained(
|
||||
pretrained_model_name_or_path=_model_name_,
|
||||
trust_remote_code=True,
|
||||
device="cuda",
|
||||
load_in_8bit=True,
|
||||
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
|
||||
)
|
||||
else:
|
||||
chatglm_model = AutoModel.from_pretrained(
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
from loguru import logger
|
||||
from toolbox import update_ui, get_conf
|
||||
from multiprocessing import Process, Pipe
|
||||
import time
|
||||
import os
|
||||
import json
|
||||
import threading
|
||||
import importlib
|
||||
from toolbox import update_ui, get_conf
|
||||
from multiprocessing import Process, Pipe
|
||||
|
||||
load_message = "ChatGLMFT尚未加载,加载需要一段时间。注意,取决于`config.py`的配置,ChatGLMFT消耗大量的内存(CPU)或显存(GPU),也许会导致低配计算机卡死 ……"
|
||||
|
||||
@@ -78,7 +79,7 @@ class GetGLMFTHandle(Process):
|
||||
config.pre_seq_len = model_args['pre_seq_len']
|
||||
config.prefix_projection = model_args['prefix_projection']
|
||||
|
||||
print(f"Loading prefix_encoder weight from {CHATGLM_PTUNING_CHECKPOINT}")
|
||||
logger.info(f"Loading prefix_encoder weight from {CHATGLM_PTUNING_CHECKPOINT}")
|
||||
model = AutoModel.from_pretrained(model_args['model_name_or_path'], config=config, trust_remote_code=True)
|
||||
prefix_state_dict = torch.load(os.path.join(CHATGLM_PTUNING_CHECKPOINT, "pytorch_model.bin"))
|
||||
new_prefix_state_dict = {}
|
||||
@@ -88,7 +89,7 @@ class GetGLMFTHandle(Process):
|
||||
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
||||
|
||||
if model_args['quantization_bit'] is not None and model_args['quantization_bit'] != 0:
|
||||
print(f"Quantized to {model_args['quantization_bit']} bit")
|
||||
logger.info(f"Quantized to {model_args['quantization_bit']} bit")
|
||||
model = model.quantize(model_args['quantization_bit'])
|
||||
model = model.cuda()
|
||||
if model_args['pre_seq_len'] is not None:
|
||||
|
||||
@@ -12,11 +12,12 @@ import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import logging
|
||||
import traceback
|
||||
import requests
|
||||
import random
|
||||
|
||||
from loguru import logger
|
||||
|
||||
# config_private.py放自己的秘密如API和代理网址
|
||||
# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
|
||||
from toolbox import get_conf, update_ui, is_any_api_key, select_api_key, what_keys, clip_history
|
||||
@@ -133,21 +134,32 @@ def predict_no_ui_long_connection(inputs:str, llm_kwargs:dict, history:list=[],
|
||||
observe_window = None:
|
||||
用于负责跨越线程传递已经输出的部分,大部分时候仅仅为了fancy的视觉效果,留空即可。observe_window[0]:观测窗。observe_window[1]:看门狗
|
||||
"""
|
||||
from request_llms.bridge_all import model_info
|
||||
|
||||
watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可
|
||||
headers, payload = generate_payload(inputs, llm_kwargs, history, system_prompt=sys_prompt, stream=True)
|
||||
|
||||
if model_info[llm_kwargs['llm_model']].get('openai_disable_stream', False): stream = False
|
||||
else: stream = True
|
||||
|
||||
headers, payload = generate_payload(inputs, llm_kwargs, history, system_prompt=sys_prompt, stream=stream)
|
||||
retry = 0
|
||||
while True:
|
||||
try:
|
||||
# make a POST request to the API endpoint, stream=False
|
||||
from .bridge_all import model_info
|
||||
endpoint = verify_endpoint(model_info[llm_kwargs['llm_model']]['endpoint'])
|
||||
response = requests.post(endpoint, headers=headers, proxies=proxies,
|
||||
json=payload, stream=True, timeout=TIMEOUT_SECONDS); break
|
||||
json=payload, stream=stream, timeout=TIMEOUT_SECONDS); break
|
||||
except requests.exceptions.ReadTimeout as e:
|
||||
retry += 1
|
||||
traceback.print_exc()
|
||||
if retry > MAX_RETRY: raise TimeoutError
|
||||
if MAX_RETRY!=0: print(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
|
||||
if MAX_RETRY!=0: logger.error(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
|
||||
|
||||
if not stream:
|
||||
# 该分支仅适用于不支持stream的o1模型,其他情形一律不适用
|
||||
chunkjson = json.loads(response.content.decode())
|
||||
gpt_replying_buffer = chunkjson['choices'][0]["message"]["content"]
|
||||
return gpt_replying_buffer
|
||||
|
||||
stream_response = response.iter_lines()
|
||||
result = ''
|
||||
@@ -190,10 +202,13 @@ def predict_no_ui_long_connection(inputs:str, llm_kwargs:dict, history:list=[],
|
||||
if (time.time()-observe_window[1]) > watch_dog_patience:
|
||||
raise RuntimeError("用户取消了程序。")
|
||||
else: raise RuntimeError("意外Json结构:"+delta)
|
||||
if json_data and json_data['finish_reason'] == 'content_filter':
|
||||
raise RuntimeError("由于提问含不合规内容被Azure过滤。")
|
||||
if json_data and json_data['finish_reason'] == 'length':
|
||||
|
||||
finish_reason = json_data.get('finish_reason', None) if json_data else None
|
||||
if finish_reason == 'content_filter':
|
||||
raise RuntimeError("由于提问含不合规内容被过滤。")
|
||||
if finish_reason == 'length':
|
||||
raise ConnectionAbortedError("正常结束,但显示Token不足,导致输出不完整,请削减单次输入的文本量。")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -208,7 +223,7 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
||||
chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容
|
||||
additional_fn代表点击的哪个按钮,按钮见functional.py
|
||||
"""
|
||||
from .bridge_all import model_info
|
||||
from request_llms.bridge_all import model_info
|
||||
if is_any_api_key(inputs):
|
||||
chatbot._cookies['api_key'] = inputs
|
||||
chatbot.append(("输入已识别为openai的api_key", what_keys(inputs)))
|
||||
@@ -237,6 +252,10 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
||||
chatbot.append((_inputs, ""))
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg="等待响应") # 刷新界面
|
||||
|
||||
# 禁用stream的特殊模型处理
|
||||
if model_info[llm_kwargs['llm_model']].get('openai_disable_stream', False): stream = False
|
||||
else: stream = True
|
||||
|
||||
# check mis-behavior
|
||||
if is_the_upload_folder(user_input):
|
||||
chatbot[-1] = (inputs, f"[Local Message] 检测到操作错误!当您上传文档之后,需点击“**函数插件区**”按钮进行处理,请勿点击“提交”按钮或者“基础功能区”按钮。")
|
||||
@@ -270,7 +289,7 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
||||
try:
|
||||
# make a POST request to the API endpoint, stream=True
|
||||
response = requests.post(endpoint, headers=headers, proxies=proxies,
|
||||
json=payload, stream=True, timeout=TIMEOUT_SECONDS);break
|
||||
json=payload, stream=stream, timeout=TIMEOUT_SECONDS);break
|
||||
except:
|
||||
retry += 1
|
||||
chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg))
|
||||
@@ -278,10 +297,15 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg="请求超时"+retry_msg) # 刷新界面
|
||||
if retry > MAX_RETRY: raise TimeoutError
|
||||
|
||||
gpt_replying_buffer = ""
|
||||
|
||||
is_head_of_the_stream = True
|
||||
if not stream:
|
||||
# 该分支仅适用于不支持stream的o1模型,其他情形一律不适用
|
||||
yield from handle_o1_model_special(response, inputs, llm_kwargs, chatbot, history)
|
||||
return
|
||||
|
||||
if stream:
|
||||
gpt_replying_buffer = ""
|
||||
is_head_of_the_stream = True
|
||||
stream_response = response.iter_lines()
|
||||
while True:
|
||||
try:
|
||||
@@ -317,7 +341,6 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
||||
# 前者是API2D的结束条件,后者是OPENAI的结束条件
|
||||
if ('data: [DONE]' in chunk_decoded) or (len(chunkjson['choices'][0]["delta"]) == 0):
|
||||
# 判定为数据流的结束,gpt_replying_buffer也写完了
|
||||
# logging.info(f'[response] {gpt_replying_buffer}')
|
||||
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer)
|
||||
break
|
||||
# 处理数据流的主体
|
||||
@@ -343,12 +366,24 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
||||
chunk_decoded = chunk.decode()
|
||||
error_msg = chunk_decoded
|
||||
chatbot, history = handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg)
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg="Json异常" + error_msg) # 刷新界面
|
||||
print(error_msg)
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg="Json解析异常" + error_msg) # 刷新界面
|
||||
logger.error(error_msg)
|
||||
return
|
||||
return # return from stream-branch
|
||||
|
||||
def handle_o1_model_special(response, inputs, llm_kwargs, chatbot, history):
|
||||
try:
|
||||
chunkjson = json.loads(response.content.decode())
|
||||
gpt_replying_buffer = chunkjson['choices'][0]["message"]["content"]
|
||||
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer)
|
||||
history[-1] = gpt_replying_buffer
|
||||
chatbot[-1] = (history[-2], history[-1])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
except Exception as e:
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg="Json解析异常" + response.text) # 刷新界面
|
||||
|
||||
def handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg):
|
||||
from .bridge_all import model_info
|
||||
from request_llms.bridge_all import model_info
|
||||
openai_website = ' 请登录OpenAI查看详情 https://platform.openai.com/signup'
|
||||
if "reduce the length" in error_msg:
|
||||
if len(history) >= 2: history[-1] = ""; history[-2] = "" # 清除当前溢出的输入:history[-2] 是本次输入, history[-1] 是本次输出
|
||||
@@ -381,6 +416,8 @@ def generate_payload(inputs:str, llm_kwargs:dict, history:list, system_prompt:st
|
||||
"""
|
||||
整合所有信息,选择LLM模型,生成http请求,为发送请求做准备
|
||||
"""
|
||||
from request_llms.bridge_all import model_info
|
||||
|
||||
if not is_any_api_key(llm_kwargs['api_key']):
|
||||
raise AssertionError("你提供了错误的API_KEY。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。")
|
||||
|
||||
@@ -409,10 +446,16 @@ def generate_payload(inputs:str, llm_kwargs:dict, history:list, system_prompt:st
|
||||
else:
|
||||
enable_multimodal_capacity = False
|
||||
|
||||
conversation_cnt = len(history) // 2
|
||||
openai_disable_system_prompt = model_info[llm_kwargs['llm_model']].get('openai_disable_system_prompt', False)
|
||||
|
||||
if openai_disable_system_prompt:
|
||||
messages = [{"role": "user", "content": system_prompt}]
|
||||
else:
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
if not enable_multimodal_capacity:
|
||||
# 不使用多模态能力
|
||||
conversation_cnt = len(history) // 2
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
if conversation_cnt:
|
||||
for index in range(0, 2*conversation_cnt, 2):
|
||||
what_i_have_asked = {}
|
||||
@@ -434,8 +477,6 @@ def generate_payload(inputs:str, llm_kwargs:dict, history:list, system_prompt:st
|
||||
messages.append(what_i_ask_now)
|
||||
else:
|
||||
# 多模态能力
|
||||
conversation_cnt = len(history) // 2
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
if conversation_cnt:
|
||||
for index in range(0, 2*conversation_cnt, 2):
|
||||
what_i_have_asked = {}
|
||||
@@ -486,7 +527,6 @@ def generate_payload(inputs:str, llm_kwargs:dict, history:list, system_prompt:st
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-3.5-turbo-0301",
|
||||
])
|
||||
logging.info("Random select model:" + model)
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
@@ -496,10 +536,6 @@ def generate_payload(inputs:str, llm_kwargs:dict, history:list, system_prompt:st
|
||||
"n": 1,
|
||||
"stream": stream,
|
||||
}
|
||||
# try:
|
||||
# print(f" {llm_kwargs['llm_model']} : {conversation_cnt} : {inputs[:100]} ..........")
|
||||
# except:
|
||||
# print('输入中可能存在乱码。')
|
||||
|
||||
return headers,payload
|
||||
|
||||
|
||||
|
||||
@@ -8,15 +8,15 @@
|
||||
2. predict_no_ui_long_connection:支持多线程
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
import requests
|
||||
import base64
|
||||
import os
|
||||
import glob
|
||||
from loguru import logger
|
||||
from toolbox import get_conf, update_ui, is_any_api_key, select_api_key, what_keys, clip_history, trimmed_format_exc, is_the_upload_folder, \
|
||||
update_ui_lastest_msg, get_max_token, encode_image, have_any_recent_upload_image_files
|
||||
update_ui_lastest_msg, get_max_token, encode_image, have_any_recent_upload_image_files, log_chat
|
||||
|
||||
|
||||
proxies, TIMEOUT_SECONDS, MAX_RETRY, API_ORG, AZURE_CFG_ARRAY = \
|
||||
@@ -100,7 +100,6 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
||||
inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)
|
||||
|
||||
raw_input = inputs
|
||||
logging.info(f'[raw_input] {raw_input}')
|
||||
def make_media_input(inputs, image_paths):
|
||||
for image_path in image_paths:
|
||||
inputs = inputs + f'<br/><br/><div align="center"><img src="file={os.path.abspath(image_path)}"></div>'
|
||||
@@ -185,7 +184,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
||||
# 判定为数据流的结束,gpt_replying_buffer也写完了
|
||||
lastmsg = chatbot[-1][-1] + f"\n\n\n\n「{llm_kwargs['llm_model']}调用结束,该模型不具备上下文对话能力,如需追问,请及时切换模型。」"
|
||||
yield from update_ui_lastest_msg(lastmsg, chatbot, history, delay=1)
|
||||
logging.info(f'[response] {gpt_replying_buffer}')
|
||||
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer)
|
||||
break
|
||||
# 处理数据流的主体
|
||||
status_text = f"finish_reason: {chunkjson['choices'][0].get('finish_reason', 'null')}"
|
||||
@@ -210,7 +209,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
||||
error_msg = chunk_decoded
|
||||
chatbot, history = handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg, api_key)
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg="Json异常" + error_msg) # 刷新界面
|
||||
print(error_msg)
|
||||
logger.error(error_msg)
|
||||
return
|
||||
|
||||
def handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg, api_key=""):
|
||||
@@ -301,10 +300,7 @@ def generate_payload(inputs, llm_kwargs, history, system_prompt, image_paths):
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0,
|
||||
}
|
||||
try:
|
||||
print(f" {llm_kwargs['llm_model']} : {inputs[:100]} ..........")
|
||||
except:
|
||||
print('输入中可能存在乱码。')
|
||||
|
||||
return headers, payload, api_key
|
||||
|
||||
|
||||
|
||||
@@ -1,281 +0,0 @@
|
||||
# 借鉴了 https://github.com/GaiZhenbiao/ChuanhuChatGPT 项目
|
||||
|
||||
"""
|
||||
该文件中主要包含三个函数
|
||||
|
||||
不具备多线程能力的函数:
|
||||
1. predict: 正常对话时使用,具备完备的交互功能,不可多线程
|
||||
|
||||
具备多线程调用能力的函数
|
||||
2. predict_no_ui_long_connection:支持多线程
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import gradio as gr
|
||||
import logging
|
||||
import traceback
|
||||
import requests
|
||||
import importlib
|
||||
|
||||
# config_private.py放自己的秘密如API和代理网址
|
||||
# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
|
||||
from toolbox import get_conf, update_ui, is_any_api_key, select_api_key, what_keys, clip_history, trimmed_format_exc
|
||||
proxies, TIMEOUT_SECONDS, MAX_RETRY, API_ORG = \
|
||||
get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY', 'API_ORG')
|
||||
|
||||
timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
|
||||
'网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。'
|
||||
|
||||
def get_full_error(chunk, stream_response):
|
||||
"""
|
||||
获取完整的从Openai返回的报错
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
chunk += next(stream_response)
|
||||
except:
|
||||
break
|
||||
return chunk
|
||||
|
||||
|
||||
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None, console_slience=False):
|
||||
"""
|
||||
发送至chatGPT,等待回复,一次性完成,不显示中间过程。但内部用stream的方法避免中途网线被掐。
|
||||
inputs:
|
||||
是本次问询的输入
|
||||
sys_prompt:
|
||||
系统静默prompt
|
||||
llm_kwargs:
|
||||
chatGPT的内部调优参数
|
||||
history:
|
||||
是之前的对话列表
|
||||
observe_window = None:
|
||||
用于负责跨越线程传递已经输出的部分,大部分时候仅仅为了fancy的视觉效果,留空即可。observe_window[0]:观测窗。observe_window[1]:看门狗
|
||||
"""
|
||||
watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可
|
||||
headers, payload = generate_payload(inputs, llm_kwargs, history, system_prompt=sys_prompt, stream=True)
|
||||
retry = 0
|
||||
while True:
|
||||
try:
|
||||
# make a POST request to the API endpoint, stream=False
|
||||
from .bridge_all import model_info
|
||||
endpoint = model_info[llm_kwargs['llm_model']]['endpoint']
|
||||
response = requests.post(endpoint, headers=headers, proxies=proxies,
|
||||
json=payload, stream=True, timeout=TIMEOUT_SECONDS); break
|
||||
except requests.exceptions.ReadTimeout as e:
|
||||
retry += 1
|
||||
traceback.print_exc()
|
||||
if retry > MAX_RETRY: raise TimeoutError
|
||||
if MAX_RETRY!=0: print(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
|
||||
|
||||
stream_response = response.iter_lines()
|
||||
result = ''
|
||||
while True:
|
||||
try: chunk = next(stream_response).decode()
|
||||
except StopIteration:
|
||||
break
|
||||
except requests.exceptions.ConnectionError:
|
||||
chunk = next(stream_response).decode() # 失败了,重试一次?再失败就没办法了。
|
||||
if len(chunk)==0: continue
|
||||
if not chunk.startswith('data:'):
|
||||
error_msg = get_full_error(chunk.encode('utf8'), stream_response).decode()
|
||||
if "reduce the length" in error_msg:
|
||||
raise ConnectionAbortedError("OpenAI拒绝了请求:" + error_msg)
|
||||
else:
|
||||
raise RuntimeError("OpenAI拒绝了请求:" + error_msg)
|
||||
if ('data: [DONE]' in chunk): break # api2d 正常完成
|
||||
json_data = json.loads(chunk.lstrip('data:'))['choices'][0]
|
||||
delta = json_data["delta"]
|
||||
if len(delta) == 0: break
|
||||
if "role" in delta: continue
|
||||
if "content" in delta:
|
||||
result += delta["content"]
|
||||
if not console_slience: print(delta["content"], end='')
|
||||
if observe_window is not None:
|
||||
# 观测窗,把已经获取的数据显示出去
|
||||
if len(observe_window) >= 1: observe_window[0] += delta["content"]
|
||||
# 看门狗,如果超过期限没有喂狗,则终止
|
||||
if len(observe_window) >= 2:
|
||||
if (time.time()-observe_window[1]) > watch_dog_patience:
|
||||
raise RuntimeError("用户取消了程序。")
|
||||
else: raise RuntimeError("意外Json结构:"+delta)
|
||||
if json_data['finish_reason'] == 'content_filter':
|
||||
raise RuntimeError("由于提问含不合规内容被Azure过滤。")
|
||||
if json_data['finish_reason'] == 'length':
|
||||
raise ConnectionAbortedError("正常结束,但显示Token不足,导致输出不完整,请削减单次输入的文本量。")
|
||||
return result
|
||||
|
||||
|
||||
def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream = True, additional_fn=None):
|
||||
"""
|
||||
发送至chatGPT,流式获取输出。
|
||||
用于基础的对话功能。
|
||||
inputs 是本次问询的输入
|
||||
top_p, temperature是chatGPT的内部调优参数
|
||||
history 是之前的对话列表(注意无论是inputs还是history,内容太长了都会触发token数量溢出的错误)
|
||||
chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容
|
||||
additional_fn代表点击的哪个按钮,按钮见functional.py
|
||||
"""
|
||||
if additional_fn is not None:
|
||||
from core_functional import handle_core_functionality
|
||||
inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)
|
||||
|
||||
raw_input = inputs
|
||||
logging.info(f'[raw_input] {raw_input}')
|
||||
chatbot.append((inputs, ""))
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg="等待响应") # 刷新界面
|
||||
|
||||
try:
|
||||
headers, payload = generate_payload(inputs, llm_kwargs, history, system_prompt, stream)
|
||||
except RuntimeError as e:
|
||||
chatbot[-1] = (inputs, f"您提供的api-key不满足要求,不包含任何可用于{llm_kwargs['llm_model']}的api-key。您可能选择了错误的模型或请求源。")
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg="api-key不满足要求") # 刷新界面
|
||||
return
|
||||
|
||||
history.append(inputs); history.append("")
|
||||
|
||||
retry = 0
|
||||
while True:
|
||||
try:
|
||||
# make a POST request to the API endpoint, stream=True
|
||||
from .bridge_all import model_info
|
||||
endpoint = model_info[llm_kwargs['llm_model']]['endpoint']
|
||||
response = requests.post(endpoint, headers=headers, proxies=proxies,
|
||||
json=payload, stream=True, timeout=TIMEOUT_SECONDS);break
|
||||
except:
|
||||
retry += 1
|
||||
chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg))
|
||||
retry_msg = f",正在重试 ({retry}/{MAX_RETRY}) ……" if MAX_RETRY > 0 else ""
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg="请求超时"+retry_msg) # 刷新界面
|
||||
if retry > MAX_RETRY: raise TimeoutError
|
||||
|
||||
gpt_replying_buffer = ""
|
||||
|
||||
is_head_of_the_stream = True
|
||||
if stream:
|
||||
stream_response = response.iter_lines()
|
||||
while True:
|
||||
try:
|
||||
chunk = next(stream_response)
|
||||
except StopIteration:
|
||||
# 非OpenAI官方接口的出现这样的报错,OpenAI和API2D不会走这里
|
||||
chunk_decoded = chunk.decode()
|
||||
error_msg = chunk_decoded
|
||||
chatbot, history = handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg)
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg="非Openai官方接口返回了错误:" + chunk.decode()) # 刷新界面
|
||||
return
|
||||
|
||||
# print(chunk.decode()[6:])
|
||||
if is_head_of_the_stream and (r'"object":"error"' not in chunk.decode()):
|
||||
# 数据流的第一帧不携带content
|
||||
is_head_of_the_stream = False; continue
|
||||
|
||||
if chunk:
|
||||
try:
|
||||
chunk_decoded = chunk.decode()
|
||||
# 前者是API2D的结束条件,后者是OPENAI的结束条件
|
||||
if 'data: [DONE]' in chunk_decoded:
|
||||
# 判定为数据流的结束,gpt_replying_buffer也写完了
|
||||
logging.info(f'[response] {gpt_replying_buffer}')
|
||||
break
|
||||
# 处理数据流的主体
|
||||
chunkjson = json.loads(chunk_decoded[6:])
|
||||
status_text = f"finish_reason: {chunkjson['choices'][0]['finish_reason']}"
|
||||
delta = chunkjson['choices'][0]["delta"]
|
||||
if "content" in delta:
|
||||
gpt_replying_buffer = gpt_replying_buffer + delta["content"]
|
||||
history[-1] = gpt_replying_buffer
|
||||
chatbot[-1] = (history[-2], history[-1])
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg=status_text) # 刷新界面
|
||||
except Exception as e:
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg="Json解析不合常规") # 刷新界面
|
||||
chunk = get_full_error(chunk, stream_response)
|
||||
chunk_decoded = chunk.decode()
|
||||
error_msg = chunk_decoded
|
||||
chatbot, history = handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg)
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg="Json异常" + error_msg) # 刷新界面
|
||||
print(error_msg)
|
||||
return
|
||||
|
||||
def handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg):
|
||||
from .bridge_all import model_info
|
||||
openai_website = ' 请登录OpenAI查看详情 https://platform.openai.com/signup'
|
||||
if "reduce the length" in error_msg:
|
||||
if len(history) >= 2: history[-1] = ""; history[-2] = "" # 清除当前溢出的输入:history[-2] 是本次输入, history[-1] 是本次输出
|
||||
history = clip_history(inputs=inputs, history=history, tokenizer=model_info[llm_kwargs['llm_model']]['tokenizer'],
|
||||
max_token_limit=(model_info[llm_kwargs['llm_model']]['max_token'])) # history至少释放二分之一
|
||||
chatbot[-1] = (chatbot[-1][0], "[Local Message] Reduce the length. 本次输入过长, 或历史数据过长. 历史缓存数据已部分释放, 您可以请再次尝试. (若再次失败则更可能是因为输入过长.)")
|
||||
# history = [] # 清除历史
|
||||
elif "does not exist" in error_msg:
|
||||
chatbot[-1] = (chatbot[-1][0], f"[Local Message] Model {llm_kwargs['llm_model']} does not exist. 模型不存在, 或者您没有获得体验资格.")
|
||||
elif "Incorrect API key" in error_msg:
|
||||
chatbot[-1] = (chatbot[-1][0], "[Local Message] Incorrect API key. OpenAI以提供了不正确的API_KEY为由, 拒绝服务. " + openai_website)
|
||||
elif "exceeded your current quota" in error_msg:
|
||||
chatbot[-1] = (chatbot[-1][0], "[Local Message] You exceeded your current quota. OpenAI以账户额度不足为由, 拒绝服务." + openai_website)
|
||||
elif "account is not active" in error_msg:
|
||||
chatbot[-1] = (chatbot[-1][0], "[Local Message] Your account is not active. OpenAI以账户失效为由, 拒绝服务." + openai_website)
|
||||
elif "associated with a deactivated account" in error_msg:
|
||||
chatbot[-1] = (chatbot[-1][0], "[Local Message] You are associated with a deactivated account. OpenAI以账户失效为由, 拒绝服务." + openai_website)
|
||||
elif "bad forward key" in error_msg:
|
||||
chatbot[-1] = (chatbot[-1][0], "[Local Message] Bad forward key. API2D账户额度不足.")
|
||||
elif "Not enough point" in error_msg:
|
||||
chatbot[-1] = (chatbot[-1][0], "[Local Message] Not enough point. API2D账户点数不足.")
|
||||
else:
|
||||
from toolbox import regular_txt_to_markdown
|
||||
tb_str = '```\n' + trimmed_format_exc() + '```'
|
||||
chatbot[-1] = (chatbot[-1][0], f"[Local Message] 异常 \n\n{tb_str} \n\n{regular_txt_to_markdown(chunk_decoded)}")
|
||||
return chatbot, history
|
||||
|
||||
def generate_payload(inputs, llm_kwargs, history, system_prompt, stream):
|
||||
"""
|
||||
整合所有信息,选择LLM模型,生成http请求,为发送请求做准备
|
||||
"""
|
||||
if not is_any_api_key(llm_kwargs['api_key']):
|
||||
raise AssertionError("你提供了错误的API_KEY。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。")
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
conversation_cnt = len(history) // 2
|
||||
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
if conversation_cnt:
|
||||
for index in range(0, 2*conversation_cnt, 2):
|
||||
what_i_have_asked = {}
|
||||
what_i_have_asked["role"] = "user"
|
||||
what_i_have_asked["content"] = history[index]
|
||||
what_gpt_answer = {}
|
||||
what_gpt_answer["role"] = "assistant"
|
||||
what_gpt_answer["content"] = history[index+1]
|
||||
if what_i_have_asked["content"] != "":
|
||||
if what_gpt_answer["content"] == "": continue
|
||||
if what_gpt_answer["content"] == timeout_bot_msg: continue
|
||||
messages.append(what_i_have_asked)
|
||||
messages.append(what_gpt_answer)
|
||||
else:
|
||||
messages[-1]['content'] = what_gpt_answer['content']
|
||||
|
||||
what_i_ask_now = {}
|
||||
what_i_ask_now["role"] = "user"
|
||||
what_i_ask_now["content"] = inputs
|
||||
messages.append(what_i_ask_now)
|
||||
|
||||
payload = {
|
||||
"model": llm_kwargs['llm_model'].strip('api2d-'),
|
||||
"messages": messages,
|
||||
"temperature": llm_kwargs['temperature'], # 1.0,
|
||||
"top_p": llm_kwargs['top_p'], # 1.0,
|
||||
"n": 1,
|
||||
"stream": stream,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 0,
|
||||
}
|
||||
try:
|
||||
print(f" {llm_kwargs['llm_model']} : {conversation_cnt} : {inputs[:100]} ..........")
|
||||
except:
|
||||
print('输入中可能存在乱码。')
|
||||
return headers,payload
|
||||
|
||||
|
||||
@@ -9,13 +9,14 @@
|
||||
具备多线程调用能力的函数
|
||||
2. predict_no_ui_long_connection:支持多线程
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
import json
|
||||
import requests
|
||||
from loguru import logger
|
||||
from toolbox import get_conf, update_ui, trimmed_format_exc, encode_image, every_image_file_in_path, log_chat
|
||||
|
||||
picture_system_prompt = "\n当回复图像时,必须说明正在回复哪张图像。所有图像仅在最后一个问题中提供,即使它们在历史记录中被提及。请使用'这是第X张图像:'的格式来指明您正在描述的是哪张图像。"
|
||||
Claude_3_Models = ["claude-3-haiku-20240307", "claude-3-sonnet-20240229", "claude-3-opus-20240229", "claude-3-5-sonnet-20240620"]
|
||||
|
||||
@@ -101,7 +102,7 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
|
||||
retry += 1
|
||||
traceback.print_exc()
|
||||
if retry > MAX_RETRY: raise TimeoutError
|
||||
if MAX_RETRY!=0: print(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
|
||||
if MAX_RETRY!=0: logger.error(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
|
||||
stream_response = response.iter_lines()
|
||||
result = ''
|
||||
while True:
|
||||
@@ -116,12 +117,11 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
|
||||
if need_to_pass:
|
||||
pass
|
||||
elif is_last_chunk:
|
||||
# logging.info(f'[response] {result}')
|
||||
# logger.info(f'[response] {result}')
|
||||
break
|
||||
else:
|
||||
if chunkjson and chunkjson['type'] == 'content_block_delta':
|
||||
result += chunkjson['delta']['text']
|
||||
print(chunkjson['delta']['text'], end='')
|
||||
if observe_window is not None:
|
||||
# 观测窗,把已经获取的数据显示出去
|
||||
if len(observe_window) >= 1:
|
||||
@@ -134,7 +134,7 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
|
||||
chunk = get_full_error(chunk, stream_response)
|
||||
chunk_decoded = chunk.decode()
|
||||
error_msg = chunk_decoded
|
||||
print(error_msg)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError("Json解析不合常规")
|
||||
|
||||
return result
|
||||
@@ -200,7 +200,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
||||
retry += 1
|
||||
traceback.print_exc()
|
||||
if retry > MAX_RETRY: raise TimeoutError
|
||||
if MAX_RETRY!=0: print(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
|
||||
if MAX_RETRY!=0: logger.error(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
|
||||
stream_response = response.iter_lines()
|
||||
gpt_replying_buffer = ""
|
||||
|
||||
@@ -217,7 +217,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
||||
pass
|
||||
elif is_last_chunk:
|
||||
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer)
|
||||
# logging.info(f'[response] {gpt_replying_buffer}')
|
||||
# logger.info(f'[response] {gpt_replying_buffer}')
|
||||
break
|
||||
else:
|
||||
if chunkjson and chunkjson['type'] == 'content_block_delta':
|
||||
@@ -230,7 +230,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
||||
chunk = get_full_error(chunk, stream_response)
|
||||
chunk_decoded = chunk.decode()
|
||||
error_msg = chunk_decoded
|
||||
print(error_msg)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError("Json解析不合常规")
|
||||
|
||||
def multiple_picture_types(image_paths):
|
||||
|
||||
@@ -13,11 +13,9 @@
|
||||
import json
|
||||
import time
|
||||
import gradio as gr
|
||||
import logging
|
||||
import traceback
|
||||
import requests
|
||||
import importlib
|
||||
import random
|
||||
from loguru import logger
|
||||
|
||||
# config_private.py放自己的秘密如API和代理网址
|
||||
# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
|
||||
@@ -98,7 +96,7 @@ def predict_no_ui_long_connection(inputs:str, llm_kwargs:dict, history:list=[],
|
||||
retry += 1
|
||||
traceback.print_exc()
|
||||
if retry > MAX_RETRY: raise TimeoutError
|
||||
if MAX_RETRY!=0: print(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
|
||||
if MAX_RETRY!=0: logger.error(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
|
||||
|
||||
stream_response = response.iter_lines()
|
||||
result = ''
|
||||
@@ -153,7 +151,7 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
||||
inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)
|
||||
|
||||
raw_input = inputs
|
||||
# logging.info(f'[raw_input] {raw_input}')
|
||||
# logger.info(f'[raw_input] {raw_input}')
|
||||
chatbot.append((inputs, ""))
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg="等待响应") # 刷新界面
|
||||
|
||||
@@ -237,7 +235,7 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
||||
error_msg = chunk_decoded
|
||||
chatbot, history = handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg)
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg="Json异常" + error_msg) # 刷新界面
|
||||
print(error_msg)
|
||||
logger.error(error_msg)
|
||||
return
|
||||
|
||||
def handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg):
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
model_name = "deepseek-coder-6.7b-instruct"
|
||||
cmd_to_install = "未知" # "`pip install -r request_llms/requirements_qwen.txt`"
|
||||
|
||||
import os
|
||||
from toolbox import ProxyNetworkActivate
|
||||
from toolbox import get_conf
|
||||
from .local_llm_class import LocalLLMHandle, get_local_llm_predict_fns
|
||||
from request_llms.local_llm_class import LocalLLMHandle, get_local_llm_predict_fns
|
||||
from threading import Thread
|
||||
from loguru import logger
|
||||
import torch
|
||||
import os
|
||||
|
||||
def download_huggingface_model(model_name, max_retry, local_dir):
|
||||
from huggingface_hub import snapshot_download
|
||||
@@ -15,7 +16,7 @@ def download_huggingface_model(model_name, max_retry, local_dir):
|
||||
snapshot_download(repo_id=model_name, local_dir=local_dir, resume_download=True)
|
||||
break
|
||||
except Exception as e:
|
||||
print(f'\n\n下载失败,重试第{i}次中...\n\n')
|
||||
logger.error(f'\n\n下载失败,重试第{i}次中...\n\n')
|
||||
return local_dir
|
||||
# ------------------------------------------------------------------------------------------------------------------------
|
||||
# 🔌💻 Local Model
|
||||
@@ -112,7 +113,6 @@ class GetCoderLMHandle(LocalLLMHandle):
|
||||
generated_text = ""
|
||||
for new_text in self._streamer:
|
||||
generated_text += new_text
|
||||
# print(generated_text)
|
||||
yield generated_text
|
||||
|
||||
|
||||
|
||||
@@ -8,15 +8,15 @@ import os
|
||||
import time
|
||||
from request_llms.com_google import GoogleChatInit
|
||||
from toolbox import ChatBotWithCookies
|
||||
from toolbox import get_conf, update_ui, update_ui_lastest_msg, have_any_recent_upload_image_files, trimmed_format_exc, log_chat
|
||||
from toolbox import get_conf, update_ui, update_ui_lastest_msg, have_any_recent_upload_image_files, trimmed_format_exc, log_chat, encode_image
|
||||
|
||||
proxies, TIMEOUT_SECONDS, MAX_RETRY = get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY')
|
||||
timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
|
||||
'网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。'
|
||||
|
||||
|
||||
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None,
|
||||
console_slience=False):
|
||||
def predict_no_ui_long_connection(inputs:str, llm_kwargs:dict, history:list=[], sys_prompt:str="", observe_window:list=[],
|
||||
console_slience:bool=False):
|
||||
# 检查API_KEY
|
||||
if get_conf("GEMINI_API_KEY") == "":
|
||||
raise ValueError(f"请配置 GEMINI_API_KEY。")
|
||||
@@ -44,9 +44,20 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
|
||||
raise RuntimeError(f'{gpt_replying_buffer} 对话错误')
|
||||
return gpt_replying_buffer
|
||||
|
||||
def make_media_input(inputs, image_paths):
|
||||
image_base64_array = []
|
||||
for image_path in image_paths:
|
||||
path = os.path.abspath(image_path)
|
||||
inputs = inputs + f'<br/><br/><div align="center"><img src="file={path}"></div>'
|
||||
base64 = encode_image(path)
|
||||
image_base64_array.append(base64)
|
||||
return inputs, image_base64_array
|
||||
|
||||
def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWithCookies,
|
||||
history:list=[], system_prompt:str='', stream:bool=True, additional_fn:str=None):
|
||||
|
||||
from .bridge_all import model_info
|
||||
|
||||
# 检查API_KEY
|
||||
if get_conf("GEMINI_API_KEY") == "":
|
||||
yield from update_ui_lastest_msg(f"请配置 GEMINI_API_KEY。", chatbot=chatbot, history=history, delay=0)
|
||||
@@ -57,18 +68,17 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
||||
from core_functional import handle_core_functionality
|
||||
inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)
|
||||
|
||||
if "vision" in llm_kwargs["llm_model"]:
|
||||
have_recent_file, image_paths = have_any_recent_upload_image_files(chatbot)
|
||||
if not have_recent_file:
|
||||
chatbot.append((inputs, "没有检测到任何近期上传的图像文件,请上传jpg格式的图片,此外,请注意拓展名需要小写"))
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg="等待图片") # 刷新界面
|
||||
return
|
||||
def make_media_input(inputs, image_paths):
|
||||
for image_path in image_paths:
|
||||
inputs = inputs + f'<br/><br/><div align="center"><img src="file={os.path.abspath(image_path)}"></div>'
|
||||
return inputs
|
||||
if have_recent_file:
|
||||
inputs = make_media_input(inputs, image_paths)
|
||||
# multimodal capacity
|
||||
# inspired by codes in bridge_chatgpt
|
||||
has_multimodal_capacity = model_info[llm_kwargs['llm_model']].get('has_multimodal_capacity', False)
|
||||
if has_multimodal_capacity:
|
||||
has_recent_image_upload, image_paths = have_any_recent_upload_image_files(chatbot, pop=True)
|
||||
else:
|
||||
has_recent_image_upload, image_paths = False, []
|
||||
if has_recent_image_upload:
|
||||
inputs, image_base64_array = make_media_input(inputs, image_paths)
|
||||
else:
|
||||
inputs, image_base64_array = inputs, []
|
||||
|
||||
chatbot.append((inputs, ""))
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
@@ -76,7 +86,7 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
||||
retry = 0
|
||||
while True:
|
||||
try:
|
||||
stream_response = genai.generate_chat(inputs, llm_kwargs, history, system_prompt)
|
||||
stream_response = genai.generate_chat(inputs, llm_kwargs, history, system_prompt, image_base64_array, has_multimodal_capacity)
|
||||
break
|
||||
except Exception as e:
|
||||
retry += 1
|
||||
@@ -112,7 +122,6 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
llm_kwargs = {'llm_model': 'gemini-pro'}
|
||||
|
||||
@@ -65,10 +65,10 @@ class GetInternlmHandle(LocalLLMHandle):
|
||||
|
||||
def llm_stream_generator(self, **kwargs):
|
||||
import torch
|
||||
import logging
|
||||
import copy
|
||||
import warnings
|
||||
import torch.nn as nn
|
||||
from loguru import logger as logging
|
||||
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
|
||||
|
||||
# 🏃♂️🏃♂️🏃♂️ 子进程执行
|
||||
@@ -119,7 +119,7 @@ class GetInternlmHandle(LocalLLMHandle):
|
||||
elif generation_config.max_new_tokens is not None:
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||
if not has_default_max_length:
|
||||
logging.warn(
|
||||
logging.warning(
|
||||
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
||||
"Please refer to the documentation for more information. "
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
|
||||
from toolbox import get_conf, update_ui, log_chat
|
||||
from toolbox import ChatBotWithCookies
|
||||
|
||||
某些文件未显示,因为此 diff 中更改的文件太多 显示更多
在新工单中引用
屏蔽一个用户