比较提交

...

79 次代码提交

作者 SHA1 备注 提交日期
binary-husky
d94a571eb5 block unstable 2025-08-23 15:59:05 +08:00
XiaoBoAI
91f28c2721 back 2025-07-21 02:18:46 +08:00
XiaoBoAI
6813ba88bb feat: 为predict函数添加文件和URL读取功能
- 添加URL检测和网页内容提取功能,支持自动提取网页文本
- 添加文件路径识别和文件内容读取功能,支持private_upload路径格式
- 集成WebTextExtractor处理网页内容提取
- 集成TextContentLoader处理本地文件读取
- 支持文件路径与问题组合的智能处理
2025-07-19 19:06:38 +08:00
binary-husky
fb5189fd96 accelerate nltk 2025-07-16 00:57:45 +08:00
binary-husky
9945340277 merge more academic plugins 2025-07-14 02:20:20 +08:00
binary-husky
55607cbe8b file dynamic load 2025-07-13 02:51:14 +08:00
binary-husky
a49085088c Merge branch 'master' into master-4.0 2025-07-13 01:30:16 +08:00
binary-husky
a7a56b5058 fix buggy gradio version 2025-06-25 01:34:33 +08:00
binary-husky
8e0332c71b fix buggy gradio version 2025-06-25 01:34:07 +08:00
binary-husky
90d1b34f5e stage document conversation 2025-06-23 01:12:04 +08:00
binary-husky
73f573092b stage academic conversation 2025-06-22 18:31:41 +08:00
binary-husky
8c21432291 use uv to build dockerfile 2025-06-04 02:24:09 +08:00
binary-husky
87b3f79ae9 setup nv 2025-06-04 01:53:29 +08:00
binary-husky
f42aad5093 implement doc_fns 2025-06-04 00:20:09 +08:00
binary-husky
725f60fba3 add context clip policy 2025-06-03 01:05:37 +08:00
binary-husky
be83907394 Merge branch 'master' of github.com:binary-husky/chatgpt_academic 2025-05-06 22:17:34 +08:00
binary-husky
eba48a0f1a improve reset conversation ui 2025-05-06 22:10:21 +08:00
binary-husky
ee1a9e7cce support qwen3 models - edit config hint 2025-04-29 11:10:49 +08:00
binary-husky
fc06be6f7a support qwen3 models 2025-04-29 11:09:53 +08:00
binary-husky
883b513b91 add can_multi_thread 2025-04-21 00:50:24 +08:00
binary-husky
24cebaf4ec add o3 and o4 models 2025-04-21 00:48:59 +08:00
binary-husky
858b5f69b0 add in-text stop btn 2025-04-15 01:08:54 +08:00
davidfir3
63c61e6204 添加gemini-2.0-flash (#2180)
Co-authored-by: 柯仕锋 <12029064@zju.edu.cn>
2025-03-25 00:13:18 +08:00
BZfei
82aac97980 阿里云百炼(原灵积)增加对deepseek-r1、deepseek-v3模型支持 (#2182)
* 阿里云百炼(原灵积)增加对deepseek-r1、deepseek-v3模型支持

* update reasoning display

---------

Co-authored-by: binary-husky <qingxu.fu@outlook.com>
2025-03-25 00:11:55 +08:00
binary-husky
045cdb15d8 ensure display none even if css load fails 2025-03-10 23:44:47 +08:00
binary-husky
e78e8b0909 allow copy original text instead of renderend text 2025-03-09 00:04:27 +08:00
binary-husky
07974a26d0 Merge branch 'master' of github.com:binary-husky/chatgpt_academic 2025-03-08 23:10:42 +08:00
binary-husky
3e56c074cc fix gui_toolbar 2025-03-08 23:09:22 +08:00
littleolaf
72dbe856d2 添加接入 火山引擎在线大模型 内容的支持 (#2165)
* use oai adaptive bridge function to handle vol engine

* add vol engine deepseek v3

---------

Co-authored-by: binary-husky <qingxu.fu@outlook.com>
2025-03-04 23:58:03 +08:00
Steven Moder
4a79aa6a93 typo: Fix typos and rename functions across multiple files (#2130)
* typo: Fix typos and rename functions across multiple files

This commit addresses several minor issues:
- Corrected spelling of function names (e.g., `update_ui_lastest_msg` to `update_ui_latest_msg`)
- Fixed typos in comments and variable names
- Corrected capitalization in some strings (e.g., "ArXiv" instead of "Arixv")
- Renamed some variables for consistency
- Corrected some console-related parameter names (e.g., `console_slience` to `console_silence`)

The changes span multiple files across the project, including request LLM bridges, crazy functions, and utility modules.

* fix: f-string expression part cannot include a backslash (#2139)

* raise error when the uploaded tar contain hard/soft link (#2136)

* minor bug fix

* fine tune reasoning css

* upgrade internet gpt plugin

* Update README.md

* fix GHSA-gqp5-wm97-qxcv

* typo fix

* update readme

---------

Co-authored-by: binary-husky <96192199+binary-husky@users.noreply.github.com>
Co-authored-by: binary-husky <qingxu.fu@outlook.com>
2025-03-02 02:16:10 +08:00
binary-husky
5dffe8627f fix GHSA-gqp5-wm97-qxcv 2025-03-02 01:58:45 +08:00
binary-husky
2aefef26db Update README.md 2025-02-21 19:51:09 +08:00
binary-husky
957da731db upgrade internet gpt plugin 2025-02-13 00:19:43 +08:00
binary-husky
add29eba08 fine tune reasoning css 2025-02-09 20:26:52 +08:00
binary-husky
163e59c0f3 minor bug fix 2025-02-09 19:33:02 +08:00
binary-husky
07ece29c7c raise error when the uploaded tar contain hard/soft link (#2136) 2025-02-08 20:54:01 +08:00
Steven Moder
991a903fa9 fix: f-string expression part cannot include a backslash (#2139) 2025-02-08 20:50:54 +08:00
Steven Moder
cf7c81170c fix: return 参数数量 及 返回类型考虑 (#2129) 2025-02-07 21:33:06 +08:00
barry
6dda2061dd Update bridge_openrouter.py (#2132)
fix openrouter api 400 post bug

Co-authored-by: lan <56376794+lostatnight@users.noreply.github.com>
2025-02-07 21:28:05 +08:00
binary-husky
8a0d96afd3 consider element missing cases in js 2025-02-07 01:21:21 +08:00
binary-husky
37f9b94dee add options to hide ui components 2025-02-07 00:17:36 +08:00
binary-husky
936e2f5206 update readme 2025-02-04 16:15:56 +08:00
binary-husky
7f4b87a633 update readme 2025-02-04 16:08:18 +08:00
binary-husky
2ddd1bb634 Merge branch 'memset0-master' 2025-02-04 16:03:53 +08:00
binary-husky
c68285aeac update config and version 2025-02-04 16:03:01 +08:00
Memento mori.
caaebe4296 add support for Deepseek R1 model and display CoT (#2118)
* feat: add support for R1 model and display CoT

* fix unpacking

* feat: customized font & font size

* auto hide tooltip when scoll down

* tooltip glass transparent css

* fix: Enhance API key validation in is_any_api_key function (#2113)

* support qwen2.5-max!

* update minior adjustment

---------

Co-authored-by: binary-husky <qingxu.fu@outlook.com>
Co-authored-by: Steven Moder <java20131114@gmail.com>
2025-02-04 16:02:02 +08:00
binary-husky
39d50c1c95 update minior adjustment 2025-02-04 15:57:35 +08:00
binary-husky
25dc7bf912 Merge branch 'master' of https://github.com/memset0/gpt_academic into memset0-master 2025-01-30 22:03:31 +08:00
binary-husky
0458590a77 support qwen2.5-max! 2025-01-29 23:29:38 +08:00
Steven Moder
44fe78fff5 fix: Enhance API key validation in is_any_api_key function (#2113) 2025-01-29 21:40:30 +08:00
binary-husky
5ddd657ebc tooltip glass transparent css 2025-01-28 23:50:21 +08:00
binary-husky
9b0b2cf260 auto hide tooltip when scoll down 2025-01-28 23:32:40 +08:00
binary-husky
9f39a6571a feat: customized font & font size 2025-01-28 02:52:56 +08:00
memset0
d07e736214 fix unpacking 2025-01-25 00:00:13 +08:00
memset0
a1f7ae5b55 feat: add support for R1 model and display CoT 2025-01-24 14:43:49 +08:00
binary-husky
1213ef19e5 Merge branch 'master' of github.com:binary-husky/chatgpt_academic 2025-01-22 01:50:08 +08:00
binary-husky
aaafe2a797 fix xelatex font problem in all-cap image 2025-01-22 01:49:53 +08:00
binary-husky
2716606f0c Update README.md 2025-01-16 23:40:24 +08:00
binary-husky
286f7303be fix image display bug 2025-01-12 21:54:43 +08:00
binary-husky
7eeab9e376 fix code block display bug 2025-01-09 22:31:59 +08:00
binary-husky
4ca331fb28 prevent html rendering for input 2025-01-05 21:20:12 +08:00
binary-husky
9487829930 change max_chat_preserve = 10 2025-01-03 00:34:36 +08:00
binary-husky
a73074b89e upgrade chat checkpoint 2025-01-03 00:31:03 +08:00
Southlandi
fd93622840 修复Gemini对话错误问题(停用词数量为0的情况) (#2092) 2024-12-28 23:22:10 +08:00
whyXVI
09a82a572d Fix RuntimeError in predict_no_ui_long_connection() (#2095)
Bug fix: Fix RuntimeError in predict_no_ui_long_connection()

In the original code, calling predict_no_ui_long_connection() would trigger a RuntimeError("OpenAI拒绝了请求:" + error_msg) even when the server responded normally. The issue occurred due to incorrect handling of SSE protocol comment lines (lines starting with ":"). 

Modified the parsing logic in both `predict` and `predict_no_ui_long_connection` to handle these lines correctly, making the logic more intuitive and robust.
2024-12-28 23:21:14 +08:00
G.RQ
c53ddf65aa 修复 bug“重置”按钮报错 (#2102)
* fix 重置按钮bug

* fix version control bug

---------

Co-authored-by: binary-husky <qingxu.fu@outlook.com>
2024-12-28 23:19:25 +08:00
binary-husky
ac64a77c2d allow disable openai proxy in WHEN_TO_USE_PROXY 2024-12-28 07:14:54 +08:00
binary-husky
dae8a0affc compat bug fix 2024-12-25 01:21:58 +08:00
binary-husky
97a81e9388 fix temp issue of o1 2024-12-25 00:54:03 +08:00
binary-husky
1dd1d0ed6c fix cookie overflow bug 2024-12-25 00:33:20 +08:00
binary-husky
060af0d2e6 Merge branch 'master' of github.com:binary-husky/chatgpt_academic 2024-12-22 23:33:44 +08:00
binary-husky
a848f714b6 fix welcome card bugs 2024-12-22 23:33:22 +08:00
binary-husky
924f8e30c7 Update issue stale.yml 2024-12-22 14:16:18 +08:00
binary-husky
f40347665b github action change 2024-12-22 14:15:16 +08:00
binary-husky
734c40bbde fix non-localhost javascript error 2024-12-22 14:01:22 +08:00
binary-husky
4ec87fbb54 history ng patch 1 2024-12-21 11:27:53 +08:00
binary-husky
17b5c22e61 Merge branch 'master' of github.com:binary-husky/chatgpt_academic 2024-12-19 22:46:14 +08:00
binary-husky
c6cd04a407 promote the rank of DASHSCOPE_API_KEY 2024-12-19 22:39:14 +08:00
YIQI JIANG
f60a12f8b4 Add o1 and o1-2024-12-17 model support (#2090)
* Add o1 and o1-2024-12-17 model support

* patch api key selection

---------

Co-authored-by: 蒋翌琪 <jiangyiqi99@jiangyiqideMacBook-Pro.local>
Co-authored-by: binary-husky <qingxu.fu@outlook.com>
2024-12-19 22:32:57 +08:00
共有 199 个文件被更改,包括 29061 次插入826 次删除

查看文件

@@ -7,7 +7,7 @@
name: 'Close stale issues and PRs'
on:
schedule:
- cron: '*/5 * * * *'
- cron: '*/30 * * * *'
jobs:
stale:
@@ -19,7 +19,6 @@ jobs:
steps:
- uses: actions/stale@v8
with:
stale-issue-message: 'This issue is stale because it has been open 100 days with no activity. Remove stale label or comment or this will be closed in 1 days.'
stale-issue-message: 'This issue is stale because it has been open 100 days with no activity. Remove stale label or comment or this will be closed in 7 days.'
days-before-stale: 100
days-before-close: 1
debug-only: true
days-before-close: 7

2
.gitignore vendored
查看文件

@@ -163,3 +163,5 @@ objdump*
TODO
experimental_mods
search_results
gg.docx
unstructured_reader.py

查看文件

@@ -3,37 +3,38 @@
# - 如何构建: 先修改 `config.py`, 然后 `docker build -t gpt-academic . `
# - 如何运行(Linux下): `docker run --rm -it --net=host gpt-academic `
# - 如何运行(其他操作系统,选择任意一个固定端口50923): `docker run --rm -it -e WEB_PORT=50923 -p 50923:50923 gpt-academic `
FROM python:3.11
FROM ghcr.io/astral-sh/uv:python3.12-bookworm
# 非必要步骤,更换pip源 (以下三行,可以删除)
RUN echo '[global]' > /etc/pip.conf && \
echo 'index-url = https://mirrors.aliyun.com/pypi/simple/' >> /etc/pip.conf && \
echo 'trusted-host = mirrors.aliyun.com' >> /etc/pip.conf
# 语音输出功能以下两行,第一行更换阿里源,第二行安装ffmpeg,都可以删除
RUN UBUNTU_VERSION=$(awk -F= '/^VERSION_CODENAME=/{print $2}' /etc/os-release); echo "deb https://mirrors.aliyun.com/debian/ $UBUNTU_VERSION main non-free contrib" > /etc/apt/sources.list; apt-get update
# 语音输出功能以下1,2行更换阿里源,第3,4行安装ffmpeg,都可以删除
RUN sed -i 's/deb.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list.d/debian.sources && \
sed -i 's/security.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list.d/debian.sources && \
apt-get update
RUN apt-get install ffmpeg -y
RUN apt-get clean
# 进入工作路径(必要)
WORKDIR /gpt
# 安装大部分依赖,利用Docker缓存加速以后的构建 (以下两行,可以删除)
COPY requirements.txt ./
RUN pip3 install -r requirements.txt
RUN uv venv --python=3.12 && uv pip install --verbose -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/
ENV PATH="/gpt/.venv/bin:$PATH"
RUN python -c 'import loguru'
# 装载项目文件,安装剩余依赖(必要)
COPY . .
RUN pip3 install -r requirements.txt
RUN uv venv --python=3.12 && uv pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/
# # 非必要步骤,用于预热模块(可以删除)
RUN python -c 'from check_proxy import warm_up_modules; warm_up_modules()'
# 非必要步骤,用于预热模块(可以删除)
RUN python3 -c 'from check_proxy import warm_up_modules; warm_up_modules()'
ENV CGO_ENABLED=0
# 启动(必要)
CMD ["python3", "-u", "main.py"]
CMD ["bash", "-c", "python main.py"]

查看文件

@@ -1,12 +1,14 @@
> [!IMPORTANT]
> `master主分支`最新动态(2025.3.2): 修复大量代码typo / 联网组件支持Jina的api / 增加deepseek-r1支持
> `frontier开发分支`最新动态(2024.12.9): 更新对话时间线功能,优化xelatex论文翻译
> `wiki文档`最新动态(2024.12.5): 更新ollama接入指南
>
> 2024.10.10: 突发停电,紧急恢复了提供[whl包](https://drive.google.com/file/d/19U_hsLoMrjOlQSzYS3pzWX9fTzyusArP/view?usp=sharing)的文件服务器
> 2024.10.8: 版本3.90加入对llama-index的初步支持,版本3.80加入插件二级菜单功能详见wiki
> 2025.2.2: 三分钟快速接入最强qwen2.5-max[视频](https://www.bilibili.com/video/BV1LeFuerEG4)
> 2025.2.1: 支持自定义字体
> 2024.10.10: 突发停电,紧急恢复了提供[whl包](https://drive.google.com/drive/folders/14kR-3V-lIbvGxri4AHc8TpiA1fqsw7SK?usp=sharing)的文件服务器
> 2024.5.1: 加入Doc2x翻译PDF论文的功能,[查看详情](https://github.com/binary-husky/gpt_academic/wiki/Doc2x)
> 2024.3.11: 全力支持Qwen、GLM、DeepseekCoder等中文大语言模型 SoVits语音克隆模块,[查看详情](https://www.bilibili.com/video/BV1Rp421S7tF/)
> 2024.1.17: 安装依赖时,请选择`requirements.txt`中**指定的版本**。 安装命令:`pip install -r requirements.txt`。本项目完全开源免费,您可通过订阅[在线服务](https://github.com/binary-husky/gpt_academic/wiki/online)的方式鼓励本项目的发展。
> 2024.1.17: 安装依赖时,请选择`requirements.txt`中**指定的版本**。 安装命令:`pip install -r requirements.txt`。
<br>
@@ -127,20 +129,20 @@ Latex论文一键校对 | [插件] 仿Grammarly对Latex文章进行语法、拼
```mermaid
flowchart TD
A{"安装方法"} --> W1("I. 🔑直接运行 (Windows, Linux or MacOS)")
W1 --> W11["1. Python pip包管理依赖"]
W1 --> W12["2. Anaconda包管理依赖推荐⭐"]
A{"安装方法"} --> W1("I 🔑直接运行 (Windows, Linux or MacOS)")
W1 --> W11["1 Python pip包管理依赖"]
W1 --> W12["2 Anaconda包管理依赖推荐⭐"]
A --> W2["II. 🐳使用Docker (Windows, Linux or MacOS)"]
A --> W2["II 🐳使用Docker (Windows, Linux or MacOS)"]
W2 --> k1["1. 部署项目全部能力的大镜像(推荐⭐)"]
W2 --> k2["2. 仅在线模型GPT, GLM4等镜像"]
W2 --> k3["3. 在线模型 + Latex的大镜像"]
W2 --> k1["1 部署项目全部能力的大镜像(推荐⭐)"]
W2 --> k2["2 仅在线模型GPT, GLM4等镜像"]
W2 --> k3["3 在线模型 + Latex的大镜像"]
A --> W4["IV. 🚀其他部署方法"]
W4 --> C1["1. Windows/MacOS 一键安装运行脚本(推荐⭐)"]
W4 --> C2["2. Huggingface, Sealos远程部署"]
W4 --> C4["3. ... 其他 ..."]
A --> W4["IV 🚀其他部署方法"]
W4 --> C1["1 Windows/MacOS 一键安装运行脚本(推荐⭐)"]
W4 --> C2["2 Huggingface, Sealos远程部署"]
W4 --> C4["3 其他 ..."]
```
### 安装方法I直接运行 (Windows, Linux or MacOS)
@@ -426,7 +428,6 @@ timeline LR
1. `master` 分支: 主分支,稳定版
2. `frontier` 分支: 开发分支,测试版
3. 如何[接入其他大模型](request_llms/README.md)
4. 访问GPT-Academic的[在线服务并支持我们](https://github.com/binary-husky/gpt_academic/wiki/online)
### V参考与学习

查看文件

@@ -230,6 +230,48 @@ def warm_up_modules():
enc.encode("模块预热", disallowed_special=())
enc = model_info["gpt-4"]['tokenizer']
enc.encode("模块预热", disallowed_special=())
try_warm_up_vectordb()
# def try_warm_up_vectordb():
# try:
# import os
# import nltk
# target = os.path.expanduser('~/nltk_data')
# logger.info(f'模块预热: nltk punkt (从Github下载部分文件到 {target})')
# nltk.data.path.append(target)
# nltk.download('punkt', download_dir=target)
# logger.info('模块预热完成: nltk punkt')
# except:
# logger.exception('模块预热: nltk punkt 失败,可能需要手动安装 nltk punkt')
# logger.error('模块预热: nltk punkt 失败,可能需要手动安装 nltk punkt')
def try_warm_up_vectordb():
import os
import nltk
target = os.path.expanduser('~/nltk_data')
nltk.data.path.append(target)
try:
# 尝试加载 punkt
logger.info(f'nltk模块预热')
nltk.data.find('tokenizers/punkt')
nltk.data.find('tokenizers/punkt_tab')
nltk.data.find('taggers/averaged_perceptron_tagger_eng')
logger.info('nltk模块预热完成读取本地缓存')
except:
# 如果找不到,则尝试下载
try:
logger.info(f'模块预热: nltk punkt (从 Github 下载部分文件到 {target})')
from shared_utils.nltk_downloader import Downloader
_downloader = Downloader()
_downloader.download('punkt', download_dir=target)
_downloader.download('punkt_tab', download_dir=target)
_downloader.download('averaged_perceptron_tagger_eng', download_dir=target)
logger.info('nltk模块预热完成')
except Exception:
logger.exception('模块预热: nltk punkt 失败,可能需要手动安装 nltk punkt')
def warm_up_vectordb():
"""

查看文件

@@ -7,11 +7,16 @@
Configuration reading priority: environment variable > config_private.py > config.py
"""
# [step 1]>> API_KEY = "sk-123456789xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx123456789"。极少数情况下,还需要填写组织格式如org-123456789abcdefghijklmno的,请向下翻,找 API_ORG 设置项
API_KEY = "此处填API密钥" # 可同时填写多个API-KEY,用英文逗号分割,例如API_KEY = "sk-openaikey1,sk-openaikey2,fkxxxx-api2dkey3,azure-apikey4"
# [step 1-1]>> ( 接入OpenAI模型家族 ) API_KEY = "sk-123456789xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx123456789"。极少数情况下,还需要填写组织格式如org-123456789abcdefghijklmno的,请向下翻,找 API_ORG 设置项
API_KEY = "此处填APIKEY" # 可同时填写多个API-KEY,用英文逗号分割,例如API_KEY = "sk-openaikey1,sk-openaikey2,fkxxxx-api2dkey3,azure-apikey4"
# [step 1-2]>> ( 强烈推荐!接入通义家族 & 大模型服务平台百炼 ) 接入通义千问在线大模型,api-key获取地址 https://dashscope.console.aliyun.com/
DASHSCOPE_API_KEY = "" # 阿里灵积云API_KEY用于接入qwen-max,dashscope-qwen3-14b,dashscope-deepseek-r1等
# [step 2]>> 改为True应用代理,如果直接在海外服务器部署,此处不修改;如果使用本地或无地域限制的大模型时,此处也不需要修改
# [step 1-3]>> ( 接入 deepseek-reasoner, 即 deepseek-r1 ) 深度求索(DeepSeek) API KEY,默认请求地址为"https://api.deepseek.com/v1/chat/completions"
DEEPSEEK_API_KEY = ""
# [step 2]>> 改为True应用代理。如果使用本地或无地域限制的大模型时,此处不修改;如果直接在海外服务器部署,此处不修改
USE_PROXY = False
if USE_PROXY:
"""
@@ -32,11 +37,16 @@ else:
# [step 3]>> 模型选择是 (注意: LLM_MODEL是默认选中的模型, 它*必须*被包含在AVAIL_LLM_MODELS列表中 )
LLM_MODEL = "gpt-3.5-turbo-16k" # 可选 ↓↓↓
AVAIL_LLM_MODELS = ["gpt-4-1106-preview", "gpt-4-turbo-preview", "gpt-4-vision-preview",
AVAIL_LLM_MODELS = ["qwen-max", "o1-mini", "o1-mini-2024-09-12", "o1", "o1-2024-12-17", "o1-preview", "o1-preview-2024-09-12",
"gpt-4-1106-preview", "gpt-4-turbo-preview", "gpt-4-vision-preview",
"gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
"gpt-3.5-turbo-1106", "gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5",
"gpt-4", "gpt-4-32k", "azure-gpt-4", "glm-4", "glm-4v", "glm-3-turbo",
"gemini-1.5-pro", "chatglm3", "chatglm4"
"gemini-1.5-pro", "chatglm3", "chatglm4",
"deepseek-chat", "deepseek-coder", "deepseek-reasoner",
"volcengine-deepseek-r1-250120", "volcengine-deepseek-v3-241226",
"dashscope-deepseek-r1", "dashscope-deepseek-v3",
"dashscope-qwen3-14b", "dashscope-qwen3-235b-a22b", "dashscope-qwen3-32b",
]
EMBEDDING_MODEL = "text-embedding-3-small"
@@ -47,7 +57,7 @@ EMBEDDING_MODEL = "text-embedding-3-small"
# "glm-4-0520", "glm-4-air", "glm-4-airx", "glm-4-flash",
# "qianfan", "deepseekcoder",
# "spark", "sparkv2", "sparkv3", "sparkv3.5", "sparkv4",
# "qwen-turbo", "qwen-plus", "qwen-max", "qwen-local",
# "qwen-turbo", "qwen-plus", "qwen-local",
# "moonshot-v1-128k", "moonshot-v1-32k", "moonshot-v1-8k",
# "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-0125", "gpt-4o-2024-05-13"
# "claude-3-haiku-20240307","claude-3-sonnet-20240229","claude-3-opus-20240229", "claude-2.1", "claude-instant-1.2",
@@ -74,7 +84,7 @@ API_URL_REDIRECT = {}
# 多线程函数插件中,默认允许多少路线程同时访问OpenAI。Free trial users的限制是每分钟3次,Pay-as-you-go users的限制是每分钟3500次
# 一言以蔽之免费5刀用户填3,OpenAI绑了信用卡的用户可以填 16 或者更高。提高限制请查询https://platform.openai.com/docs/guides/rate-limits/overview
DEFAULT_WORKER_NUM = 3
DEFAULT_WORKER_NUM = 8
# 色彩主题, 可选 ["Default", "Chuanhu-Small-and-Beautiful", "High-Contrast"]
@@ -82,6 +92,31 @@ DEFAULT_WORKER_NUM = 3
THEME = "Default"
AVAIL_THEMES = ["Default", "Chuanhu-Small-and-Beautiful", "High-Contrast", "Gstaff/Xkcd", "NoCrypt/Miku"]
FONT = "Theme-Default-Font"
AVAIL_FONTS = [
"默认值(Theme-Default-Font)",
"宋体(SimSun)",
"黑体(SimHei)",
"楷体(KaiTi)",
"仿宋(FangSong)",
"华文细黑(STHeiti Light)",
"华文楷体(STKaiti)",
"华文仿宋(STFangsong)",
"华文宋体(STSong)",
"华文中宋(STZhongsong)",
"华文新魏(STXinwei)",
"华文隶书(STLiti)",
# 备注:以下字体需要网络支持,您可以自定义任意您喜欢的字体,如下所示,需要满足的格式为 "字体昵称(字体英文真名@字体css下载链接)"
"思源宋体(Source Han Serif CN VF@https://chinese-fonts-cdn.deno.dev/packages/syst/dist/SourceHanSerifCN/result.css)",
"月星楷(Moon Stars Kai HW@https://chinese-fonts-cdn.deno.dev/packages/moon-stars-kai/dist/MoonStarsKaiHW-Regular/result.css)",
"珠圆体(MaokenZhuyuanTi@https://chinese-fonts-cdn.deno.dev/packages/mkzyt/dist/猫啃珠圆体/result.css)",
"平方萌萌哒(PING FANG MENG MNEG DA@https://chinese-fonts-cdn.deno.dev/packages/pfmmd/dist/平方萌萌哒/result.css)",
"Helvetica",
"ui-sans-serif",
"sans-serif",
"system-ui"
]
# 默认的系统提示词system prompt
INIT_SYS_PROMPT = "Serve me as a writing and programming assistant."
@@ -133,10 +168,6 @@ MULTI_QUERY_LLM_MODELS = "gpt-3.5-turbo&chatglm3"
QWEN_LOCAL_MODEL_SELECTION = "Qwen/Qwen-1_8B-Chat-Int8"
# 接入通义千问在线大模型 https://dashscope.console.aliyun.com/
DASHSCOPE_API_KEY = "" # 阿里灵积云API_KEY
# 百度千帆LLM_MODEL="qianfan"
BAIDU_CLOUD_API_KEY = ''
BAIDU_CLOUD_SECRET_KEY = ''
@@ -238,8 +269,9 @@ MOONSHOT_API_KEY = ""
# 零一万物(Yi Model) API KEY
YIMODEL_API_KEY = ""
# 深度求索(DeepSeek) API KEY,默认请求地址为"https://api.deepseek.com/v1/chat/completions"
DEEPSEEK_API_KEY = ""
# 接入火山引擎的在线大模型),api-key获取地址 https://console.volcengine.com/ark/region:ark+cn-beijing/endpoint
ARK_API_KEY = "00000000-0000-0000-0000-000000000000" # 火山引擎 API KEY
# 紫东太初大模型 https://ai-maas.wair.ac.cn
@@ -303,7 +335,7 @@ ARXIV_CACHE_DIR = "gpt_log/arxiv_cache"
# 除了连接OpenAI之外,还有哪些场合允许使用代理,请尽量不要修改
WHEN_TO_USE_PROXY = ["Download_LLM", "Download_Gradio_Theme", "Connect_Grobid",
WHEN_TO_USE_PROXY = ["Connect_OpenAI", "Download_LLM", "Download_Gradio_Theme", "Connect_Grobid",
"Warmup_Modules", "Nougat_Download", "AutoGen", "Connect_OpenAI_Embedding"]
@@ -319,6 +351,23 @@ NUM_CUSTOM_BASIC_BTN = 4
DAAS_SERVER_URLS = [ f"https://niuziniu-biligpt{i}.hf.space/stream" for i in range(1,5) ]
# 在互联网搜索组件中,负责将搜索结果整理成干净的Markdown
JINA_API_KEY = ""
# SEMANTIC SCHOLAR API KEY
SEMANTIC_SCHOLAR_KEY = ""
# 是否自动裁剪上下文长度(是否启动,默认不启动)
AUTO_CONTEXT_CLIP_ENABLE = False
# 目标裁剪上下文的token长度如果超过这个长度,则会自动裁剪
AUTO_CONTEXT_CLIP_TRIGGER_TOKEN_LEN = 30*1000
# 无条件丢弃x以上的轮数
AUTO_CONTEXT_MAX_ROUND = 64
# 在裁剪上下文时,倒数第x次对话能“最多”保留的上下文token的比例占 AUTO_CONTEXT_CLIP_TRIGGER_TOKEN_LEN 的多少
AUTO_CONTEXT_MAX_CLIP_RATIO = [0.80, 0.60, 0.45, 0.25, 0.20, 0.18, 0.16, 0.14, 0.12, 0.10, 0.08, 0.07, 0.06, 0.05, 0.04, 0.03, 0.02, 0.01]
"""
--------------- 配置关联关系说明 ---------------

查看文件

@@ -50,6 +50,9 @@ def get_crazy_functions():
from crazy_functions.SourceCode_Comment import 注释Python项目
from crazy_functions.SourceCode_Comment_Wrap import SourceCodeComment_Wrap
from crazy_functions.VideoResource_GPT import 多媒体任务
from crazy_functions.Document_Conversation import 批量文件询问
from crazy_functions.Document_Conversation_Wrap import Document_Conversation_Wrap
function_plugins = {
"多媒体智能体": {
@@ -113,7 +116,7 @@ def get_crazy_functions():
"Group": "学术",
"Color": "stop",
"AsButton": True,
"Info": "Arixv论文精细翻译 | 输入参数arxiv论文的ID,比如1812.10695",
"Info": "ArXiv论文精细翻译 | 输入参数arxiv论文的ID,比如1812.10695",
"Function": HotReload(Latex翻译中文并重新编译PDF), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
"Class": Arxiv_Localize, # 新一代插件需要注册Class
},
@@ -352,7 +355,7 @@ def get_crazy_functions():
"ArgsReminder": r"如果有必要, 请在此处给出自定义翻译命令, 解决部分词汇翻译不准确的问题。 "
r"例如当单词'agent'翻译不准确时, 请尝试把以下指令复制到高级参数区: "
r'If the term "agent" is used in this section, it should be translated to "智能体". ',
"Info": "Arixv论文精细翻译 | 输入参数arxiv论文的ID,比如1812.10695",
"Info": "ArXiv论文精细翻译 | 输入参数arxiv论文的ID,比如1812.10695",
"Function": HotReload(Latex翻译中文并重新编译PDF), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
"Class": Arxiv_Localize, # 新一代插件需要注册Class
},
@@ -378,7 +381,16 @@ def get_crazy_functions():
"Info": "PDF翻译中文,并重新编译PDF | 输入参数为路径",
"Function": HotReload(PDF翻译中文并重新编译PDF), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
"Class": PDF_Localize # 新一代插件需要注册Class
}
},
"批量文件询问 (支持自定义总结各种文件)": {
"Group": "学术",
"Color": "stop",
"AsButton": False,
"AdvancedArgs": False,
"Info": "先上传文件,点击此按钮,进行提问",
"Function": HotReload(批量文件询问),
"Class": Document_Conversation_Wrap,
},
}
function_plugins.update(
@@ -414,8 +426,6 @@ def get_crazy_functions():
# -=--=- 尚未充分测试的实验性插件 & 需要额外依赖的插件 -=--=-
try:
from crazy_functions.下载arxiv论文翻译摘要 import 下载arxiv论文并翻译摘要
@@ -434,36 +444,6 @@ def get_crazy_functions():
logger.error(trimmed_format_exc())
logger.error("Load function plugin failed")
# try:
# from crazy_functions.联网的ChatGPT import 连接网络回答问题
# function_plugins.update(
# {
# "连接网络回答问题(输入问题后点击该插件,需要访问谷歌)": {
# "Group": "对话",
# "Color": "stop",
# "AsButton": False, # 加入下拉菜单中
# # "Info": "连接网络回答问题(需要访问谷歌)| 输入参数是一个问题",
# "Function": HotReload(连接网络回答问题),
# }
# }
# )
# from crazy_functions.联网的ChatGPT_bing版 import 连接bing搜索回答问题
# function_plugins.update(
# {
# "连接网络回答问题中文Bing版,输入问题后点击该插件": {
# "Group": "对话",
# "Color": "stop",
# "AsButton": False, # 加入下拉菜单中
# "Info": "连接网络回答问题需要访问中文Bing| 输入参数是一个问题",
# "Function": HotReload(连接bing搜索回答问题),
# }
# }
# )
# except:
# logger.error(trimmed_format_exc())
# logger.error("Load function plugin failed")
try:
from crazy_functions.SourceCode_Analyse import 解析任意code项目
@@ -674,22 +654,21 @@ def get_crazy_functions():
logger.error(trimmed_format_exc())
logger.error("Load function plugin failed")
try:
from crazy_functions.多智能体 import 多智能体终端
function_plugins.update(
{
"AutoGen多智能体终端仅供测试": {
"Group": "智能体",
"Color": "stop",
"AsButton": False,
"Function": HotReload(多智能体终端),
}
}
)
except:
logger.error(trimmed_format_exc())
logger.error("Load function plugin failed")
# try:
# from crazy_functions.多智能体 import 多智能体终端
# function_plugins.update(
# {
# "AutoGen多智能体终端仅供测试": {
# "Group": "智能体",
# "Color": "stop",
# "AsButton": False,
# "Function": HotReload(多智能体终端),
# }
# }
# )
# except:
# logger.error(trimmed_format_exc())
# logger.error("Load function plugin failed")
try:
from crazy_functions.互动小游戏 import 随机小游戏
@@ -726,6 +705,44 @@ def get_crazy_functions():
logger.error(trimmed_format_exc())
logger.error("Load function plugin failed")
# try:
# from crazy_functions.Document_Optimize import 自定义智能文档处理
# function_plugins.update(
# {
# "一键处理文档(支持自定义全文润色、降重等)": {
# "Group": "学术",
# "Color": "stop",
# "AsButton": False,
# "AdvancedArgs": True,
# "ArgsReminder": "请输入处理指令和要求(可以详细描述),如:请帮我润色文本,要求幽默点。默认调用润色指令。",
# "Info": "保留文档结构,智能处理文档内容 | 输入参数为文件路径",
# "Function": HotReload(自定义智能文档处理)
# },
# }
# )
# except:
# logger.error(trimmed_format_exc())
# logger.error("Load function plugin failed")
# try:
# from crazy_functions.Paper_Reading import 快速论文解读
# function_plugins.update(
# {
# "速读论文": {
# "Group": "学术",
# "Color": "stop",
# "AsButton": False,
# "Info": "上传一篇论文进行快速分析和解读 | 输入参数为论文路径或DOI/arXiv ID",
# "Function": HotReload(快速论文解读),
# },
# }
# )
# except:
# logger.error(trimmed_format_exc())
# logger.error("Load function plugin failed")
# try:
# from crazy_functions.高级功能函数模板 import 测试图表渲染
@@ -771,12 +788,15 @@ def get_multiplex_button_functions():
"常规对话":
"",
"多模型对话":
"查互联网后回答":
"查互联网后回答",
"多模型对话":
"询问多个GPT模型", # 映射到上面的 `询问多个GPT模型` 插件
"智能召回 RAG":
"智能召回 RAG":
"Rag智能召回", # 映射到上面的 `Rag智能召回` 插件
"多媒体查询":
"多媒体查询":
"多媒体智能体", # 映射到上面的 `多媒体智能体` 插件
}

查看文件

@@ -0,0 +1,290 @@
import re
import os
import asyncio
from typing import List, Dict, Tuple
from dataclasses import dataclass
from textwrap import dedent
from toolbox import CatchException, get_conf, update_ui, promote_file_to_downloadzone, get_log_folder, get_user
from toolbox import update_ui, CatchException, report_exception, write_history_to_file
from crazy_functions.review_fns.data_sources.semantic_source import SemanticScholarSource
from crazy_functions.review_fns.data_sources.arxiv_source import ArxivSource
from crazy_functions.review_fns.query_analyzer import QueryAnalyzer
from crazy_functions.review_fns.handlers.review_handler import 文献综述功能
from crazy_functions.review_fns.handlers.recommend_handler import 论文推荐功能
from crazy_functions.review_fns.handlers.qa_handler import 学术问答功能
from crazy_functions.review_fns.handlers.paper_handler import 单篇论文分析功能
from crazy_functions.Conversation_To_File import write_chat_to_file
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
from crazy_functions.review_fns.handlers.latest_handler import Arxiv最新论文推荐功能
from datetime import datetime
@CatchException
def 学术对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
history: List, system_prompt: str, user_request: str):
"""主函数"""
# 初始化数据源
arxiv_source = ArxivSource()
semantic_source = SemanticScholarSource(
api_key=get_conf("SEMANTIC_SCHOLAR_KEY")
)
# 初始化处理器
handlers = {
"review": 文献综述功能(arxiv_source, semantic_source, llm_kwargs),
"recommend": 论文推荐功能(arxiv_source, semantic_source, llm_kwargs),
"qa": 学术问答功能(arxiv_source, semantic_source, llm_kwargs),
"paper": 单篇论文分析功能(arxiv_source, semantic_source, llm_kwargs),
"latest": Arxiv最新论文推荐功能(arxiv_source, semantic_source, llm_kwargs),
}
# 分析查询意图
chatbot.append([None, "正在分析研究主题和查询要求..."])
yield from update_ui(chatbot=chatbot, history=history)
query_analyzer = QueryAnalyzer()
search_criteria = yield from query_analyzer.analyze_query(txt, chatbot, llm_kwargs)
handler = handlers.get(search_criteria.query_type)
if not handler:
handler = handlers["qa"] # 默认使用QA处理器
# 处理查询
chatbot.append([None, f"使用{handler.__class__.__name__}处理...,可能需要您耐心等待35分钟..."])
yield from update_ui(chatbot=chatbot, history=history)
final_prompt = asyncio.run(handler.handle(
criteria=search_criteria,
chatbot=chatbot,
history=history,
system_prompt=system_prompt,
llm_kwargs=llm_kwargs,
plugin_kwargs=plugin_kwargs
))
if final_prompt:
# 检查是否是道歉提示
if "很抱歉,我们未能找到" in final_prompt:
chatbot.append([txt, final_prompt])
yield from update_ui(chatbot=chatbot, history=history)
return
# 在 final_prompt 末尾添加用户原始查询要求
final_prompt += dedent(f"""
Original user query: "{txt}"
IMPORTANT NOTE :
- Your response must directly address the user's original user query above
- While following the previous guidelines, prioritize answering what the user specifically asked
- Make sure your response format and content align with the user's expectations
- Do not translate paper titles, keep them in their original language
- Do not generate a reference list in your response - references will be handled separately
""")
# 使用最终的prompt生成回答
response = yield from request_gpt_model_in_new_thread_with_ui_alive(
inputs=final_prompt,
inputs_show_user=txt,
llm_kwargs=llm_kwargs,
chatbot=chatbot,
history=[],
sys_prompt=f"You are a helpful academic assistant. Response in Chinese by default unless specified language is required in the user's query."
)
# 1. 获取文献列表
papers_list = handler.ranked_papers # 直接使用原始论文数据
# 在新的对话中添加格式化的参考文献列表
if papers_list:
references = ""
for idx, paper in enumerate(papers_list, 1):
# 构建作者列表
authors = paper.authors[:3]
if len(paper.authors) > 3:
authors.append("et al.")
authors_str = ", ".join(authors)
# 构建期刊指标信息
metrics = []
if hasattr(paper, 'if_factor') and paper.if_factor:
metrics.append(f"IF: {paper.if_factor}")
if hasattr(paper, 'jcr_division') and paper.jcr_division:
metrics.append(f"JCR: {paper.jcr_division}")
if hasattr(paper, 'cas_division') and paper.cas_division:
metrics.append(f"中科院分区: {paper.cas_division}")
metrics_str = f" [{', '.join(metrics)}]" if metrics else ""
# 构建DOI链接
doi_link = ""
if paper.doi:
if "arxiv.org" in str(paper.doi):
doi_url = paper.doi
else:
doi_url = f"https://doi.org/{paper.doi}"
doi_link = f" <a href='{doi_url}' target='_blank'>DOI: {paper.doi}</a>"
# 构建完整的引用
reference = f"[{idx}] {authors_str}. *{paper.title}*"
if paper.venue_name:
reference += f". {paper.venue_name}"
if paper.year:
reference += f", {paper.year}"
reference += metrics_str
if doi_link:
reference += f".{doi_link}"
reference += " \n"
references += reference
# 添加新的对话显示参考文献
chatbot.append(["参考文献如下:", references])
yield from update_ui(chatbot=chatbot, history=history)
# 2. 保存为不同格式
from .review_fns.conversation_doc.word_doc import WordFormatter
from .review_fns.conversation_doc.word2pdf import WordToPdfConverter
from .review_fns.conversation_doc.markdown_doc import MarkdownFormatter
from .review_fns.conversation_doc.html_doc import HtmlFormatter
# 创建保存目录
save_dir = get_log_folder(get_user(chatbot), plugin_name='chatscholar')
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# 生成文件名
def get_safe_filename(txt, max_length=10):
# 获取文本前max_length个字符作为文件名
filename = txt[:max_length].strip()
# 移除不安全的文件名字符
filename = re.sub(r'[\\/:*?"<>|]', '', filename)
# 如果文件名为空,使用时间戳
if not filename:
filename = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
return filename
base_filename = get_safe_filename(txt)
result_files = [] # 收集所有生成的文件
pdf_path = None # 用于跟踪PDF是否成功生成
# 保存为Markdown
try:
md_formatter = MarkdownFormatter()
md_content = md_formatter.create_document(txt, response, papers_list)
result_file_md = write_history_to_file(
history=[md_content],
file_basename=f"markdown_{base_filename}.md"
)
result_files.append(result_file_md)
except Exception as e:
print(f"Markdown保存失败: {str(e)}")
# 保存为HTML
try:
html_formatter = HtmlFormatter()
html_content = html_formatter.create_document(txt, response, papers_list)
result_file_html = write_history_to_file(
history=[html_content],
file_basename=f"html_{base_filename}.html"
)
result_files.append(result_file_html)
except Exception as e:
print(f"HTML保存失败: {str(e)}")
# 保存为Word
try:
word_formatter = WordFormatter()
try:
doc = word_formatter.create_document(txt, response, papers_list)
except Exception as e:
print(f"Word文档内容生成失败: {str(e)}")
raise e
try:
result_file_docx = os.path.join(
os.path.dirname(result_file_md) if result_file_md else save_dir,
f"docx_{base_filename}.docx"
)
doc.save(result_file_docx)
result_files.append(result_file_docx)
print(f"Word文档已保存到: {result_file_docx}")
# 转换为PDF
try:
pdf_path = WordToPdfConverter.convert_to_pdf(result_file_docx)
if pdf_path:
result_files.append(pdf_path)
print(f"PDF文档已生成: {pdf_path}")
except Exception as e:
print(f"PDF转换失败: {str(e)}")
except Exception as e:
print(f"Word文档保存失败: {str(e)}")
raise e
except Exception as e:
print(f"Word格式化失败: {str(e)}")
import traceback
print(f"详细错误信息: {traceback.format_exc()}")
# 保存为BibTeX格式
try:
from .review_fns.conversation_doc.reference_formatter import ReferenceFormatter
ref_formatter = ReferenceFormatter()
bibtex_content = ref_formatter.create_document(papers_list)
# 在与其他文件相同目录下创建BibTeX文件
result_file_bib = os.path.join(
os.path.dirname(result_file_md) if result_file_md else save_dir,
f"references_{base_filename}.bib"
)
# 直接写入文件
with open(result_file_bib, 'w', encoding='utf-8') as f:
f.write(bibtex_content)
result_files.append(result_file_bib)
print(f"BibTeX文件已保存到: {result_file_bib}")
except Exception as e:
print(f"BibTeX格式保存失败: {str(e)}")
# 保存为EndNote格式
try:
from .review_fns.conversation_doc.endnote_doc import EndNoteFormatter
endnote_formatter = EndNoteFormatter()
endnote_content = endnote_formatter.create_document(papers_list)
# 在与其他文件相同目录下创建EndNote文件
result_file_enw = os.path.join(
os.path.dirname(result_file_md) if result_file_md else save_dir,
f"references_{base_filename}.enw"
)
# 直接写入文件
with open(result_file_enw, 'w', encoding='utf-8') as f:
f.write(endnote_content)
result_files.append(result_file_enw)
print(f"EndNote文件已保存到: {result_file_enw}")
except Exception as e:
print(f"EndNote格式保存失败: {str(e)}")
# 添加所有文件到下载区
success_files = []
for file in result_files:
try:
promote_file_to_downloadzone(file, chatbot=chatbot)
success_files.append(os.path.basename(file))
except Exception as e:
print(f"文件添加到下载区失败: {str(e)}")
# 更新成功提示消息
if success_files:
chatbot.append(["保存对话记录成功,bib和enw文件支持导入到EndNote、Zotero、JabRef、Mendeley等文献管理软件,HTML文件支持在浏览器中打开,里面包含详细论文源信息", "对话已保存并添加到下载区,可以在下载区找到相关文件"])
else:
chatbot.append(["保存对话记录", "所有格式的保存都失败了,请检查错误日志。"])
yield from update_ui(chatbot=chatbot, history=history)
else:
report_exception(chatbot, history, a=f"处理失败", b=f"请尝试其他查询")
yield from update_ui(chatbot=chatbot, history=history)

查看文件

@@ -1,10 +1,11 @@
from toolbox import CatchException, update_ui, promote_file_to_downloadzone, get_log_folder, get_user
from crazy_functions.plugin_template.plugin_class_template import GptAcademicPluginTemplate, ArgProperty
import re
from toolbox import CatchException, update_ui, promote_file_to_downloadzone, get_log_folder, get_user, update_ui_latest_msg
from crazy_functions.plugin_template.plugin_class_template import GptAcademicPluginTemplate, ArgProperty
from loguru import logger
f_prefix = 'GPT-Academic对话存档'
def write_chat_to_file(chatbot, history=None, file_name=None):
def write_chat_to_file_legacy(chatbot, history=None, file_name=None):
"""
将对话记录history以Markdown格式写入文件中。如果没有指定文件名,则使用当前时间生成文件名。
"""
@@ -12,6 +13,9 @@ def write_chat_to_file(chatbot, history=None, file_name=None):
import time
from themes.theme import advanced_css
if (file_name is not None) and (file_name != "") and (not file_name.endswith('.html')): file_name += '.html'
else: file_name = None
if file_name is None:
file_name = f_prefix + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '.html'
fp = os.path.join(get_log_folder(get_user(chatbot), plugin_name='chat_history'), file_name)
@@ -68,6 +72,147 @@ def write_chat_to_file(chatbot, history=None, file_name=None):
promote_file_to_downloadzone(fp, rename_file=file_name, chatbot=chatbot)
return '对话历史写入:' + fp
def write_chat_to_file(chatbot, history=None, file_name=None):
"""
将对话记录history以多种格式HTML、Word、Markdown写入文件中。如果没有指定文件名,则使用当前时间生成文件名。
Args:
chatbot: 聊天机器人对象,包含对话内容
history: 对话历史记录
file_name: 指定的文件名,如果为None则使用时间戳
Returns:
str: 提示信息,包含文件保存路径
"""
import os
import time
import asyncio
import aiofiles
from toolbox import promote_file_to_downloadzone
from crazy_functions.doc_fns.conversation_doc.excel_doc import save_chat_tables
from crazy_functions.doc_fns.conversation_doc.html_doc import HtmlFormatter
from crazy_functions.doc_fns.conversation_doc.markdown_doc import MarkdownFormatter
from crazy_functions.doc_fns.conversation_doc.word_doc import WordFormatter
from crazy_functions.doc_fns.conversation_doc.txt_doc import TxtFormatter
from crazy_functions.doc_fns.conversation_doc.word2pdf import WordToPdfConverter
async def save_html():
try:
html_formatter = HtmlFormatter(chatbot, history)
html_content = html_formatter.create_document()
html_file = os.path.join(save_dir, base_name + '.html')
async with aiofiles.open(html_file, 'w', encoding='utf8') as f:
await f.write(html_content)
return html_file
except Exception as e:
print(f"保存HTML格式失败: {str(e)}")
return None
async def save_word():
try:
word_formatter = WordFormatter()
doc = word_formatter.create_document(history)
docx_file = os.path.join(save_dir, base_name + '.docx')
# 由于python-docx不支持异步,使用线程池执行
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, doc.save, docx_file)
return docx_file
except Exception as e:
print(f"保存Word格式失败: {str(e)}")
return None
async def save_pdf(docx_file):
try:
if docx_file:
# 获取文件名和保存路径
pdf_file = os.path.join(save_dir, base_name + '.pdf')
# 在线程池中执行转换
loop = asyncio.get_event_loop()
pdf_file = await loop.run_in_executor(
None,
WordToPdfConverter.convert_to_pdf,
docx_file
# save_dir
)
return pdf_file
except Exception as e:
print(f"保存PDF格式失败: {str(e)}")
return None
async def save_markdown():
try:
md_formatter = MarkdownFormatter()
md_content = md_formatter.create_document(history)
md_file = os.path.join(save_dir, base_name + '.md')
async with aiofiles.open(md_file, 'w', encoding='utf8') as f:
await f.write(md_content)
return md_file
except Exception as e:
print(f"保存Markdown格式失败: {str(e)}")
return None
async def save_txt():
try:
txt_formatter = TxtFormatter()
txt_content = txt_formatter.create_document(history)
txt_file = os.path.join(save_dir, base_name + '.txt')
async with aiofiles.open(txt_file, 'w', encoding='utf8') as f:
await f.write(txt_content)
return txt_file
except Exception as e:
print(f"保存TXT格式失败: {str(e)}")
return None
async def main():
# 并发执行所有保存任务
html_task = asyncio.create_task(save_html())
word_task = asyncio.create_task(save_word())
md_task = asyncio.create_task(save_markdown())
txt_task = asyncio.create_task(save_txt())
# 等待所有任务完成
html_file = await html_task
docx_file = await word_task
md_file = await md_task
txt_file = await txt_task
# PDF转换需要等待word文件生成完成
pdf_file = await save_pdf(docx_file)
# 收集所有成功生成的文件
result_files = [f for f in [html_file, docx_file, md_file, txt_file, pdf_file] if f]
# 保存Excel表格
excel_files = save_chat_tables(history, save_dir, base_name)
result_files.extend(excel_files)
return result_files
# 生成时间戳
timestamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
# 获取保存目录
save_dir = get_log_folder(get_user(chatbot), plugin_name='chat_history')
# 处理文件名
base_name = file_name if file_name else f"聊天记录_{timestamp}"
# 运行异步任务
result_files = asyncio.run(main())
# 将生成的文件添加到下载区
for file in result_files:
promote_file_to_downloadzone(file, rename_file=os.path.basename(file), chatbot=chatbot)
# 如果没有成功保存任何文件,返回错误信息
if not result_files:
return "保存对话记录失败,请检查错误日志"
ext_list = [os.path.splitext(f)[1] for f in result_files]
# 返回成功信息和文件路径
return f"对话历史已保存至以下格式文件:" + "".join(ext_list)
def gen_file_preview(file_name):
try:
with open(file_name, 'r', encoding='utf8') as f:
@@ -119,12 +264,21 @@ def 对话历史存档(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_
user_request 当前用户的请求信息IP地址等
"""
file_name = plugin_kwargs.get("file_name", None)
if (file_name is not None) and (file_name != "") and (not file_name.endswith('.html')): file_name += '.html'
else: file_name = None
chatbot.append((None, f"[Local Message] {write_chat_to_file(chatbot, history, file_name)},您可以调用下拉菜单中的“载入对话历史存档”还原当下的对话。"))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 由于请求gpt需要一段时间,我们先及时地做一次界面更新
chatbot.append((None, f"[Local Message] {write_chat_to_file_legacy(chatbot, history, file_name)},您可以调用下拉菜单中的“载入对话历史存档”还原当下的对话。"))
try:
chatbot.append((None, f"[Local Message] 正在尝试生成pdf以及word格式的对话存档,请稍等..."))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 由于请求需要一段时间,我们先及时地做一次界面更新
lastmsg = f"[Local Message] {write_chat_to_file(chatbot, history, file_name)}" \
f"您可以调用下拉菜单中的“载入对话历史会话”还原当下的对话,请注意,目前只支持html格式载入历史。" \
f"当模型回答中存在表格,将提取表格内容存储为Excel的xlsx格式,如果你提供一些数据,然后输入指令要求模型帮你整理为表格" \
f"如“请帮我将下面的数据整理为表格,再利用此插件就可以获取到Excel表格。"
yield from update_ui_latest_msg(lastmsg, chatbot, history) # 刷新界面 # 由于请求需要一段时间,我们先及时地做一次界面更新
except Exception as e:
logger.exception(f"已完成对话存档pdf和word格式的对话存档生成未成功{str(e)}")
lastmsg = "已完成对话存档pdf和word格式的对话存档生成未成功"
yield from update_ui_latest_msg(lastmsg, chatbot, history) # 刷新界面 # 由于请求需要一段时间,我们先及时地做一次界面更新
return
class Conversation_To_File_Wrap(GptAcademicPluginTemplate):
def __init__(self):

查看文件

@@ -0,0 +1,537 @@
import os
import threading
import time
from dataclasses import dataclass
from typing import List, Tuple, Dict, Generator
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
from crazy_functions.pdf_fns.breakdown_txt import breakdown_text_to_satisfy_token_limit
from crazy_functions.rag_fns.rag_file_support import extract_text
from request_llms.bridge_all import model_info
from toolbox import update_ui, CatchException, report_exception
from shared_utils.fastapi_server import validate_path_safety
@dataclass
class FileFragment:
"""文件片段数据类,用于组织处理单元"""
file_path: str
content: str
rel_path: str
fragment_index: int
total_fragments: int
class BatchDocumentSummarizer:
"""优化的文档总结器 - 批处理版本"""
def __init__(self, llm_kwargs: Dict, query: str, chatbot: List, history: List, system_prompt: str):
"""初始化总结器"""
self.llm_kwargs = llm_kwargs
self.query = query
self.chatbot = chatbot
self.history = history
self.system_prompt = system_prompt
self.failed_files = []
self.file_summaries_map = {}
def _get_token_limit(self) -> int:
"""获取模型token限制"""
max_token = model_info[self.llm_kwargs['llm_model']]['max_token']
return max_token * 3 // 4
def _create_batch_inputs(self, fragments: List[FileFragment]) -> Tuple[List, List, List]:
"""创建批处理输入"""
inputs_array = []
inputs_show_user_array = []
history_array = []
for frag in fragments:
if self.query:
i_say = (f'请按照用户要求对文件内容进行处理,文件名为{os.path.basename(frag.file_path)}'
f'用户要求为:{self.query}'
f'文件内容是 ```{frag.content}```')
i_say_show_user = (f'正在处理 {frag.rel_path} (片段 {frag.fragment_index + 1}/{frag.total_fragments})')
else:
i_say = (f'请对下面的内容用中文做总结,不超过500字,文件名是{os.path.basename(frag.file_path)}'
f'内容是 ```{frag.content}```')
i_say_show_user = f'正在处理 {frag.rel_path} (片段 {frag.fragment_index + 1}/{frag.total_fragments})'
inputs_array.append(i_say)
inputs_show_user_array.append(i_say_show_user)
history_array.append([])
return inputs_array, inputs_show_user_array, history_array
def _process_single_file_with_timeout(self, file_info: Tuple[str, str], mutable_status: List) -> List[FileFragment]:
"""包装了超时控制的文件处理函数"""
def timeout_handler():
thread = threading.current_thread()
if hasattr(thread, '_timeout_occurred'):
thread._timeout_occurred = True
# 设置超时标记
thread = threading.current_thread()
thread._timeout_occurred = False
# 设置超时时间为30秒,给予更多处理时间
TIMEOUT_SECONDS = 30
timer = threading.Timer(TIMEOUT_SECONDS, timeout_handler)
timer.start()
try:
fp, project_folder = file_info
fragments = []
# 定期检查是否超时
def check_timeout():
if hasattr(thread, '_timeout_occurred') and thread._timeout_occurred:
raise TimeoutError(f"处理文件 {os.path.basename(fp)} 超时({TIMEOUT_SECONDS}秒)")
# 更新状态
mutable_status[0] = "检查文件大小"
mutable_status[1] = time.time()
check_timeout()
# 文件大小检查
if os.path.getsize(fp) > self.max_file_size:
self.failed_files.append((fp, f"文件过大:超过{self.max_file_size / 1024 / 1024}MB"))
mutable_status[2] = "文件过大"
return fragments
# 更新状态
mutable_status[0] = "提取文件内容"
mutable_status[1] = time.time()
# 提取内容 - 使用单独的超时控制
content = None
extract_start_time = time.time()
try:
while True:
check_timeout() # 检查全局超时
# 检查提取过程是否超时10秒
if time.time() - extract_start_time > 10:
raise TimeoutError("文件内容提取超时10秒")
try:
content = extract_text(fp)
break
except Exception as e:
if "timeout" in str(e).lower():
continue # 如果是临时超时,重试
raise # 其他错误直接抛出
except Exception as e:
self.failed_files.append((fp, f"文件读取失败:{str(e)}"))
mutable_status[2] = "读取失败"
return fragments
if content is None:
self.failed_files.append((fp, "文件解析失败:不支持的格式或文件损坏"))
mutable_status[2] = "格式不支持"
return fragments
elif not content.strip():
self.failed_files.append((fp, "文件内容为空"))
mutable_status[2] = "内容为空"
return fragments
check_timeout()
# 更新状态
mutable_status[0] = "分割文本"
mutable_status[1] = time.time()
# 分割文本 - 添加超时检查
split_start_time = time.time()
try:
while True:
check_timeout() # 检查全局超时
# 检查分割过程是否超时5秒
if time.time() - split_start_time > 5:
raise TimeoutError("文本分割超时5秒")
paper_fragments = breakdown_text_to_satisfy_token_limit(
txt=content,
limit=self._get_token_limit(),
llm_model=self.llm_kwargs['llm_model']
)
break
except Exception as e:
self.failed_files.append((fp, f"文本分割失败:{str(e)}"))
mutable_status[2] = "分割失败"
return fragments
# 处理片段
rel_path = os.path.relpath(fp, project_folder)
for i, frag in enumerate(paper_fragments):
check_timeout() # 每处理一个片段检查一次超时
if frag.strip():
fragments.append(FileFragment(
file_path=fp,
content=frag,
rel_path=rel_path,
fragment_index=i,
total_fragments=len(paper_fragments)
))
mutable_status[2] = "处理完成"
return fragments
except TimeoutError as e:
self.failed_files.append((fp, str(e)))
mutable_status[2] = "处理超时"
return []
except Exception as e:
self.failed_files.append((fp, f"处理失败:{str(e)}"))
mutable_status[2] = "处理异常"
return []
finally:
timer.cancel()
def prepare_fragments(self, project_folder: str, file_paths: List[str]) -> Generator:
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
from typing import Generator, List
"""并行准备所有文件的处理片段"""
all_fragments = []
total_files = len(file_paths)
# 配置参数
self.refresh_interval = 0.2 # UI刷新间隔
self.watch_dog_patience = 5 # 看门狗超时时间
self.max_file_size = 10 * 1024 * 1024 # 10MB限制
self.max_workers = min(32, len(file_paths)) # 最多32个线程
# 创建有超时控制的线程池
executor = ThreadPoolExecutor(max_workers=self.max_workers)
# 用于跨线程状态传递的可变列表 - 增加文件名信息
mutable_status_array = [["等待中", time.time(), "pending", file_path] for file_path in file_paths]
# 创建文件处理任务
file_infos = [(fp, project_folder) for fp in file_paths]
# 提交所有任务,使用带超时控制的处理函数
futures = [
executor.submit(
self._process_single_file_with_timeout,
file_info,
mutable_status_array[i]
) for i, file_info in enumerate(file_infos)
]
# 更新UI的计数器
cnt = 0
try:
# 监控任务执行
while True:
time.sleep(self.refresh_interval)
cnt += 1
# 检查任务完成状态
worker_done = [f.done() for f in futures]
# 更新状态显示
status_str = ""
for i, (status, timestamp, desc, file_path) in enumerate(mutable_status_array):
# 获取文件名(去掉路径)
file_name = os.path.basename(file_path)
if worker_done[i]:
status_str += f"文件 {file_name}: {desc}\n\n"
else:
status_str += f"文件 {file_name}: {status} {desc}\n\n"
# 更新UI
self.chatbot[-1] = [
"处理进度",
f"正在处理文件...\n\n{status_str}" + "." * (cnt % 10 + 1)
]
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 检查是否所有任务完成
if all(worker_done):
break
finally:
# 确保线程池正确关闭
executor.shutdown(wait=False)
# 收集结果
processed_files = 0
for future in futures:
try:
fragments = future.result(timeout=0.1) # 给予一个短暂的超时时间来获取结果
all_fragments.extend(fragments)
processed_files += 1
except concurrent.futures.TimeoutError:
# 处理获取结果超时
file_index = futures.index(future)
self.failed_files.append((file_paths[file_index], "结果获取超时"))
continue
except Exception as e:
# 处理其他异常
file_index = futures.index(future)
self.failed_files.append((file_paths[file_index], f"未知错误:{str(e)}"))
continue
# 最终进度更新
self.chatbot.append([
"文件处理完成",
f"成功处理 {len(all_fragments)} 个片段,失败 {len(self.failed_files)} 个文件"
])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return all_fragments
def _process_fragments_batch(self, fragments: List[FileFragment]) -> Generator:
"""批量处理文件片段"""
from collections import defaultdict
batch_size = 64 # 每批处理的片段数
max_retries = 3 # 最大重试次数
retry_delay = 5 # 重试延迟(秒)
results = defaultdict(list)
# 按批次处理
for i in range(0, len(fragments), batch_size):
batch = fragments[i:i + batch_size]
inputs_array, inputs_show_user_array, history_array = self._create_batch_inputs(batch)
sys_prompt_array = ["请总结以下内容:"] * len(batch)
# 添加重试机制
for retry in range(max_retries):
try:
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
inputs_array=inputs_array,
inputs_show_user_array=inputs_show_user_array,
llm_kwargs=self.llm_kwargs,
chatbot=self.chatbot,
history_array=history_array,
sys_prompt_array=sys_prompt_array,
)
# 处理响应
for j, frag in enumerate(batch):
summary = response_collection[j * 2 + 1]
if summary and summary.strip():
results[frag.rel_path].append({
'index': frag.fragment_index,
'summary': summary,
'total': frag.total_fragments
})
break # 成功处理,跳出重试循环
except Exception as e:
if retry == max_retries - 1: # 最后一次重试失败
for frag in batch:
self.failed_files.append((frag.file_path, f"处理失败:{str(e)}"))
else:
yield from update_ui(self.chatbot.append([f"批次处理失败,{retry_delay}秒后重试...", str(e)]))
time.sleep(retry_delay)
return results
def _generate_final_summary_request(self) -> Tuple[List, List, List]:
"""准备最终总结请求"""
if not self.file_summaries_map:
return (["无可用的文件总结"], ["生成最终总结"], [[]])
summaries = list(self.file_summaries_map.values())
if all(not summary for summary in summaries):
return (["所有文件处理均失败"], ["生成最终总结"], [[]])
if self.plugin_kwargs.get("advanced_arg"):
i_say = "根据以上所有文件的处理结果,按要求进行综合处理:" + self.plugin_kwargs['advanced_arg']
else:
i_say = "请根据以上所有文件的处理结果,生成最终的总结,不超过1000字。"
return ([i_say], [i_say], [summaries])
def process_files(self, project_folder: str, file_paths: List[str]) -> Generator:
"""处理所有文件"""
total_files = len(file_paths)
self.chatbot.append([f"开始处理", f"总计 {total_files} 个文件"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 1. 准备所有文件片段
# 在 process_files 函数中:
fragments = yield from self.prepare_fragments(project_folder, file_paths)
if not fragments:
self.chatbot.append(["处理失败", "没有可处理的文件内容"])
return "没有可处理的文件内容"
# 2. 批量处理所有文件片段
self.chatbot.append([f"文件分析", f"共计 {len(fragments)} 个处理单元"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
try:
file_summaries = yield from self._process_fragments_batch(fragments)
except Exception as e:
self.chatbot.append(["处理错误", f"批处理过程失败:{str(e)}"])
return "处理过程发生错误"
# 3. 为每个文件生成整体总结
self.chatbot.append(["生成总结", "正在汇总文件内容..."])
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 处理每个文件的总结
for rel_path, summaries in file_summaries.items():
if len(summaries) > 1: # 多片段文件需要生成整体总结
sorted_summaries = sorted(summaries, key=lambda x: x['index'])
if self.plugin_kwargs.get("advanced_arg"):
i_say = f'请按照用户要求对文件内容进行处理,用户要求为:{self.plugin_kwargs["advanced_arg"]}'
else:
i_say = f"请总结文件 {os.path.basename(rel_path)} 的主要内容,不超过500字。"
try:
summary_texts = [s['summary'] for s in sorted_summaries]
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
inputs_array=[i_say],
inputs_show_user_array=[f"生成 {rel_path} 的处理结果"],
llm_kwargs=self.llm_kwargs,
chatbot=self.chatbot,
history_array=[summary_texts],
sys_prompt_array=["你是一个优秀的助手,"],
)
self.file_summaries_map[rel_path] = response_collection[1]
except Exception as e:
self.chatbot.append(["警告", f"文件 {rel_path} 总结生成失败:{str(e)}"])
self.file_summaries_map[rel_path] = "总结生成失败"
else: # 单片段文件直接使用其唯一的总结
self.file_summaries_map[rel_path] = summaries[0]['summary']
# 4. 生成最终总结
if total_files == 1:
return "文件数为1,此时不调用总结模块"
else:
try:
# 收集所有文件的总结用于生成最终总结
file_summaries_for_final = []
for rel_path, summary in self.file_summaries_map.items():
file_summaries_for_final.append(f"文件 {rel_path} 的总结:\n{summary}")
if self.plugin_kwargs.get("advanced_arg"):
final_summary_prompt = ("根据以下所有文件的总结内容,按要求进行综合处理:" +
self.plugin_kwargs['advanced_arg'])
else:
final_summary_prompt = "请根据以下所有文件的总结内容,生成最终的总结报告。"
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
inputs_array=[final_summary_prompt],
inputs_show_user_array=["生成最终总结报告"],
llm_kwargs=self.llm_kwargs,
chatbot=self.chatbot,
history_array=[file_summaries_for_final],
sys_prompt_array=["总结所有文件内容。"],
max_workers=1
)
return response_collection[1] if len(response_collection) > 1 else "生成总结失败"
except Exception as e:
self.chatbot.append(["错误", f"最终总结生成失败:{str(e)}"])
return "生成总结失败"
def save_results(self, final_summary: str):
"""保存结果到文件"""
from toolbox import promote_file_to_downloadzone, write_history_to_file
from crazy_functions.doc_fns.batch_file_query_doc import MarkdownFormatter, HtmlFormatter, WordFormatter
import os
timestamp = time.strftime("%Y%m%d_%H%M%S")
# 创建各种格式化器
md_formatter = MarkdownFormatter(final_summary, self.file_summaries_map, self.failed_files)
html_formatter = HtmlFormatter(final_summary, self.file_summaries_map, self.failed_files)
word_formatter = WordFormatter(final_summary, self.file_summaries_map, self.failed_files)
result_files = []
# 保存 Markdown
try:
md_content = md_formatter.create_document()
result_file_md = write_history_to_file(
history=[md_content], # 直接传入内容列表
file_basename=f"文档总结_{timestamp}.md"
)
result_files.append(result_file_md)
except:
pass
# 保存 HTML
try:
html_content = html_formatter.create_document()
result_file_html = write_history_to_file(
history=[html_content],
file_basename=f"文档总结_{timestamp}.html"
)
result_files.append(result_file_html)
except:
pass
# 保存 Word
try:
doc = word_formatter.create_document()
# 由于 Word 文档需要用 doc.save(),我们使用与 md 文件相同的目录
result_file_docx = os.path.join(
os.path.dirname(result_file_md),
f"文档总结_{timestamp}.docx"
)
doc.save(result_file_docx)
result_files.append(result_file_docx)
except:
pass
# 添加到下载区
for file in result_files:
promote_file_to_downloadzone(file, chatbot=self.chatbot)
self.chatbot.append(["处理完成", f"结果已保存至: {', '.join(result_files)}"])
@CatchException
def 批量文件询问(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
history: List, system_prompt: str, user_request: str):
"""主函数 - 优化版本"""
# 初始化
import glob
import re
from crazy_functions.rag_fns.rag_file_support import supports_format
from toolbox import report_exception
query = plugin_kwargs.get("advanced_arg")
summarizer = BatchDocumentSummarizer(llm_kwargs, query, chatbot, history, system_prompt)
chatbot.append(["函数插件功能", f"作者lbykkkk,批量总结文件。支持格式: {', '.join(supports_format)}等其他文本格式文件,如果长时间卡在文件处理过程,请查看处理进度,然后删除所有处于“pending”状态的文件,然后重新上传处理。"])
yield from update_ui(chatbot=chatbot, history=history)
# 验证输入路径
if not os.path.exists(txt):
report_exception(chatbot, history, a=f"解析项目: {txt}", b=f"找不到项目或无权访问: {txt}")
yield from update_ui(chatbot=chatbot, history=history)
return
# 获取文件列表
project_folder = txt
user_name = chatbot.get_user()
validate_path_safety(project_folder, user_name)
extract_folder = next((d for d in glob.glob(f'{project_folder}/*')
if os.path.isdir(d) and d.endswith('.extract')), project_folder)
exclude_patterns = r'/[^/]+\.(zip|rar|7z|tar|gz)$'
file_manifest = [f for f in glob.glob(f'{extract_folder}/**', recursive=True)
if os.path.isfile(f) and not re.search(exclude_patterns, f)]
if not file_manifest:
report_exception(chatbot, history, a=f"解析项目: {txt}", b="未找到支持的文件类型")
yield from update_ui(chatbot=chatbot, history=history)
return
# 处理所有文件并生成总结
final_summary = yield from summarizer.process_files(project_folder, file_manifest)
yield from update_ui(chatbot=chatbot, history=history)
# 保存结果
summarizer.save_results(final_summary)
yield from update_ui(chatbot=chatbot, history=history)

查看文件

@@ -0,0 +1,36 @@
import random
from toolbox import get_conf
from crazy_functions.Document_Conversation import 批量文件询问
from crazy_functions.plugin_template.plugin_class_template import GptAcademicPluginTemplate, ArgProperty
class Document_Conversation_Wrap(GptAcademicPluginTemplate):
def __init__(self):
"""
请注意`execute`会执行在不同的线程中,因此您在定义和使用类变量时,应当慎之又慎!
"""
pass
def define_arg_selection_menu(self):
"""
定义插件的二级选项菜单
第一个参数,名称`main_input`,参数`type`声明这是一个文本框,文本框上方显示`title`,文本框内部显示`description`,`default_value`为默认值;
第二个参数,名称`advanced_arg`,参数`type`声明这是一个文本框,文本框上方显示`title`,文本框内部显示`description`,`default_value`为默认值;
第三个参数,名称`allow_cache`,参数`type`声明这是一个下拉菜单,下拉菜单上方显示`title`+`description`,下拉菜单的选项为`options`,`default_value`为下拉菜单默认值;
"""
gui_definition = {
"main_input":
ArgProperty(title="已上传的文件", description="上传文件后自动填充", default_value="", type="string").model_dump_json(),
"searxng_url":
ArgProperty(title="对材料提问", description="提问", default_value="", type="string").model_dump_json(), # 主输入,自动从输入框同步
}
return gui_definition
def execute(txt, llm_kwargs, plugin_kwargs:dict, chatbot, history, system_prompt, user_request):
"""
执行插件
"""
yield from 批量文件询问(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)

查看文件

@@ -0,0 +1,673 @@
import os
import time
import glob
import re
import threading
from typing import Dict, List, Generator, Tuple
from dataclasses import dataclass
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
from crazy_functions.pdf_fns.breakdown_txt import breakdown_text_to_satisfy_token_limit
from crazy_functions.rag_fns.rag_file_support import extract_text, supports_format, convert_to_markdown
from request_llms.bridge_all import model_info
from toolbox import update_ui, CatchException, report_exception, promote_file_to_downloadzone, write_history_to_file
from shared_utils.fastapi_server import validate_path_safety
# 新增:导入结构化论文提取器
from crazy_functions.doc_fns.read_fns.unstructured_all.paper_structure_extractor import PaperStructureExtractor, ExtractorConfig, StructuredPaper
# 导入格式化器
from crazy_functions.paper_fns.file2file_doc import (
TxtFormatter,
MarkdownFormatter,
HtmlFormatter,
WordFormatter
)
@dataclass
class TextFragment:
"""文本片段数据类,用于组织处理单元"""
content: str
fragment_index: int
total_fragments: int
class DocumentProcessor:
"""文档处理器 - 处理单个文档并输出结果"""
def __init__(self, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List, history: List, system_prompt: str):
"""初始化处理器"""
self.llm_kwargs = llm_kwargs
self.plugin_kwargs = plugin_kwargs
self.chatbot = chatbot
self.history = history
self.system_prompt = system_prompt
self.processed_results = []
self.failed_fragments = []
# 新增:初始化论文结构提取器
self.paper_extractor = PaperStructureExtractor()
def _get_token_limit(self) -> int:
"""获取模型token限制,返回更小的值以确保更细粒度的分割"""
max_token = model_info[self.llm_kwargs['llm_model']]['max_token']
# 降低token限制,使每个片段更小
return max_token // 4 # 从3/4降低到1/4
def _create_batch_inputs(self, fragments: List[TextFragment]) -> Tuple[List, List, List]:
"""创建批处理输入"""
inputs_array = []
inputs_show_user_array = []
history_array = []
user_instruction = self.plugin_kwargs.get("advanced_arg", "请润色以下学术文本,提高其语言表达的准确性、专业性和流畅度,保持学术风格,确保逻辑连贯,但不改变原文的科学内容和核心观点")
for frag in fragments:
i_say = (f'请按照以下要求处理文本内容:{user_instruction}\n\n'
f'请将对文本的处理结果放在<decision>和</decision>标签之间。\n\n'
f'文本内容:\n```\n{frag.content}\n```')
i_say_show_user = f'正在处理文本片段 {frag.fragment_index + 1}/{frag.total_fragments}'
inputs_array.append(i_say)
inputs_show_user_array.append(i_say_show_user)
history_array.append([])
return inputs_array, inputs_show_user_array, history_array
def _extract_decision(self, text: str) -> str:
"""从LLM响应中提取<decision>标签内的内容"""
import re
pattern = r'<decision>(.*?)</decision>'
matches = re.findall(pattern, text, re.DOTALL)
if matches:
return matches[0].strip()
else:
# 如果没有找到标签,返回原始文本
return text.strip()
def process_file(self, file_path: str) -> Generator:
"""处理单个文件"""
self.chatbot.append(["开始处理文件", f"文件路径: {file_path}"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
try:
# 首先尝试转换为Markdown
from crazy_functions.rag_fns.rag_file_support import convert_to_markdown
file_path = convert_to_markdown(file_path)
# 1. 检查文件是否为支持的论文格式
is_paper_format = any(file_path.lower().endswith(ext) for ext in self.paper_extractor.SUPPORTED_EXTENSIONS)
if is_paper_format:
# 使用结构化提取器处理论文
return (yield from self._process_structured_paper(file_path))
else:
# 使用原有方式处理普通文档
return (yield from self._process_regular_file(file_path))
except Exception as e:
self.chatbot.append(["处理错误", f"文件处理失败: {str(e)}"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return None
def _process_structured_paper(self, file_path: str) -> Generator:
"""处理结构化论文文件"""
# 1. 提取论文结构
self.chatbot[-1] = ["正在分析论文结构", f"文件路径: {file_path}"]
yield from update_ui(chatbot=self.chatbot, history=self.history)
try:
paper = self.paper_extractor.extract_paper_structure(file_path)
if not paper or not paper.sections:
self.chatbot.append(["无法提取论文结构", "将使用全文内容进行处理"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 使用全文内容进行段落切分
if paper and paper.full_text:
# 使用增强的分割函数进行更细致的分割
fragments = self._breakdown_section_content(paper.full_text)
# 创建文本片段对象
text_fragments = []
for i, frag in enumerate(fragments):
if frag.strip():
text_fragments.append(TextFragment(
content=frag,
fragment_index=i,
total_fragments=len(fragments)
))
# 批量处理片段
if text_fragments:
self.chatbot[-1] = ["开始处理文本", f"{len(text_fragments)} 个片段"]
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 一次性准备所有输入
inputs_array, inputs_show_user_array, history_array = self._create_batch_inputs(text_fragments)
# 使用系统提示
instruction = self.plugin_kwargs.get("advanced_arg", "请润色以下学术文本,提高其语言表达的准确性、专业性和流畅度,保持学术风格,确保逻辑连贯,但不改变原文的科学内容和核心观点")
sys_prompt_array = [f"你是一个专业的学术文献编辑助手。请按照用户的要求:'{instruction}'处理文本。保持学术风格,增强表达的准确性和专业性。"] * len(text_fragments)
# 调用LLM一次性处理所有片段
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
inputs_array=inputs_array,
inputs_show_user_array=inputs_show_user_array,
llm_kwargs=self.llm_kwargs,
chatbot=self.chatbot,
history_array=history_array,
sys_prompt_array=sys_prompt_array,
)
# 处理响应
for j, frag in enumerate(text_fragments):
try:
llm_response = response_collection[j * 2 + 1]
processed_text = self._extract_decision(llm_response)
if processed_text and processed_text.strip():
self.processed_results.append({
'index': frag.fragment_index,
'content': processed_text
})
else:
self.failed_fragments.append(frag)
self.processed_results.append({
'index': frag.fragment_index,
'content': frag.content
})
except Exception as e:
self.failed_fragments.append(frag)
self.processed_results.append({
'index': frag.fragment_index,
'content': frag.content
})
# 按原始顺序合并结果
self.processed_results.sort(key=lambda x: x['index'])
final_content = "\n".join([item['content'] for item in self.processed_results])
# 更新UI
success_count = len(text_fragments) - len(self.failed_fragments)
self.chatbot[-1] = ["处理完成", f"成功处理 {success_count}/{len(text_fragments)} 个片段"]
yield from update_ui(chatbot=self.chatbot, history=self.history)
return final_content
else:
self.chatbot.append(["处理失败", "未能提取到有效的文本内容"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return None
else:
self.chatbot.append(["处理失败", "未能提取到论文内容"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return None
# 2. 准备处理章节内容(不处理标题)
self.chatbot[-1] = ["已提取论文结构", f"{len(paper.sections)} 个主要章节"]
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 3. 收集所有需要处理的章节内容并分割为合适大小
sections_to_process = []
section_map = {} # 用于映射处理前后的内容
def collect_section_contents(sections, parent_path=""):
"""递归收集章节内容,跳过参考文献部分"""
for i, section in enumerate(sections):
current_path = f"{parent_path}/{i}" if parent_path else f"{i}"
# 检查是否为参考文献部分,如果是则跳过
if section.section_type == 'references' or section.title.lower() in ['references', '参考文献', 'bibliography', '文献']:
continue # 跳过参考文献部分
# 只处理内容非空的章节
if section.content and section.content.strip():
# 使用增强的分割函数进行更细致的分割
fragments = self._breakdown_section_content(section.content)
for fragment_idx, fragment_content in enumerate(fragments):
if fragment_content.strip():
fragment_index = len(sections_to_process)
sections_to_process.append(TextFragment(
content=fragment_content,
fragment_index=fragment_index,
total_fragments=0 # 临时值,稍后更新
))
# 保存映射关系,用于稍后更新章节内容
# 为每个片段存储原始章节和片段索引信息
section_map[fragment_index] = (current_path, section, fragment_idx, len(fragments))
# 递归处理子章节
if section.subsections:
collect_section_contents(section.subsections, current_path)
# 收集所有章节内容
collect_section_contents(paper.sections)
# 更新总片段数
total_fragments = len(sections_to_process)
for frag in sections_to_process:
frag.total_fragments = total_fragments
# 4. 如果没有内容需要处理,直接返回
if not sections_to_process:
self.chatbot.append(["处理完成", "未找到需要处理的内容"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return None
# 5. 批量处理章节内容
self.chatbot[-1] = ["开始处理论文内容", f"{len(sections_to_process)} 个内容片段"]
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 一次性准备所有输入
inputs_array, inputs_show_user_array, history_array = self._create_batch_inputs(sections_to_process)
# 使用系统提示
instruction = self.plugin_kwargs.get("advanced_arg", "请润色以下学术文本,提高其语言表达的准确性、专业性和流畅度,保持学术风格,确保逻辑连贯,但不改变原文的科学内容和核心观点")
sys_prompt_array = [f"你是一个专业的学术文献编辑助手。请按照用户的要求:'{instruction}'处理文本。保持学术风格,增强表达的准确性和专业性。"] * len(sections_to_process)
# 调用LLM一次性处理所有片段
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
inputs_array=inputs_array,
inputs_show_user_array=inputs_show_user_array,
llm_kwargs=self.llm_kwargs,
chatbot=self.chatbot,
history_array=history_array,
sys_prompt_array=sys_prompt_array,
)
# 处理响应,重组章节内容
section_contents = {} # 用于重组各章节的处理后内容
for j, frag in enumerate(sections_to_process):
try:
llm_response = response_collection[j * 2 + 1]
processed_text = self._extract_decision(llm_response)
if processed_text and processed_text.strip():
# 保存处理结果
self.processed_results.append({
'index': frag.fragment_index,
'content': processed_text
})
# 存储处理后的文本片段,用于后续重组
fragment_index = frag.fragment_index
if fragment_index in section_map:
path, section, fragment_idx, total_fragments = section_map[fragment_index]
# 初始化此章节的内容容器(如果尚未创建)
if path not in section_contents:
section_contents[path] = [""] * total_fragments
# 将处理后的片段放入正确位置
section_contents[path][fragment_idx] = processed_text
else:
self.failed_fragments.append(frag)
except Exception as e:
self.failed_fragments.append(frag)
# 重组每个章节的内容
for path, fragments in section_contents.items():
section = None
for idx in section_map:
if section_map[idx][0] == path:
section = section_map[idx][1]
break
if section:
# 合并该章节的所有处理后片段
section.content = "\n".join(fragments)
# 6. 更新UI
success_count = total_fragments - len(self.failed_fragments)
self.chatbot[-1] = ["处理完成", f"成功处理 {success_count}/{total_fragments} 个内容片段"]
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 收集参考文献部分(不进行处理)
references_sections = []
def collect_references(sections, parent_path=""):
"""递归收集参考文献部分"""
for i, section in enumerate(sections):
current_path = f"{parent_path}/{i}" if parent_path else f"{i}"
# 检查是否为参考文献部分
if section.section_type == 'references' or section.title.lower() in ['references', '参考文献', 'bibliography', '文献']:
references_sections.append((current_path, section))
# 递归检查子章节
if section.subsections:
collect_references(section.subsections, current_path)
# 收集参考文献
collect_references(paper.sections)
# 7. 将处理后的结构化论文转换为Markdown
markdown_content = self.paper_extractor.generate_markdown(paper)
# 8. 返回处理后的内容
self.chatbot[-1] = ["处理完成", f"成功处理 {success_count}/{total_fragments} 个内容片段,参考文献部分未处理"]
yield from update_ui(chatbot=self.chatbot, history=self.history)
return markdown_content
except Exception as e:
self.chatbot.append(["结构化处理失败", f"错误: {str(e)},将尝试作为普通文件处理"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return (yield from self._process_regular_file(file_path))
def _process_regular_file(self, file_path: str) -> Generator:
"""使用原有方式处理普通文件"""
# 原有的文件处理逻辑
self.chatbot[-1] = ["正在读取文件", f"文件路径: {file_path}"]
yield from update_ui(chatbot=self.chatbot, history=self.history)
content = extract_text(file_path)
if not content or not content.strip():
self.chatbot.append(["处理失败", "文件内容为空或无法提取内容"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return None
# 2. 分割文本
self.chatbot[-1] = ["正在分析文件", "将文件内容分割为适当大小的片段"]
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 使用增强的分割函数
fragments = self._breakdown_section_content(content)
# 3. 创建文本片段对象
text_fragments = []
for i, frag in enumerate(fragments):
if frag.strip():
text_fragments.append(TextFragment(
content=frag,
fragment_index=i,
total_fragments=len(fragments)
))
# 4. 处理所有片段
self.chatbot[-1] = ["开始处理文本", f"{len(text_fragments)} 个片段"]
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 批量处理片段
batch_size = 8 # 每批处理的片段数
for i in range(0, len(text_fragments), batch_size):
batch = text_fragments[i:i + batch_size]
inputs_array, inputs_show_user_array, history_array = self._create_batch_inputs(batch)
# 使用系统提示
instruction = self.plugin_kwargs.get("advanced_arg", "请润色以下文本")
sys_prompt_array = [f"你是一个专业的文本处理助手。请按照用户的要求:'{instruction}'处理文本。"] * len(batch)
# 调用LLM处理
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
inputs_array=inputs_array,
inputs_show_user_array=inputs_show_user_array,
llm_kwargs=self.llm_kwargs,
chatbot=self.chatbot,
history_array=history_array,
sys_prompt_array=sys_prompt_array,
)
# 处理响应
for j, frag in enumerate(batch):
try:
llm_response = response_collection[j * 2 + 1]
processed_text = self._extract_decision(llm_response)
if processed_text and processed_text.strip():
self.processed_results.append({
'index': frag.fragment_index,
'content': processed_text
})
else:
self.failed_fragments.append(frag)
self.processed_results.append({
'index': frag.fragment_index,
'content': frag.content # 如果处理失败,使用原始内容
})
except Exception as e:
self.failed_fragments.append(frag)
self.processed_results.append({
'index': frag.fragment_index,
'content': frag.content # 如果处理失败,使用原始内容
})
# 5. 按原始顺序合并结果
self.processed_results.sort(key=lambda x: x['index'])
final_content = "\n".join([item['content'] for item in self.processed_results])
# 6. 更新UI
success_count = len(text_fragments) - len(self.failed_fragments)
self.chatbot[-1] = ["处理完成", f"成功处理 {success_count}/{len(text_fragments)} 个片段"]
yield from update_ui(chatbot=self.chatbot, history=self.history)
return final_content
def save_results(self, content: str, original_file_path: str) -> List[str]:
"""保存处理结果为多种格式"""
if not content:
return []
timestamp = time.strftime("%Y%m%d_%H%M%S")
original_filename = os.path.basename(original_file_path)
filename_without_ext = os.path.splitext(original_filename)[0]
base_filename = f"{filename_without_ext}_processed_{timestamp}"
result_files = []
# 获取用户指定的处理类型
processing_type = self.plugin_kwargs.get("advanced_arg", "文本处理")
# 1. 保存为TXT
try:
txt_formatter = TxtFormatter()
txt_content = txt_formatter.create_document(content)
txt_file = write_history_to_file(
history=[txt_content],
file_basename=f"{base_filename}.txt"
)
result_files.append(txt_file)
except Exception as e:
self.chatbot.append(["警告", f"TXT格式保存失败: {str(e)}"])
# 2. 保存为Markdown
try:
md_formatter = MarkdownFormatter()
md_content = md_formatter.create_document(content, processing_type)
md_file = write_history_to_file(
history=[md_content],
file_basename=f"{base_filename}.md"
)
result_files.append(md_file)
except Exception as e:
self.chatbot.append(["警告", f"Markdown格式保存失败: {str(e)}"])
# 3. 保存为HTML
try:
html_formatter = HtmlFormatter(processing_type=processing_type)
html_content = html_formatter.create_document(content)
html_file = write_history_to_file(
history=[html_content],
file_basename=f"{base_filename}.html"
)
result_files.append(html_file)
except Exception as e:
self.chatbot.append(["警告", f"HTML格式保存失败: {str(e)}"])
# 4. 保存为Word
try:
word_formatter = WordFormatter()
doc = word_formatter.create_document(content, processing_type)
# 获取保存路径
from toolbox import get_log_folder
word_path = os.path.join(get_log_folder(), f"{base_filename}.docx")
doc.save(word_path)
# 5. 保存为PDF通过Word转换
try:
from crazy_functions.paper_fns.file2file_doc.word2pdf import WordToPdfConverter
pdf_path = WordToPdfConverter.convert_to_pdf(word_path)
result_files.append(pdf_path)
except Exception as e:
self.chatbot.append(["警告", f"PDF格式保存失败: {str(e)}"])
except Exception as e:
self.chatbot.append(["警告", f"Word格式保存失败: {str(e)}"])
# 添加到下载区
for file in result_files:
promote_file_to_downloadzone(file, chatbot=self.chatbot)
return result_files
def _breakdown_section_content(self, content: str) -> List[str]:
"""对文本内容进行分割与合并
主要按段落进行组织,只合并较小的段落以减少片段数量
保留原始段落结构,不对长段落进行强制分割
针对中英文设置不同的阈值,因为字符密度不同
"""
# 先按段落分割文本
paragraphs = content.split('\n\n')
# 检测语言类型
chinese_char_count = sum(1 for char in content if '\u4e00' <= char <= '\u9fff')
is_chinese_text = chinese_char_count / max(1, len(content)) > 0.3
# 根据语言类型设置不同的阈值(只用于合并小段落)
if is_chinese_text:
# 中文文本:一个汉字就是一个字符,信息密度高
min_chunk_size = 300 # 段落合并的最小阈值
target_size = 800 # 理想的段落大小
else:
# 英文文本:一个单词由多个字符组成,信息密度低
min_chunk_size = 600 # 段落合并的最小阈值
target_size = 1600 # 理想的段落大小
# 1. 只合并小段落,不对长段落进行分割
result_fragments = []
current_chunk = []
current_length = 0
for para in paragraphs:
# 如果段落太小且不会超过目标大小,则合并
if len(para) < min_chunk_size and current_length + len(para) <= target_size:
current_chunk.append(para)
current_length += len(para)
# 否则,创建新段落
else:
# 如果当前块非空且与当前段落无关,先保存它
if current_chunk and current_length > 0:
result_fragments.append('\n\n'.join(current_chunk))
# 当前段落作为新块
current_chunk = [para]
current_length = len(para)
# 如果当前块大小已接近目标大小,保存并开始新块
if current_length >= target_size:
result_fragments.append('\n\n'.join(current_chunk))
current_chunk = []
current_length = 0
# 保存最后一个块
if current_chunk:
result_fragments.append('\n\n'.join(current_chunk))
# 2. 处理可能过大的片段确保不超过token限制
final_fragments = []
max_token = self._get_token_limit()
for fragment in result_fragments:
# 检查fragment是否可能超出token限制
# 根据语言类型调整token估算
if is_chinese_text:
estimated_tokens = len(fragment) / 1.5 # 中文每个token约1-2个字符
else:
estimated_tokens = len(fragment) / 4 # 英文每个token约4个字符
if estimated_tokens > max_token:
# 即使可能超出限制,也尽量保持段落的完整性
# 使用breakdown_text但设置更大的限制来减少分割
larger_limit = max_token * 0.95 # 使用95%的限制
sub_fragments = breakdown_text_to_satisfy_token_limit(
txt=fragment,
limit=larger_limit,
llm_model=self.llm_kwargs['llm_model']
)
final_fragments.extend(sub_fragments)
else:
final_fragments.append(fragment)
return final_fragments
@CatchException
def 自定义智能文档处理(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
history: List, system_prompt: str, user_request: str):
"""主函数 - 文件到文件处理"""
# 初始化
processor = DocumentProcessor(llm_kwargs, plugin_kwargs, chatbot, history, system_prompt)
chatbot.append(["函数插件功能", "文件内容处理:将文档内容按照指定要求处理后输出为多种格式"])
yield from update_ui(chatbot=chatbot, history=history)
# 验证输入路径
if not os.path.exists(txt):
report_exception(chatbot, history, a=f"解析路径: {txt}", b=f"找不到路径或无权访问: {txt}")
yield from update_ui(chatbot=chatbot, history=history)
return
# 验证路径安全性
user_name = chatbot.get_user()
validate_path_safety(txt, user_name)
# 获取文件列表
if os.path.isfile(txt):
# 单个文件处理
file_paths = [txt]
else:
# 目录处理 - 类似批量文件询问插件
project_folder = txt
extract_folder = next((d for d in glob.glob(f'{project_folder}/*')
if os.path.isdir(d) and d.endswith('.extract')), project_folder)
# 排除压缩文件
exclude_patterns = r'/[^/]+\.(zip|rar|7z|tar|gz)$'
file_paths = [f for f in glob.glob(f'{extract_folder}/**', recursive=True)
if os.path.isfile(f) and not re.search(exclude_patterns, f)]
# 过滤支持的文件格式
file_paths = [f for f in file_paths if any(f.lower().endswith(ext) for ext in
list(processor.paper_extractor.SUPPORTED_EXTENSIONS) + ['.json', '.csv', '.xlsx', '.xls'])]
if not file_paths:
report_exception(chatbot, history, a=f"解析路径: {txt}", b="未找到支持的文件类型")
yield from update_ui(chatbot=chatbot, history=history)
return
# 处理文件
if len(file_paths) > 1:
chatbot.append(["发现多个文件", f"共找到 {len(file_paths)} 个文件,将处理第一个文件"])
yield from update_ui(chatbot=chatbot, history=history)
# 只处理第一个文件
file_to_process = file_paths[0]
processed_content = yield from processor.process_file(file_to_process)
if processed_content:
# 保存结果
result_files = processor.save_results(processed_content, file_to_process)
if result_files:
chatbot.append(["处理完成", f"已生成 {len(result_files)} 个结果文件"])
else:
chatbot.append(["处理完成", "但未能保存任何结果文件"])
else:
chatbot.append(["处理失败", "未能生成有效的处理结果"])
yield from update_ui(chatbot=chatbot, history=history)

查看文件

@@ -7,7 +7,7 @@ from bs4 import BeautifulSoup
from functools import lru_cache
from itertools import zip_longest
from check_proxy import check_proxy
from toolbox import CatchException, update_ui, get_conf, update_ui_lastest_msg
from toolbox import CatchException, update_ui, get_conf, update_ui_latest_msg
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, input_clipping
from request_llms.bridge_all import model_info
from request_llms.bridge_all import predict_no_ui_long_connection
@@ -49,7 +49,7 @@ def search_optimizer(
mutable = ["", time.time(), ""]
llm_kwargs["temperature"] = 0.8
try:
querys_json = predict_no_ui_long_connection(
query_json = predict_no_ui_long_connection(
inputs=query,
llm_kwargs=llm_kwargs,
history=[],
@@ -57,31 +57,31 @@ def search_optimizer(
observe_window=mutable,
)
except Exception:
querys_json = "1234"
query_json = "null"
#* 尝试解码优化后的搜索结果
querys_json = re.sub(r"```json|```", "", querys_json)
query_json = re.sub(r"```json|```", "", query_json)
try:
querys = json.loads(querys_json)
queries = json.loads(query_json)
except Exception:
#* 如果解码失败,降低温度再试一次
try:
llm_kwargs["temperature"] = 0.4
querys_json = predict_no_ui_long_connection(
query_json = predict_no_ui_long_connection(
inputs=query,
llm_kwargs=llm_kwargs,
history=[],
sys_prompt=sys_prompt,
observe_window=mutable,
)
querys_json = re.sub(r"```json|```", "", querys_json)
querys = json.loads(querys_json)
query_json = re.sub(r"```json|```", "", query_json)
queries = json.loads(query_json)
except Exception:
#* 如果再次失败,直接返回原始问题
querys = [query]
queries = [query]
links = []
success = 0
Exceptions = ""
for q in querys:
for q in queries:
try:
link = searxng_request(q, proxies, categories, searxng_url, engines=engines)
if len(link) > 0:
@@ -175,10 +175,17 @@ def scrape_text(url, proxies) -> str:
Returns:
str: The scraped text
"""
from loguru import logger
headers = {
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/94.0.4606.61 Safari/537.36',
'Content-Type': 'text/plain',
}
# 首先采用Jina进行文本提取
if get_conf("JINA_API_KEY"):
try: return jina_scrape_text(url)
except: logger.debug("Jina API 请求失败,回到旧方法")
try:
response = requests.get(url, headers=headers, proxies=proxies, timeout=8)
if response.encoding == "ISO-8859-1": response.encoding = response.apparent_encoding
@@ -193,21 +200,39 @@ def scrape_text(url, proxies) -> str:
text = "\n".join(chunk for chunk in chunks if chunk)
return text
def jina_scrape_text(url) -> str:
"jina_39727421c8fa4e4fa9bd698e5211feaaDyGeVFESNrRaepWiLT0wmHYJSh-d"
headers = {
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/94.0.4606.61 Safari/537.36',
'Content-Type': 'text/plain',
"X-Retain-Images": "none",
"Authorization": f'Bearer {get_conf("JINA_API_KEY")}'
}
response = requests.get("https://r.jina.ai/" + url, headers=headers, proxies=None, timeout=8)
if response.status_code != 200:
raise ValueError("Jina API 请求失败,开始尝试旧方法!" + response.text)
if response.encoding == "ISO-8859-1": response.encoding = response.apparent_encoding
result = response.text
result = result.replace("\\[", "[").replace("\\]", "]").replace("\\(", "(").replace("\\)", ")")
return response.text
def internet_search_with_analysis_prompt(prompt, analysis_prompt, llm_kwargs, chatbot):
from toolbox import get_conf
proxies = get_conf('proxies')
categories = 'general'
searxng_url = None # 使用默认的searxng_url
engines = None # 使用默认的搜索引擎
yield from update_ui_lastest_msg(lastmsg=f"检索中: {prompt} ...", chatbot=chatbot, history=[], delay=1)
yield from update_ui_latest_msg(lastmsg=f"检索中: {prompt} ...", chatbot=chatbot, history=[], delay=1)
urls = searxng_request(prompt, proxies, categories, searxng_url, engines=engines)
yield from update_ui_lastest_msg(lastmsg=f"依次访问搜索到的网站 ...", chatbot=chatbot, history=[], delay=1)
yield from update_ui_latest_msg(lastmsg=f"依次访问搜索到的网站 ...", chatbot=chatbot, history=[], delay=1)
if len(urls) == 0:
return None
max_search_result = 5 # 最多收纳多少个网页的结果
history = []
for index, url in enumerate(urls[:max_search_result]):
yield from update_ui_lastest_msg(lastmsg=f"依次访问搜索到的网站: {url['link']} ...", chatbot=chatbot, history=[], delay=1)
yield from update_ui_latest_msg(lastmsg=f"依次访问搜索到的网站: {url['link']} ...", chatbot=chatbot, history=[], delay=1)
res = scrape_text(url['link'], proxies)
prefix = f"{index}份搜索结果 [源自{url['source'][0]}搜索] {url['title'][:25]}"
history.extend([prefix, res])
@@ -222,7 +247,7 @@ def internet_search_with_analysis_prompt(prompt, analysis_prompt, llm_kwargs, ch
llm_kwargs=llm_kwargs,
history=history,
sys_prompt="请从搜索结果中抽取信息,对最相关的两个搜索结果进行总结,然后回答问题。",
console_slience=False,
console_silence=False,
)
return gpt_say
@@ -246,23 +271,52 @@ def 连接网络回答问题(txt, llm_kwargs, plugin_kwargs, chatbot, history, s
urls = search_optimizer(txt, proxies, optimizer_history, llm_kwargs, optimizer, categories, searxng_url, engines)
history = []
if len(urls) == 0:
chatbot.append((f"结论:{txt}",
"[Local Message] 受到限制,无法从searxng获取信息请尝试更换搜索引擎。"))
chatbot.append((f"结论:{txt}", "[Local Message] 受到限制,无法从searxng获取信息请尝试更换搜索引擎。"))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return
# ------------- < 第2步依次访问网页 > -------------
from concurrent.futures import ThreadPoolExecutor
from textwrap import dedent
max_search_result = 5 # 最多收纳多少个网页的结果
if optimizer == "开启(增强)":
max_search_result = 8
chatbot.append(["联网检索中 ...", None])
for index, url in enumerate(urls[:max_search_result]):
res = scrape_text(url['link'], proxies)
prefix = f"{index}份搜索结果 [源自{url['source'][0]}搜索] {url['title'][:25]}"
history.extend([prefix, res])
res_squeeze = res.replace('\n', '...')
chatbot[-1] = [prefix + "\n\n" + res_squeeze[:500] + "......", None]
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
template = dedent("""
<details>
<summary>{TITLE}</summary>
<div class="search_result">{URL}</div>
<div class="search_result">{CONTENT}</div>
</details>
""")
buffer = ""
# 创建线程池
with ThreadPoolExecutor(max_workers=5) as executor:
# 提交任务到线程池
futures = []
for index, url in enumerate(urls[:max_search_result]):
future = executor.submit(scrape_text, url['link'], proxies)
futures.append((index, future, url))
# 处理完成的任务
for index, future, url in futures:
# 开始
prefix = f"正在加载 第{index+1}份搜索结果 [源自{url['source'][0]}搜索] {url['title'][:25]}"
string_structure = template.format(TITLE=prefix, URL=url['link'], CONTENT="正在加载,请稍后 ......")
yield from update_ui_latest_msg(lastmsg=(buffer + string_structure), chatbot=chatbot, history=history, delay=0.1) # 刷新界面
# 获取结果
res = future.result()
# 显示结果
prefix = f"{index+1}份搜索结果 [源自{url['source'][0]}搜索] {url['title'][:25]}"
string_structure = template.format(TITLE=prefix, URL=url['link'], CONTENT=res[:1000] + "......")
buffer += string_structure
# 更新历史
history.extend([prefix, res])
yield from update_ui_latest_msg(lastmsg=buffer, chatbot=chatbot, history=history, delay=0.1) # 刷新界面
# ------------- < 第3步ChatGPT综合 > -------------
if (optimizer != "开启(增强)"):

查看文件

@@ -38,11 +38,12 @@ class NetworkGPT_Wrap(GptAcademicPluginTemplate):
}
return gui_definition
def execute(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
def execute(txt, llm_kwargs, plugin_kwargs:dict, chatbot, history, system_prompt, user_request):
"""
执行插件
"""
if plugin_kwargs["categories"] == "网页": plugin_kwargs["categories"] = "general"
if plugin_kwargs["categories"] == "学术论文": plugin_kwargs["categories"] = "science"
if plugin_kwargs.get("categories", None) == "网页": plugin_kwargs["categories"] = "general"
elif plugin_kwargs.get("categories", None) == "学术论文": plugin_kwargs["categories"] = "science"
else: plugin_kwargs["categories"] = "general"
yield from 连接网络回答问题(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)

查看文件

@@ -1,5 +1,5 @@
from toolbox import update_ui, trimmed_format_exc, get_conf, get_log_folder, promote_file_to_downloadzone, check_repeat_upload, map_file_to_sha256
from toolbox import CatchException, report_exception, update_ui_lastest_msg, zip_result, gen_time_str
from toolbox import CatchException, report_exception, update_ui_latest_msg, zip_result, gen_time_str
from functools import partial
from loguru import logger
@@ -41,7 +41,7 @@ def switch_prompt(pfg, mode, more_requirement):
return inputs_array, sys_prompt_array
def desend_to_extracted_folder_if_exist(project_folder):
def descend_to_extracted_folder_if_exist(project_folder):
"""
Descend into the extracted folder if it exists, otherwise return the original folder.
@@ -130,7 +130,7 @@ def arxiv_download(chatbot, history, txt, allow_cache=True):
if not txt.startswith('https://arxiv.org/abs/'):
msg = f"解析arxiv网址失败, 期望格式例如: https://arxiv.org/abs/1707.06690。实际得到格式: {url_}"
yield from update_ui_lastest_msg(msg, chatbot=chatbot, history=history) # 刷新界面
yield from update_ui_latest_msg(msg, chatbot=chatbot, history=history) # 刷新界面
return msg, None
# <-------------- set format ------------->
arxiv_id = url_.split('/abs/')[-1]
@@ -156,16 +156,16 @@ def arxiv_download(chatbot, history, txt, allow_cache=True):
return False
if os.path.exists(dst) and allow_cache:
yield from update_ui_lastest_msg(f"调用缓存 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
yield from update_ui_latest_msg(f"调用缓存 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
success = True
else:
yield from update_ui_lastest_msg(f"开始下载 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
yield from update_ui_latest_msg(f"开始下载 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
success = fix_url_and_download()
yield from update_ui_lastest_msg(f"下载完成 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
yield from update_ui_latest_msg(f"下载完成 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
if not success:
yield from update_ui_lastest_msg(f"下载失败 {arxiv_id}", chatbot=chatbot, history=history)
yield from update_ui_latest_msg(f"下载失败 {arxiv_id}", chatbot=chatbot, history=history)
raise tarfile.ReadError(f"论文下载失败 {arxiv_id}")
# <-------------- extract file ------------->
@@ -288,7 +288,7 @@ def Latex英文纠错加PDF对比(txt, llm_kwargs, plugin_kwargs, chatbot, histo
return
# <-------------- if is a zip/tar file ------------->
project_folder = desend_to_extracted_folder_if_exist(project_folder)
project_folder = descend_to_extracted_folder_if_exist(project_folder)
# <-------------- move latex project away from temp folder ------------->
from shared_utils.fastapi_server import validate_path_safety
@@ -365,7 +365,7 @@ def Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot,
try:
txt, arxiv_id = yield from arxiv_download(chatbot, history, txt, allow_cache)
except tarfile.ReadError as e:
yield from update_ui_lastest_msg(
yield from update_ui_latest_msg(
"无法自动下载该论文的Latex源码,请前往arxiv打开此论文下载页面,点other Formats,然后download source手动下载latex源码包。接下来调用本地Latex翻译插件即可。",
chatbot=chatbot, history=history)
return
@@ -404,7 +404,7 @@ def Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot,
return
# <-------------- if is a zip/tar file ------------->
project_folder = desend_to_extracted_folder_if_exist(project_folder)
project_folder = descend_to_extracted_folder_if_exist(project_folder)
# <-------------- move latex project away from temp folder ------------->
from shared_utils.fastapi_server import validate_path_safety
@@ -518,7 +518,7 @@ def PDF翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot, h
# repeat, project_folder = check_repeat_upload(file_manifest[0], hash_tag)
# if repeat:
# yield from update_ui_lastest_msg(f"发现重复上传,请查收结果(压缩包)...", chatbot=chatbot, history=history)
# yield from update_ui_latest_msg(f"发现重复上传,请查收结果(压缩包)...", chatbot=chatbot, history=history)
# try:
# translate_pdf = [f for f in glob.glob(f'{project_folder}/**/merge_translate_zh.pdf', recursive=True)][0]
# promote_file_to_downloadzone(translate_pdf, rename_file=None, chatbot=chatbot)
@@ -531,7 +531,7 @@ def PDF翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot, h
# report_exception(chatbot, history, a=f"解析项目: {txt}", b=f"发现重复上传,但是无法找到相关文件")
# yield from update_ui(chatbot=chatbot, history=history)
# else:
# yield from update_ui_lastest_msg(f"未发现重复上传", chatbot=chatbot, history=history)
# yield from update_ui_latest_msg(f"未发现重复上传", chatbot=chatbot, history=history)
# <-------------- convert pdf into tex ------------->
chatbot.append([f"解析项目: {txt}", "正在将PDF转换为tex项目,请耐心等待..."])
@@ -543,7 +543,7 @@ def PDF翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot, h
return False
# <-------------- translate latex file into Chinese ------------->
yield from update_ui_lastest_msg("正在tex项目将翻译为中文...", chatbot=chatbot, history=history)
yield from update_ui_latest_msg("正在tex项目将翻译为中文...", chatbot=chatbot, history=history)
file_manifest = [f for f in glob.glob(f'{project_folder}/**/*.tex', recursive=True)]
if len(file_manifest) == 0:
report_exception(chatbot, history, a=f"解析项目: {txt}", b=f"找不到任何.tex文件: {txt}")
@@ -551,7 +551,7 @@ def PDF翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot, h
return
# <-------------- if is a zip/tar file ------------->
project_folder = desend_to_extracted_folder_if_exist(project_folder)
project_folder = descend_to_extracted_folder_if_exist(project_folder)
# <-------------- move latex project away from temp folder ------------->
from shared_utils.fastapi_server import validate_path_safety
@@ -571,7 +571,7 @@ def PDF翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot, h
switch_prompt=_switch_prompt_)
# <-------------- compile PDF ------------->
yield from update_ui_lastest_msg("正在将翻译好的项目tex项目编译为PDF...", chatbot=chatbot, history=history)
yield from update_ui_latest_msg("正在将翻译好的项目tex项目编译为PDF...", chatbot=chatbot, history=history)
success = yield from 编译Latex(chatbot, history, main_file_original='merge',
main_file_modified='merge_translate_zh', mode='translate_zh',
work_folder_original=project_folder, work_folder_modified=project_folder,

查看文件

@@ -1,5 +1,5 @@
from toolbox import CatchException, check_packages, get_conf
from toolbox import update_ui, update_ui_lastest_msg, disable_auto_promotion
from toolbox import update_ui, update_ui_latest_msg, disable_auto_promotion
from toolbox import trimmed_format_exc_markdown
from crazy_functions.crazy_utils import get_files_from_everything
from crazy_functions.pdf_fns.parse_pdf import get_avail_grobid_url
@@ -57,9 +57,9 @@ def 批量翻译PDF文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
yield from 解析PDF_基于GROBID(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, grobid_url)
return
if method == "ClASSIC":
if method == "Classic":
# ------- 第三种方法,早期代码,效果不理想 -------
yield from update_ui_lastest_msg("GROBID服务不可用,请检查config中的GROBID_URL。作为替代,现在将执行效果稍差的旧版代码。", chatbot, history, delay=3)
yield from update_ui_latest_msg("GROBID服务不可用,请检查config中的GROBID_URL。作为替代,现在将执行效果稍差的旧版代码。", chatbot, history, delay=3)
yield from 解析PDF_简单拆解(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt)
return
@@ -77,7 +77,7 @@ def 批量翻译PDF文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
if grobid_url is not None:
yield from 解析PDF_基于GROBID(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, grobid_url)
return
yield from update_ui_lastest_msg("GROBID服务不可用,请检查config中的GROBID_URL。作为替代,现在将执行效果稍差的旧版代码。", chatbot, history, delay=3)
yield from update_ui_latest_msg("GROBID服务不可用,请检查config中的GROBID_URL。作为替代,现在将执行效果稍差的旧版代码。", chatbot, history, delay=3)
yield from 解析PDF_简单拆解(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt)
return

查看文件

@@ -19,7 +19,7 @@ class PDF_Tran(GptAcademicPluginTemplate):
"additional_prompt":
ArgProperty(title="额外提示词", description="例如:对专有名词、翻译语气等方面的要求", default_value="", type="string").model_dump_json(), # 高级参数输入区,自动同步
"pdf_parse_method":
ArgProperty(title="PDF解析方法", options=["DOC2X", "GROBID", "ClASSIC"], description="", default_value="GROBID", type="dropdown").model_dump_json(),
ArgProperty(title="PDF解析方法", options=["DOC2X", "GROBID", "Classic"], description="", default_value="GROBID", type="dropdown").model_dump_json(),
}
return gui_definition

查看文件

@@ -0,0 +1,360 @@
import os
import time
import glob
from pathlib import Path
from datetime import datetime
from dataclasses import dataclass
from typing import Dict, List, Generator, Tuple
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
from toolbox import update_ui, promote_file_to_downloadzone, write_history_to_file, CatchException, report_exception
from shared_utils.fastapi_server import validate_path_safety
from crazy_functions.paper_fns.paper_download import extract_paper_id, extract_paper_ids, get_arxiv_paper, format_arxiv_id
@dataclass
class PaperQuestion:
"""论文分析问题类"""
id: str # 问题ID
question: str # 问题内容
importance: int # 重要性 (1-5,5最高)
description: str # 问题描述
class PaperAnalyzer:
"""论文快速分析器"""
def __init__(self, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List, history: List, system_prompt: str):
"""初始化分析器"""
self.llm_kwargs = llm_kwargs
self.plugin_kwargs = plugin_kwargs
self.chatbot = chatbot
self.history = history
self.system_prompt = system_prompt
self.paper_content = ""
self.results = {}
# 定义论文分析问题库已合并为4个核心问题
self.questions = [
PaperQuestion(
id="research_and_methods",
question="这篇论文的主要研究问题、目标和方法是什么?请分析1)论文的核心研究问题和研究动机;2)论文提出的关键方法、模型或理论框架;3)这些方法如何解决研究问题。",
importance=5,
description="研究问题与方法"
),
PaperQuestion(
id="findings_and_innovation",
question="论文的主要发现、结论及创新点是什么?请分析1)论文的核心结果与主要发现;2)作者得出的关键结论;3)研究的创新点与对领域的贡献;4)与已有工作的区别。",
importance=4,
description="研究发现与创新"
),
PaperQuestion(
id="methodology_and_data",
question="论文使用了什么研究方法和数据?请详细分析1)研究设计与实验设置;2)数据收集方法与数据集特点;3)分析技术与评估方法;4)方法学上的合理性。",
importance=3,
description="研究方法与数据"
),
PaperQuestion(
id="limitations_and_impact",
question="论文的局限性、未来方向及潜在影响是什么?请分析1)研究的不足与限制因素;2)作者提出的未来研究方向;3)该研究对学术界和行业可能产生的影响;4)研究结果的适用范围与推广价值。",
importance=2,
description="局限性与影响"
),
]
# 按重要性排序
self.questions.sort(key=lambda q: q.importance, reverse=True)
def _load_paper(self, paper_path: str) -> Generator:
from crazy_functions.doc_fns.text_content_loader import TextContentLoader
"""加载论文内容"""
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 使用TextContentLoader读取文件
loader = TextContentLoader(self.chatbot, self.history)
yield from loader.execute_single_file(paper_path)
# 获取加载的内容
if len(self.history) >= 2 and self.history[-2]:
self.paper_content = self.history[-2]
yield from update_ui(chatbot=self.chatbot, history=self.history)
return True
else:
self.chatbot.append(["错误", "无法读取论文内容,请检查文件是否有效"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return False
def _analyze_question(self, question: PaperQuestion) -> Generator:
"""分析单个问题 - 直接显示问题和答案"""
try:
# 创建分析提示
prompt = f"请基于以下论文内容回答问题:\n\n{self.paper_content}\n\n问题:{question.question}"
# 使用单线程版本的请求函数
response = yield from request_gpt_model_in_new_thread_with_ui_alive(
inputs=prompt,
inputs_show_user=question.question, # 显示问题本身
llm_kwargs=self.llm_kwargs,
chatbot=self.chatbot,
history=[], # 空历史,确保每个问题独立分析
sys_prompt="你是一个专业的科研论文分析助手,需要仔细阅读论文内容并回答问题。请保持客观、准确,并基于论文内容提供深入分析。"
)
if response:
self.results[question.id] = response
return True
return False
except Exception as e:
self.chatbot.append(["错误", f"分析问题时出错: {str(e)}"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return False
def _generate_summary(self) -> Generator:
"""生成最终总结报告"""
self.chatbot.append(["生成报告", "正在整合分析结果,生成最终报告..."])
yield from update_ui(chatbot=self.chatbot, history=self.history)
summary_prompt = "请基于以下对论文的各个方面的分析,生成一份全面的论文解读报告。报告应该简明扼要地呈现论文的关键内容,并保持逻辑连贯性。"
for q in self.questions:
if q.id in self.results:
summary_prompt += f"\n\n关于{q.description}的分析:\n{self.results[q.id]}"
try:
# 使用单线程版本的请求函数,可以在前端实时显示生成结果
response = yield from request_gpt_model_in_new_thread_with_ui_alive(
inputs=summary_prompt,
inputs_show_user="生成论文解读报告",
llm_kwargs=self.llm_kwargs,
chatbot=self.chatbot,
history=[],
sys_prompt="你是一个科研论文解读专家,请将多个方面的分析整合为一份完整、连贯、有条理的报告。报告应当重点突出,层次分明,并且保持学术性和客观性。"
)
if response:
return response
return "报告生成失败"
except Exception as e:
self.chatbot.append(["错误", f"生成报告时出错: {str(e)}"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return "报告生成失败: " + str(e)
def save_report(self, report: str) -> Generator:
"""保存分析报告"""
timestamp = time.strftime("%Y%m%d_%H%M%S")
# 保存为Markdown文件
try:
md_content = f"# 论文快速解读报告\n\n{report}"
for q in self.questions:
if q.id in self.results:
md_content += f"\n\n## {q.description}\n\n{self.results[q.id]}"
result_file = write_history_to_file(
history=[md_content],
file_basename=f"论文解读_{timestamp}.md"
)
if result_file and os.path.exists(result_file):
promote_file_to_downloadzone(result_file, chatbot=self.chatbot)
self.chatbot.append(["保存成功", f"解读报告已保存至: {os.path.basename(result_file)}"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
else:
self.chatbot.append(["警告", "保存报告成功但找不到文件"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
except Exception as e:
self.chatbot.append(["警告", f"保存报告失败: {str(e)}"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
def analyze_paper(self, paper_path: str) -> Generator:
"""分析论文主流程"""
# 加载论文
success = yield from self._load_paper(paper_path)
if not success:
return
# 分析关键问题 - 直接询问每个问题,不显示进度信息
for question in self.questions:
yield from self._analyze_question(question)
# 生成总结报告
final_report = yield from self._generate_summary()
# 显示最终报告
# self.chatbot.append(["论文解读报告", final_report])
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 保存报告
yield from self.save_report(final_report)
def _find_paper_file(path: str) -> str:
"""查找路径中的论文文件(简化版)"""
if os.path.isfile(path):
return path
# 支持的文件扩展名(按优先级排序)
extensions = ["pdf", "docx", "doc", "txt", "md", "tex"]
# 简单地遍历目录
if os.path.isdir(path):
try:
for ext in extensions:
# 手动检查每个可能的文件,而不使用glob
potential_file = os.path.join(path, f"paper.{ext}")
if os.path.exists(potential_file) and os.path.isfile(potential_file):
return potential_file
# 如果没找到特定命名的文件,检查目录中的所有文件
for file in os.listdir(path):
file_path = os.path.join(path, file)
if os.path.isfile(file_path):
file_ext = file.split('.')[-1].lower() if '.' in file else ""
if file_ext in extensions:
return file_path
except Exception:
pass # 忽略任何错误
return None
def download_paper_by_id(paper_info, chatbot, history) -> str:
"""下载论文并返回保存路径
Args:
paper_info: 元组,包含论文ID类型arxiv或doi和ID值
chatbot: 聊天机器人对象
history: 历史记录
Returns:
str: 下载的论文路径或None
"""
from crazy_functions.review_fns.data_sources.scihub_source import SciHub
id_type, paper_id = paper_info
# 创建保存目录 - 使用时间戳创建唯一文件夹
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
user_name = chatbot.get_user() if hasattr(chatbot, 'get_user') else "default"
from toolbox import get_log_folder, get_user
base_save_dir = get_log_folder(get_user(chatbot), plugin_name='paper_download')
save_dir = os.path.join(base_save_dir, f"papers_{timestamp}")
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_path = Path(save_dir)
chatbot.append([f"下载论文", f"正在下载{'arXiv' if id_type == 'arxiv' else 'DOI'} {paper_id} 的论文..."])
update_ui(chatbot=chatbot, history=history)
pdf_path = None
try:
if id_type == 'arxiv':
# 使用改进的arxiv查询方法
formatted_id = format_arxiv_id(paper_id)
paper_result = get_arxiv_paper(formatted_id)
if not paper_result:
chatbot.append([f"下载失败", f"未找到arXiv论文: {paper_id}"])
update_ui(chatbot=chatbot, history=history)
return None
# 下载PDF
filename = f"arxiv_{paper_id.replace('/', '_')}.pdf"
pdf_path = str(save_path / filename)
paper_result.download_pdf(filename=pdf_path)
else: # doi
# 下载DOI
sci_hub = SciHub(
doi=paper_id,
path=save_path
)
pdf_path = sci_hub.fetch()
# 检查下载结果
if pdf_path and os.path.exists(pdf_path):
promote_file_to_downloadzone(pdf_path, chatbot=chatbot)
chatbot.append([f"下载成功", f"已成功下载论文: {os.path.basename(pdf_path)}"])
update_ui(chatbot=chatbot, history=history)
return pdf_path
else:
chatbot.append([f"下载失败", f"论文下载失败: {paper_id}"])
update_ui(chatbot=chatbot, history=history)
return None
except Exception as e:
chatbot.append([f"下载错误", f"下载论文时出错: {str(e)}"])
update_ui(chatbot=chatbot, history=history)
return None
@CatchException
def 快速论文解读(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
history: List, system_prompt: str, user_request: str):
"""主函数 - 论文快速解读"""
# 初始化分析器
chatbot.append(["函数插件功能及使用方式", "论文快速解读:通过分析论文的关键要素,帮助您迅速理解论文内容,适用于各学科领域的科研论文。 <br><br>📋 使用方式:<br>1、直接上传PDF文件或者输入DOI号仅针对SCI hub存在的论文或arXiv ID如2501.03916<br>2、点击插件开始分析"])
yield from update_ui(chatbot=chatbot, history=history)
paper_file = None
# 检查输入是否为论文IDarxiv或DOI
paper_info = extract_paper_id(txt)
if paper_info:
# 如果是论文ID,下载论文
chatbot.append(["检测到论文ID", f"检测到{'arXiv' if paper_info[0] == 'arxiv' else 'DOI'} ID: {paper_info[1]},准备下载论文..."])
yield from update_ui(chatbot=chatbot, history=history)
# 下载论文 - 完全重新实现
paper_file = download_paper_by_id(paper_info, chatbot, history)
if not paper_file:
report_exception(chatbot, history, a=f"下载论文失败", b=f"无法下载{'arXiv' if paper_info[0] == 'arxiv' else 'DOI'}论文: {paper_info[1]}")
yield from update_ui(chatbot=chatbot, history=history)
return
else:
# 检查输入路径
if not os.path.exists(txt):
report_exception(chatbot, history, a=f"解析论文: {txt}", b=f"找不到文件或无权访问: {txt}")
yield from update_ui(chatbot=chatbot, history=history)
return
# 验证路径安全性
user_name = chatbot.get_user()
validate_path_safety(txt, user_name)
# 查找论文文件
paper_file = _find_paper_file(txt)
if not paper_file:
report_exception(chatbot, history, a=f"解析论文", b=f"在路径 {txt} 中未找到支持的论文文件")
yield from update_ui(chatbot=chatbot, history=history)
return
yield from update_ui(chatbot=chatbot, history=history)
# 增加调试信息,检查paper_file的类型和值
chatbot.append(["文件类型检查", f"paper_file类型: {type(paper_file)}, 值: {paper_file}"])
yield from update_ui(chatbot=chatbot, history=history)
chatbot.pop() # 移除调试信息
# 确保paper_file是字符串
if paper_file is not None and not isinstance(paper_file, str):
# 尝试转换为字符串
try:
paper_file = str(paper_file)
except:
report_exception(chatbot, history, a=f"类型错误", b=f"论文路径不是有效的字符串: {type(paper_file)}")
yield from update_ui(chatbot=chatbot, history=history)
return
# 分析论文
chatbot.append(["开始分析", f"正在分析论文: {os.path.basename(paper_file)}"])
yield from update_ui(chatbot=chatbot, history=history)
analyzer = PaperAnalyzer(llm_kwargs, plugin_kwargs, chatbot, history, system_prompt)
yield from analyzer.analyze_paper(paper_file)

查看文件

@@ -4,7 +4,7 @@ from typing import List
from shared_utils.fastapi_server import validate_path_safety
from toolbox import report_exception
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_latest_msg
from shared_utils.fastapi_server import validate_path_safety
from crazy_functions.crazy_utils import input_clipping
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
@@ -92,7 +92,7 @@ def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, u
chatbot.append([txt, f'正在清空 ({current_context}) ...'])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
rag_worker.purge_vector_store()
yield from update_ui_lastest_msg('已清空', chatbot, history, delay=0) # 刷新界面
yield from update_ui_latest_msg('已清空', chatbot, history, delay=0) # 刷新界面
return
# 3. Normal Q&A processing
@@ -109,10 +109,10 @@ def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, u
# 5. If input is clipped, add input to vector store before retrieve
if input_is_clipped_flag:
yield from update_ui_lastest_msg('检测到长输入, 正在向量化 ...', chatbot, history, delay=0) # 刷新界面
yield from update_ui_latest_msg('检测到长输入, 正在向量化 ...', chatbot, history, delay=0) # 刷新界面
# Save input to vector store
rag_worker.add_text_to_vector_store(txt_origin)
yield from update_ui_lastest_msg('向量化完成 ...', chatbot, history, delay=0) # 刷新界面
yield from update_ui_latest_msg('向量化完成 ...', chatbot, history, delay=0) # 刷新界面
if len(txt_origin) > REMEMBER_PREVIEW:
HALF = REMEMBER_PREVIEW // 2
@@ -142,7 +142,7 @@ def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, u
)
# 8. Remember Q&A
yield from update_ui_lastest_msg(
yield from update_ui_latest_msg(
model_say + '</br></br>' + f'对话记忆中, 请稍等 ({current_context}) ...',
chatbot, history, delay=0.5
)
@@ -150,4 +150,4 @@ def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, u
history.extend([i_say, model_say])
# 9. Final UI Update
yield from update_ui_lastest_msg(model_say, chatbot, history, delay=0, msg=tip)
yield from update_ui_latest_msg(model_say, chatbot, history, delay=0, msg=tip)

查看文件

@@ -1,5 +1,5 @@
import pickle, os, random
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_latest_msg
from crazy_functions.crazy_utils import input_clipping
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
from request_llms.bridge_all import predict_no_ui_long_connection
@@ -9,7 +9,7 @@ from loguru import logger
from typing import List
SOCIAL_NETWOK_WORKER_REGISTER = {}
SOCIAL_NETWORK_WORKER_REGISTER = {}
class SocialNetwork():
def __init__(self):
@@ -78,7 +78,7 @@ class SocialNetworkWorker(SaveAndLoad):
for f in friend.friends_list:
self.add_friend(f)
msg = f"成功添加{len(friend.friends_list)}个联系人: {str(friend.friends_list)}"
yield from update_ui_lastest_msg(lastmsg=msg, chatbot=chatbot, history=history, delay=0)
yield from update_ui_latest_msg(lastmsg=msg, chatbot=chatbot, history=history, delay=0)
def run(self, txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
@@ -104,12 +104,12 @@ class SocialNetworkWorker(SaveAndLoad):
}
try:
Explaination = '\n'.join([f'{k}: {v["explain_to_llm"]}' for k, v in self.tools_to_select.items()])
Explanation = '\n'.join([f'{k}: {v["explain_to_llm"]}' for k, v in self.tools_to_select.items()])
class UserSociaIntention(BaseModel):
intention_type: str = Field(
description=
f"The type of user intention. You must choose from {self.tools_to_select.keys()}.\n\n"
f"Explaination:\n{Explaination}",
f"Explanation:\n{Explanation}",
default="SocialAdvice"
)
pydantic_cls_instance, err_msg = select_tool(
@@ -118,7 +118,7 @@ class SocialNetworkWorker(SaveAndLoad):
pydantic_cls=UserSociaIntention
)
except Exception as e:
yield from update_ui_lastest_msg(
yield from update_ui_latest_msg(
lastmsg=f"无法理解用户意图 {err_msg}",
chatbot=chatbot,
history=history,
@@ -150,10 +150,10 @@ def I人助手(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt,
# 1. we retrieve worker from global context
user_name = chatbot.get_user()
checkpoint_dir=get_log_folder(user_name, plugin_name='experimental_rag')
if user_name in SOCIAL_NETWOK_WORKER_REGISTER:
social_network_worker = SOCIAL_NETWOK_WORKER_REGISTER[user_name]
if user_name in SOCIAL_NETWORK_WORKER_REGISTER:
social_network_worker = SOCIAL_NETWORK_WORKER_REGISTER[user_name]
else:
social_network_worker = SOCIAL_NETWOK_WORKER_REGISTER[user_name] = SocialNetworkWorker(
social_network_worker = SOCIAL_NETWORK_WORKER_REGISTER[user_name] = SocialNetworkWorker(
user_name,
llm_kwargs,
checkpoint_dir=checkpoint_dir,

查看文件

@@ -1,5 +1,5 @@
import os, copy, time
from toolbox import CatchException, report_exception, update_ui, zip_result, promote_file_to_downloadzone, update_ui_lastest_msg, get_conf, generate_file_link
from toolbox import CatchException, report_exception, update_ui, zip_result, promote_file_to_downloadzone, update_ui_latest_msg, get_conf, generate_file_link
from shared_utils.fastapi_server import validate_path_safety
from crazy_functions.crazy_utils import input_clipping
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
@@ -117,7 +117,7 @@ def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
logger.error(f"文件: {fp} 的注释结果未能成功")
file_links = generate_file_link(preview_html_list)
yield from update_ui_lastest_msg(
yield from update_ui_latest_msg(
f"当前任务: <br/>{'<br/>'.join(tasks)}.<br/>" +
f"剩余源文件数量: {remain}.<br/>" +
f"已完成的文件: {sum(worker_done)}.<br/>" +

查看文件

@@ -7,7 +7,7 @@ from bs4 import BeautifulSoup
from functools import lru_cache
from itertools import zip_longest
from check_proxy import check_proxy
from toolbox import CatchException, update_ui, get_conf, promote_file_to_downloadzone, update_ui_lastest_msg, generate_file_link
from toolbox import CatchException, update_ui, get_conf, promote_file_to_downloadzone, update_ui_latest_msg, generate_file_link
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, input_clipping
from request_llms.bridge_all import model_info
from request_llms.bridge_all import predict_no_ui_long_connection
@@ -46,7 +46,7 @@ def download_video(bvid, user_name, chatbot, history):
# pause a while
tic_time = 8
for i in range(tic_time):
yield from update_ui_lastest_msg(
yield from update_ui_latest_msg(
lastmsg=f"即将下载音频。等待{tic_time-i}秒后自动继续, 点击“停止”键取消此操作。",
chatbot=chatbot, history=[], delay=1)
@@ -61,13 +61,13 @@ def download_video(bvid, user_name, chatbot, history):
# preview
preview_list = [promote_file_to_downloadzone(fp) for fp in downloaded_files]
file_links = generate_file_link(preview_list)
yield from update_ui_lastest_msg(f"已完成的文件: <br/>" + file_links, chatbot=chatbot, history=history, delay=0)
yield from update_ui_latest_msg(f"已完成的文件: <br/>" + file_links, chatbot=chatbot, history=history, delay=0)
chatbot.append((None, f"即将下载视频。"))
# pause a while
tic_time = 16
for i in range(tic_time):
yield from update_ui_lastest_msg(
yield from update_ui_latest_msg(
lastmsg=f"即将下载视频。等待{tic_time-i}秒后自动继续, 点击“停止”键取消此操作。",
chatbot=chatbot, history=[], delay=1)
@@ -78,7 +78,7 @@ def download_video(bvid, user_name, chatbot, history):
# preview
preview_list = [promote_file_to_downloadzone(fp) for fp in downloaded_files_part2]
file_links = generate_file_link(preview_list)
yield from update_ui_lastest_msg(f"已完成的文件: <br/>" + file_links, chatbot=chatbot, history=history, delay=0)
yield from update_ui_latest_msg(f"已完成的文件: <br/>" + file_links, chatbot=chatbot, history=history, delay=0)
# return
return downloaded_files + downloaded_files_part2
@@ -110,7 +110,7 @@ def 多媒体任务(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_pro
# 结构化生成
internet_search_keyword = user_wish
yield from update_ui_lastest_msg(lastmsg=f"发起互联网检索: {internet_search_keyword} ...", chatbot=chatbot, history=[], delay=1)
yield from update_ui_latest_msg(lastmsg=f"发起互联网检索: {internet_search_keyword} ...", chatbot=chatbot, history=[], delay=1)
from crazy_functions.Internet_GPT import internet_search_with_analysis_prompt
result = yield from internet_search_with_analysis_prompt(
prompt=internet_search_keyword,
@@ -119,7 +119,7 @@ def 多媒体任务(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_pro
chatbot=chatbot
)
yield from update_ui_lastest_msg(lastmsg=f"互联网检索结论: {result} \n\n 正在生成进一步检索方案 ...", chatbot=chatbot, history=[], delay=1)
yield from update_ui_latest_msg(lastmsg=f"互联网检索结论: {result} \n\n 正在生成进一步检索方案 ...", chatbot=chatbot, history=[], delay=1)
rf_req = dedent(f"""
The user wish to get the following resource:
{user_wish}
@@ -132,7 +132,7 @@ def 多媒体任务(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_pro
rf_req = dedent(f"""
The user wish to get the following resource:
{user_wish}
Generate reseach keywords (less than 5 keywords) accordingly.
Generate research keywords (less than 5 keywords) accordingly.
""")
gpt_json_io = GptJsonIO(Query)
inputs = rf_req + gpt_json_io.format_instructions
@@ -146,12 +146,12 @@ def 多媒体任务(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_pro
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
# 获取候选资源
candadate_dictionary: dict = get_video_resource(video_engine_keywords)
candadate_dictionary_as_str = json.dumps(candadate_dictionary, ensure_ascii=False, indent=4)
candidate_dictionary: dict = get_video_resource(video_engine_keywords)
candidate_dictionary_as_str = json.dumps(candidate_dictionary, ensure_ascii=False, indent=4)
# 展示候选资源
candadate_display = "\n".join([f"{i+1}. {it['title']}" for i, it in enumerate(candadate_dictionary)])
chatbot.append((None, f"候选:\n\n{candadate_display}"))
candidate_display = "\n".join([f"{i+1}. {it['title']}" for i, it in enumerate(candidate_dictionary)])
chatbot.append((None, f"候选:\n\n{candidate_display}"))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
# 结构化生成
@@ -160,7 +160,7 @@ def 多媒体任务(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_pro
{user_wish}
Select the most relevant and suitable video resource from the following search results:
{candadate_dictionary_as_str}
{candidate_dictionary_as_str}
Note:
1. The first several search video results are more likely to satisfy the user's wish.

查看文件

@@ -1,5 +1,5 @@
from toolbox import CatchException, update_ui, gen_time_str, trimmed_format_exc, ProxyNetworkActivate
from toolbox import report_exception, get_log_folder, update_ui_lastest_msg, Singleton
from toolbox import report_exception, get_log_folder, update_ui_latest_msg, Singleton
from crazy_functions.agent_fns.pipe import PluginMultiprocessManager, PipeCom
from crazy_functions.agent_fns.general import AutoGenGeneral

查看文件

@@ -8,7 +8,7 @@ class EchoDemo(PluginMultiprocessManager):
while True:
msg = self.child_conn.recv() # PipeCom
if msg.cmd == "user_input":
# wait futher user input
# wait father user input
self.child_conn.send(PipeCom("show", msg.content))
wait_success = self.subprocess_worker_wait_user_feedback(wait_msg="我准备好处理下一个问题了.")
if not wait_success:

查看文件

@@ -27,7 +27,7 @@ def gpt_academic_generate_oai_reply(
llm_kwargs=llm_config,
history=history,
sys_prompt=self._oai_system_message[0]['content'],
console_slience=True
console_silence=True
)
assumed_done = reply.endswith('\nTERMINATE')
return True, reply

查看文件

@@ -10,7 +10,7 @@ from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_
# TODO: 解决缩进问题
find_function_end_prompt = '''
Below is a page of code that you need to read. This page may not yet complete, you job is to split this page to sperate functions, class functions etc.
Below is a page of code that you need to read. This page may not yet complete, you job is to split this page to separate functions, class functions etc.
- Provide the line number where the first visible function ends.
- Provide the line number where the next visible function begins.
- If there are no other functions in this page, you should simply return the line number of the last line.
@@ -59,7 +59,7 @@ OUTPUT:
revise_funtion_prompt = '''
revise_function_prompt = '''
You need to read the following code, and revise the source code ({FILE_BASENAME}) according to following instructions:
1. You should analyze the purpose of the functions (if there are any).
2. You need to add docstring for the provided functions (if there are any).
@@ -117,7 +117,7 @@ def zip_result(folder):
'''
revise_funtion_prompt_chinese = '''
revise_function_prompt_chinese = '''
您需要阅读以下代码,并根据以下说明修订源代码({FILE_BASENAME}):
1. 如果源代码中包含函数的话, 你应该分析给定函数实现了什么功能
2. 如果源代码中包含函数的话, 你需要为函数添加docstring, docstring必须使用中文
@@ -188,9 +188,9 @@ class PythonCodeComment():
self.language = language
self.observe_window_update = observe_window_update
if self.language == "chinese":
self.core_prompt = revise_funtion_prompt_chinese
self.core_prompt = revise_function_prompt_chinese
else:
self.core_prompt = revise_funtion_prompt
self.core_prompt = revise_function_prompt
self.path = None
self.file_basename = None
self.file_brief = ""
@@ -222,7 +222,7 @@ class PythonCodeComment():
history=[],
sys_prompt="",
observe_window=[],
console_slience=True
console_silence=True
)
def extract_number(text):
@@ -316,7 +316,7 @@ class PythonCodeComment():
def tag_code(self, fn, hint):
code = fn
_, n_indent = self.dedent(code)
indent_reminder = "" if n_indent == 0 else "(Reminder: as you can see, this piece of code has indent made up with {n_indent} whitespace, please preseve them in the OUTPUT.)"
indent_reminder = "" if n_indent == 0 else "(Reminder: as you can see, this piece of code has indent made up with {n_indent} whitespace, please preserve them in the OUTPUT.)"
brief_reminder = "" if self.file_brief == "" else f"({self.file_basename} abstract: {self.file_brief})"
hint_reminder = "" if hint is None else f"(Reminder: do not ignore or modify code such as `{hint}`, provide complete code in the OUTPUT.)"
self.llm_kwargs['temperature'] = 0
@@ -333,7 +333,7 @@ class PythonCodeComment():
history=[],
sys_prompt="",
observe_window=[],
console_slience=True
console_silence=True
)
def get_code_block(reply):
@@ -400,7 +400,7 @@ class PythonCodeComment():
return revised
def begin_comment_source_code(self, chatbot=None, history=None):
# from toolbox import update_ui_lastest_msg
# from toolbox import update_ui_latest_msg
assert self.path is not None
assert '.py' in self.path # must be python source code
# write_target = self.path + '.revised.py'
@@ -409,10 +409,10 @@ class PythonCodeComment():
# with open(self.path + '.revised.py', 'w+', encoding='utf8') as f:
while True:
try:
# yield from update_ui_lastest_msg(f"({self.file_basename}) 正在读取下一段代码片段:\n", chatbot=chatbot, history=history, delay=0)
# yield from update_ui_latest_msg(f"({self.file_basename}) 正在读取下一段代码片段:\n", chatbot=chatbot, history=history, delay=0)
next_batch, line_no_start, line_no_end = self.get_next_batch()
self.observe_window_update(f"正在处理{self.file_basename} - {line_no_start}/{len(self.full_context)}\n")
# yield from update_ui_lastest_msg(f"({self.file_basename}) 处理代码片段:\n\n{next_batch}", chatbot=chatbot, history=history, delay=0)
# yield from update_ui_latest_msg(f"({self.file_basename}) 处理代码片段:\n\n{next_batch}", chatbot=chatbot, history=history, delay=0)
hint = None
MAX_ATTEMPT = 2

查看文件

@@ -1,7 +1,7 @@
import os
import threading
from loguru import logger
from shared_utils.char_visual_effect import scolling_visual_effect
from shared_utils.char_visual_effect import scrolling_visual_effect
from toolbox import update_ui, get_conf, trimmed_format_exc, get_max_token, Singleton
def input_clipping(inputs, history, max_token_limit, return_clip_flags=False):
@@ -256,7 +256,7 @@ def request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
# 【第一种情况】:顺利完成
gpt_say = predict_no_ui_long_connection(
inputs=inputs, llm_kwargs=llm_kwargs, history=history,
sys_prompt=sys_prompt, observe_window=mutable[index], console_slience=True
sys_prompt=sys_prompt, observe_window=mutable[index], console_silence=True
)
mutable[index][2] = "已成功"
return gpt_say
@@ -326,7 +326,7 @@ def request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
mutable[thread_index][1] = time.time()
# 在前端打印些好玩的东西
for thread_index, _ in enumerate(worker_done):
print_something_really_funny = f"[ ...`{scolling_visual_effect(mutable[thread_index][0], scroller_max_len)}`... ]"
print_something_really_funny = f"[ ...`{scrolling_visual_effect(mutable[thread_index][0], scroller_max_len)}`... ]"
observe_win.append(print_something_really_funny)
# 在前端打印些好玩的东西
stat_str = ''.join([f'`{mutable[thread_index][2]}`: {obs}\n\n'
@@ -389,11 +389,11 @@ def read_and_clean_pdf_text(fp):
"""
提取文本块主字体
"""
fsize_statiscs = {}
fsize_statistics = {}
for wtf in l['spans']:
if wtf['size'] not in fsize_statiscs: fsize_statiscs[wtf['size']] = 0
fsize_statiscs[wtf['size']] += len(wtf['text'])
return max(fsize_statiscs, key=fsize_statiscs.get)
if wtf['size'] not in fsize_statistics: fsize_statistics[wtf['size']] = 0
fsize_statistics[wtf['size']] += len(wtf['text'])
return max(fsize_statistics, key=fsize_statistics.get)
def ffsize_same(a,b):
"""
@@ -433,11 +433,11 @@ def read_and_clean_pdf_text(fp):
############################## <第 2 步,获取正文主字体> ##################################
try:
fsize_statiscs = {}
fsize_statistics = {}
for span in meta_span:
if span[1] not in fsize_statiscs: fsize_statiscs[span[1]] = 0
fsize_statiscs[span[1]] += span[2]
main_fsize = max(fsize_statiscs, key=fsize_statiscs.get)
if span[1] not in fsize_statistics: fsize_statistics[span[1]] = 0
fsize_statistics[span[1]] += span[2]
main_fsize = max(fsize_statistics, key=fsize_statistics.get)
if REMOVE_FOOT_NOTE:
give_up_fize_threshold = main_fsize * REMOVE_FOOT_FFSIZE_PERCENT
except:
@@ -610,9 +610,9 @@ class nougat_interface():
def NOUGAT_parse_pdf(self, fp, chatbot, history):
from toolbox import update_ui_lastest_msg
from toolbox import update_ui_latest_msg
yield from update_ui_lastest_msg("正在解析论文, 请稍候。进度:正在排队, 等待线程锁...",
yield from update_ui_latest_msg("正在解析论文, 请稍候。进度:正在排队, 等待线程锁...",
chatbot=chatbot, history=history, delay=0)
self.threadLock.acquire()
import glob, threading, os
@@ -620,7 +620,7 @@ class nougat_interface():
dst = os.path.join(get_log_folder(plugin_name='nougat'), gen_time_str())
os.makedirs(dst)
yield from update_ui_lastest_msg("正在解析论文, 请稍候。进度正在加载NOUGAT... 提示首次运行需要花费较长时间下载NOUGAT参数",
yield from update_ui_latest_msg("正在解析论文, 请稍候。进度正在加载NOUGAT... 提示首次运行需要花费较长时间下载NOUGAT参数",
chatbot=chatbot, history=history, delay=0)
command = ['nougat', '--out', os.path.abspath(dst), os.path.abspath(fp)]
self.nougat_with_timeout(command, cwd=os.getcwd(), timeout=3600)

查看文件

@@ -0,0 +1,812 @@
import os
import time
from abc import ABC, abstractmethod
from datetime import datetime
from docx import Document
from docx.enum.style import WD_STYLE_TYPE
from docx.enum.text import WD_PARAGRAPH_ALIGNMENT, WD_LINE_SPACING
from docx.oxml.ns import qn
from docx.shared import Inches, Cm
from docx.shared import Pt, RGBColor, Inches
from typing import Dict, List, Tuple
import markdown
from crazy_functions.doc_fns.conversation_doc.word_doc import convert_markdown_to_word
class DocumentFormatter(ABC):
"""文档格式化基类,定义文档格式化的基本接口"""
def __init__(self, final_summary: str, file_summaries_map: Dict, failed_files: List[Tuple]):
self.final_summary = final_summary
self.file_summaries_map = file_summaries_map
self.failed_files = failed_files
@abstractmethod
def format_failed_files(self) -> str:
"""格式化失败文件列表"""
pass
@abstractmethod
def format_file_summaries(self) -> str:
"""格式化文件总结内容"""
pass
@abstractmethod
def create_document(self) -> str:
"""创建完整文档"""
pass
class WordFormatter(DocumentFormatter):
"""Word格式文档生成器 - 符合中国政府公文格式规范(GB/T 9704-2012),并进行了优化"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.doc = Document()
self._setup_document()
self._create_styles()
# 初始化三级标题编号系统
self.numbers = {
1: 0, # 一级标题编号
2: 0, # 二级标题编号
3: 0 # 三级标题编号
}
def _setup_document(self):
"""设置文档基本格式,包括页面设置和页眉"""
sections = self.doc.sections
for section in sections:
# 设置页面大小为A4
section.page_width = Cm(21)
section.page_height = Cm(29.7)
# 设置页边距
section.top_margin = Cm(3.7) # 上边距37mm
section.bottom_margin = Cm(3.5) # 下边距35mm
section.left_margin = Cm(2.8) # 左边距28mm
section.right_margin = Cm(2.6) # 右边距26mm
# 设置页眉页脚距离
section.header_distance = Cm(2.0)
section.footer_distance = Cm(2.0)
# 添加页眉
header = section.header
header_para = header.paragraphs[0]
header_para.alignment = WD_PARAGRAPH_ALIGNMENT.RIGHT
header_run = header_para.add_run("该文档由GPT-academic生成")
header_run.font.name = '仿宋'
header_run._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
header_run.font.size = Pt(9)
def _create_styles(self):
"""创建文档样式"""
# 创建正文样式
style = self.doc.styles.add_style('Normal_Custom', WD_STYLE_TYPE.PARAGRAPH)
style.font.name = '仿宋'
style._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
style.font.size = Pt(14)
style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
style.paragraph_format.space_after = Pt(0)
style.paragraph_format.first_line_indent = Pt(28)
# 创建各级标题样式
self._create_heading_style('Title_Custom', '方正小标宋简体', 32, WD_PARAGRAPH_ALIGNMENT.CENTER)
self._create_heading_style('Heading1_Custom', '黑体', 22, WD_PARAGRAPH_ALIGNMENT.LEFT)
self._create_heading_style('Heading2_Custom', '黑体', 18, WD_PARAGRAPH_ALIGNMENT.LEFT)
self._create_heading_style('Heading3_Custom', '黑体', 16, WD_PARAGRAPH_ALIGNMENT.LEFT)
def _create_heading_style(self, style_name: str, font_name: str, font_size: int, alignment):
"""创建标题样式"""
style = self.doc.styles.add_style(style_name, WD_STYLE_TYPE.PARAGRAPH)
style.font.name = font_name
style._element.rPr.rFonts.set(qn('w:eastAsia'), font_name)
style.font.size = Pt(font_size)
style.font.bold = True
style.paragraph_format.alignment = alignment
style.paragraph_format.space_before = Pt(12)
style.paragraph_format.space_after = Pt(12)
style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
return style
def _get_heading_number(self, level: int) -> str:
"""
生成标题编号
Args:
level: 标题级别 (0-3)
Returns:
str: 格式化的标题编号
"""
if level == 0: # 主标题不需要编号
return ""
self.numbers[level] += 1 # 增加当前级别的编号
# 重置下级标题编号
for i in range(level + 1, 4):
self.numbers[i] = 0
# 根据级别返回不同格式的编号
if level == 1:
return f"{self.numbers[1]}. "
elif level == 2:
return f"{self.numbers[1]}.{self.numbers[2]} "
elif level == 3:
return f"{self.numbers[1]}.{self.numbers[2]}.{self.numbers[3]} "
return ""
def _add_heading(self, text: str, level: int):
"""
添加带编号的标题
Args:
text: 标题文本
level: 标题级别 (0-3)
"""
style_map = {
0: 'Title_Custom',
1: 'Heading1_Custom',
2: 'Heading2_Custom',
3: 'Heading3_Custom'
}
number = self._get_heading_number(level)
paragraph = self.doc.add_paragraph(style=style_map[level])
if number:
number_run = paragraph.add_run(number)
font_size = 22 if level == 1 else (18 if level == 2 else 16)
self._get_run_style(number_run, '黑体', font_size, True)
text_run = paragraph.add_run(text)
font_size = 32 if level == 0 else (22 if level == 1 else (18 if level == 2 else 16))
self._get_run_style(text_run, '黑体', font_size, True)
# 主标题添加日期
if level == 0:
date_paragraph = self.doc.add_paragraph()
date_paragraph.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
date_run = date_paragraph.add_run(datetime.now().strftime('%Y年%m月%d'))
self._get_run_style(date_run, '仿宋', 16, False)
return paragraph
def _get_run_style(self, run, font_name: str, font_size: int, bold: bool = False):
"""设置文本运行对象的样式"""
run.font.name = font_name
run._element.rPr.rFonts.set(qn('w:eastAsia'), font_name)
run.font.size = Pt(font_size)
run.font.bold = bold
def format_failed_files(self) -> str:
"""格式化失败文件列表"""
result = []
if not self.failed_files:
return "\n".join(result)
result.append("处理失败文件:")
for fp, reason in self.failed_files:
result.append(f"{os.path.basename(fp)}: {reason}")
self._add_heading("处理失败文件", 1)
for fp, reason in self.failed_files:
self._add_content(f"{os.path.basename(fp)}: {reason}", indent=False)
self.doc.add_paragraph()
return "\n".join(result)
def _add_content(self, text: str, indent: bool = True):
"""添加正文内容,使用convert_markdown_to_word处理文本"""
# 使用convert_markdown_to_word处理markdown文本
processed_text = convert_markdown_to_word(text)
paragraph = self.doc.add_paragraph(processed_text, style='Normal_Custom')
if not indent:
paragraph.paragraph_format.first_line_indent = Pt(0)
return paragraph
def format_file_summaries(self) -> str:
"""
格式化文件总结内容,确保正确的标题层级并处理markdown文本
"""
result = []
# 首先对文件路径进行分组整理
file_groups = {}
for path in sorted(self.file_summaries_map.keys()):
dir_path = os.path.dirname(path)
if dir_path not in file_groups:
file_groups[dir_path] = []
file_groups[dir_path].append(path)
# 处理没有目录的文件
root_files = file_groups.get("", [])
if root_files:
for path in sorted(root_files):
file_name = os.path.basename(path)
result.append(f"\n📄 {file_name}")
result.append(self.file_summaries_map[path])
# 无目录的文件作为二级标题
self._add_heading(f"📄 {file_name}", 2)
# 使用convert_markdown_to_word处理文件内容
self._add_content(convert_markdown_to_word(self.file_summaries_map[path]))
self.doc.add_paragraph()
# 处理有目录的文件
for dir_path in sorted(file_groups.keys()):
if dir_path == "": # 跳过已处理的根目录文件
continue
# 添加目录作为二级标题
result.append(f"\n📁 {dir_path}")
self._add_heading(f"📁 {dir_path}", 2)
# 该目录下的所有文件作为三级标题
for path in sorted(file_groups[dir_path]):
file_name = os.path.basename(path)
result.append(f"\n📄 {file_name}")
result.append(self.file_summaries_map[path])
# 添加文件名作为三级标题
self._add_heading(f"📄 {file_name}", 3)
# 使用convert_markdown_to_word处理文件内容
self._add_content(convert_markdown_to_word(self.file_summaries_map[path]))
self.doc.add_paragraph()
return "\n".join(result)
def create_document(self):
"""创建完整Word文档并返回文档对象"""
# 重置所有编号
for level in self.numbers:
self.numbers[level] = 0
# 添加主标题
self._add_heading("文档总结报告", 0)
self.doc.add_paragraph()
# 添加总体摘要,使用convert_markdown_to_word处理
self._add_heading("总体摘要", 1)
self._add_content(convert_markdown_to_word(self.final_summary))
self.doc.add_paragraph()
# 添加失败文件列表(如果有)
if self.failed_files:
self.format_failed_files()
# 添加文件详细总结
self._add_heading("各文件详细总结", 1)
self.format_file_summaries()
return self.doc
def save_as_pdf(self, word_path, pdf_path=None):
"""将生成的Word文档转换为PDF
参数:
word_path: Word文档的路径
pdf_path: 可选,PDF文件的输出路径。如果未指定,将使用与Word文档相同的名称和位置
返回:
生成的PDF文件路径,如果转换失败则返回None
"""
from crazy_functions.doc_fns.conversation_doc.word2pdf import WordToPdfConverter
try:
pdf_path = WordToPdfConverter.convert_to_pdf(word_path, pdf_path)
return pdf_path
except Exception as e:
print(f"PDF转换失败: {str(e)}")
return None
class MarkdownFormatter(DocumentFormatter):
"""Markdown格式文档生成器"""
def format_failed_files(self) -> str:
if not self.failed_files:
return ""
formatted_text = ["\n## ⚠️ 处理失败的文件"]
for fp, reason in self.failed_files:
formatted_text.append(f"- {os.path.basename(fp)}: {reason}")
formatted_text.append("\n---")
return "\n".join(formatted_text)
def format_file_summaries(self) -> str:
formatted_text = []
sorted_paths = sorted(self.file_summaries_map.keys())
current_dir = ""
for path in sorted_paths:
dir_path = os.path.dirname(path)
if dir_path != current_dir:
if dir_path:
formatted_text.append(f"\n## 📁 {dir_path}")
current_dir = dir_path
file_name = os.path.basename(path)
formatted_text.append(f"\n### 📄 {file_name}")
formatted_text.append(self.file_summaries_map[path])
formatted_text.append("\n---")
return "\n".join(formatted_text)
def create_document(self) -> str:
document = [
"# 📑 文档总结报告",
"\n## 总体摘要",
self.final_summary
]
if self.failed_files:
document.append(self.format_failed_files())
document.extend([
"\n# 📚 各文件详细总结",
self.format_file_summaries()
])
return "\n".join(document)
class HtmlFormatter(DocumentFormatter):
"""HTML格式文档生成器 - 优化版"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.md = markdown.Markdown(extensions=['extra','codehilite', 'tables','nl2br'])
self.css_styles = """
@keyframes fadeIn {
from { opacity: 0; transform: translateY(20px); }
to { opacity: 1; transform: translateY(0); }
}
@keyframes slideIn {
from { transform: translateX(-20px); opacity: 0; }
to { transform: translateX(0); opacity: 1; }
}
@keyframes pulse {
0% { transform: scale(1); }
50% { transform: scale(1.05); }
100% { transform: scale(1); }
}
:root {
/* Enhanced color palette */
--primary-color: #2563eb;
--primary-light: #eff6ff;
--secondary-color: #1e293b;
--background-color: #f8fafc;
--text-color: #334155;
--text-light: #64748b;
--border-color: #e2e8f0;
--error-color: #ef4444;
--error-light: #fef2f2;
--success-color: #22c55e;
--warning-color: #f59e0b;
--card-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1);
--hover-shadow: 0 20px 25px -5px rgb(0 0 0 / 0.1), 0 8px 10px -6px rgb(0 0 0 / 0.1);
/* Typography */
--heading-font: "Plus Jakarta Sans", system-ui, sans-serif;
--body-font: "Inter", system-ui, sans-serif;
}
body {
font-family: var(--body-font);
line-height: 1.8;
max-width: 1200px;
margin: 0 auto;
padding: 2rem;
color: var(--text-color);
background-color: var(--background-color);
font-size: 16px;
-webkit-font-smoothing: antialiased;
}
.container {
background: white;
padding: 3rem;
border-radius: 24px;
box-shadow: var(--card-shadow);
transition: all 0.4s cubic-bezier(0.4, 0, 0.2, 1);
animation: fadeIn 0.6s ease-out;
border: 1px solid var(--border-color);
}
.container:hover {
box-shadow: var(--hover-shadow);
transform: translateY(-2px);
}
h1, h2, h3 {
font-family: var(--heading-font);
font-weight: 600;
}
h1 {
color: var(--primary-color);
font-size: 2.8em;
text-align: center;
margin: 2rem 0 3rem;
padding-bottom: 1.5rem;
border-bottom: 3px solid var(--primary-color);
letter-spacing: -0.03em;
position: relative;
display: flex;
align-items: center;
justify-content: center;
gap: 1rem;
}
h1::after {
content: '';
position: absolute;
bottom: -3px;
left: 50%;
transform: translateX(-50%);
width: 120px;
height: 3px;
background: linear-gradient(90deg, var(--primary-color), var(--primary-light));
border-radius: 3px;
transition: width 0.3s ease;
}
h1:hover::after {
width: 180px;
}
h2 {
color: var(--secondary-color);
font-size: 1.9em;
margin: 2.5rem 0 1.5rem;
padding-left: 1.2rem;
border-left: 4px solid var(--primary-color);
letter-spacing: -0.02em;
display: flex;
align-items: center;
gap: 1rem;
transition: all 0.3s ease;
}
h2:hover {
color: var(--primary-color);
transform: translateX(5px);
}
h3 {
color: var(--text-color);
font-size: 1.5em;
margin: 2rem 0 1rem;
padding-bottom: 0.8rem;
border-bottom: 2px solid var(--border-color);
transition: all 0.3s ease;
display: flex;
align-items: center;
gap: 0.8rem;
}
h3:hover {
color: var(--primary-color);
border-bottom-color: var(--primary-color);
}
.summary {
background: var(--primary-light);
padding: 2.5rem;
border-radius: 16px;
margin: 2.5rem 0;
box-shadow: 0 4px 6px -1px rgba(37, 99, 235, 0.1);
position: relative;
overflow: hidden;
transition: transform 0.3s ease, box-shadow 0.3s ease;
animation: slideIn 0.5s ease-out;
}
.summary:hover {
transform: translateY(-3px);
box-shadow: 0 8px 12px -2px rgba(37, 99, 235, 0.15);
}
.summary::before {
content: '';
position: absolute;
top: 0;
left: 0;
width: 4px;
height: 100%;
background: linear-gradient(to bottom, var(--primary-color), rgba(37, 99, 235, 0.6));
}
.summary p {
margin: 1.2rem 0;
line-height: 1.9;
color: var(--text-color);
transition: color 0.3s ease;
}
.summary:hover p {
color: var(--secondary-color);
}
.details {
margin-top: 3.5rem;
padding-top: 2.5rem;
border-top: 2px dashed var(--border-color);
animation: fadeIn 0.8s ease-out;
}
.failed-files {
background: var(--error-light);
padding: 2rem;
border-radius: 16px;
margin: 3rem 0;
border-left: 4px solid var(--error-color);
position: relative;
transition: all 0.3s ease;
animation: slideIn 0.5s ease-out;
}
.failed-files:hover {
transform: translateX(5px);
box-shadow: 0 8px 15px -3px rgba(239, 68, 68, 0.1);
}
.failed-files h2 {
color: var(--error-color);
border-left: none;
padding-left: 0;
}
.failed-files ul {
margin: 1.8rem 0;
padding-left: 1.2rem;
list-style-type: none;
}
.failed-files li {
margin: 1.2rem 0;
padding: 1.2rem 1.8rem;
background: rgba(239, 68, 68, 0.08);
border-radius: 12px;
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
}
.failed-files li:hover {
transform: translateX(8px);
background: rgba(239, 68, 68, 0.12);
}
.directory-section {
margin: 3.5rem 0;
padding: 2rem;
background: var(--background-color);
border-radius: 16px;
position: relative;
transition: all 0.3s ease;
animation: fadeIn 0.6s ease-out;
}
.directory-section:hover {
background: white;
box-shadow: var(--card-shadow);
}
.file-summary {
background: white;
padding: 2rem;
margin: 1.8rem 0;
border-radius: 16px;
box-shadow: var(--card-shadow);
border-left: 4px solid var(--border-color);
transition: all 0.4s cubic-bezier(0.4, 0, 0.2, 1);
position: relative;
overflow: hidden;
}
.file-summary:hover {
border-left-color: var(--primary-color);
transform: translateX(8px) translateY(-2px);
box-shadow: var(--hover-shadow);
}
.file-summary {
background: white;
padding: 2rem;
margin: 1.8rem 0;
border-radius: 16px;
box-shadow: var(--card-shadow);
border-left: 4px solid var(--border-color);
transition: all 0.4s cubic-bezier(0.4, 0, 0.2, 1);
position: relative;
}
.file-summary:hover {
border-left-color: var(--primary-color);
transform: translateX(8px) translateY(-2px);
box-shadow: var(--hover-shadow);
}
.icon {
display: inline-flex;
align-items: center;
justify-content: center;
width: 32px;
height: 32px;
border-radius: 8px;
background: var(--primary-light);
color: var(--primary-color);
font-size: 1.2em;
transition: all 0.3s ease;
}
.file-summary:hover .icon,
.directory-section:hover .icon {
transform: scale(1.1);
background: var(--primary-color);
color: white;
}
/* Smooth scrolling */
html {
scroll-behavior: smooth;
}
/* Selection style */
::selection {
background: var(--primary-light);
color: var(--primary-color);
}
/* Print styles */
@media print {
body {
background: white;
}
.container {
box-shadow: none;
padding: 0;
}
.file-summary, .failed-files {
break-inside: avoid;
box-shadow: none;
}
.icon {
display: none;
}
}
/* Responsive design */
@media (max-width: 768px) {
body {
padding: 1rem;
font-size: 15px;
}
.container {
padding: 1.5rem;
}
h1 {
font-size: 2.2em;
margin: 1.5rem 0 2rem;
}
h2 {
font-size: 1.7em;
}
h3 {
font-size: 1.4em;
}
.summary, .failed-files, .directory-section {
padding: 1.5rem;
}
.file-summary {
padding: 1.2rem;
}
.icon {
width: 28px;
height: 28px;
}
}
/* Dark mode support */
@media (prefers-color-scheme: dark) {
:root {
--primary-light: rgba(37, 99, 235, 0.15);
--background-color: #0f172a;
--text-color: #e2e8f0;
--text-light: #94a3b8;
--border-color: #1e293b;
--error-light: rgba(239, 68, 68, 0.15);
}
.container, .file-summary {
background: #1e293b;
}
.directory-section {
background: #0f172a;
}
.directory-section:hover {
background: #1e293b;
}
}
"""
def format_failed_files(self) -> str:
if not self.failed_files:
return ""
failed_files_html = ['<div class="failed-files">']
failed_files_html.append('<h2><span class="icon">⚠️</span> 处理失败的文件</h2>')
failed_files_html.append("<ul>")
for fp, reason in self.failed_files:
failed_files_html.append(
f'<li><strong>📄 {os.path.basename(fp)}</strong><br><span style="color: var(--text-light)">{reason}</span></li>'
)
failed_files_html.append("</ul></div>")
return "\n".join(failed_files_html)
def format_file_summaries(self) -> str:
formatted_html = []
sorted_paths = sorted(self.file_summaries_map.keys())
current_dir = ""
for path in sorted_paths:
dir_path = os.path.dirname(path)
if dir_path != current_dir:
if dir_path:
formatted_html.append('<div class="directory-section">')
formatted_html.append(f'<h2><span class="icon">📁</span> {dir_path}</h2>')
formatted_html.append('</div>')
current_dir = dir_path
file_name = os.path.basename(path)
formatted_html.append('<div class="file-summary">')
formatted_html.append(f'<h3><span class="icon">📄</span> {file_name}</h3>')
formatted_html.append(self.md.convert(self.file_summaries_map[path]))
formatted_html.append('</div>')
return "\n".join(formatted_html)
def create_document(self) -> str:
"""生成HTML文档
Returns:
str: 完整的HTML文档字符串
"""
return f"""
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>文档总结报告</title>
<link href="https://cdnjs.cloudflare.com/ajax/libs/inter/3.19.3/inter.css" rel="stylesheet">
<link href="https://fonts.googleapis.com/css2?family=Plus+Jakarta+Sans:wght@400;600&display=swap" rel="stylesheet">
<style>{self.css_styles}</style>
</head>
<body>
<div class="container">
<h1><span class="icon">📑</span> 文档总结报告</h1>
<div class="summary">
<h2><span class="icon">📋</span> 总体摘要</h2>
<p>{self.md.convert(self.final_summary)}</p>
</div>
{self.format_failed_files()}
<div class="details">
<h2><span class="icon">📚</span> 各文件详细总结</h2>
{self.format_file_summaries()}
</div>
</div>
</body>
</html>
"""

查看文件

查看文件

@@ -0,0 +1,812 @@
import os
import time
from abc import ABC, abstractmethod
from datetime import datetime
from docx import Document
from docx.enum.style import WD_STYLE_TYPE
from docx.enum.text import WD_PARAGRAPH_ALIGNMENT, WD_LINE_SPACING
from docx.oxml.ns import qn
from docx.shared import Inches, Cm
from docx.shared import Pt, RGBColor, Inches
from typing import Dict, List, Tuple
import markdown
from crazy_functions.doc_fns.conversation_doc.word_doc import convert_markdown_to_word
class DocumentFormatter(ABC):
"""文档格式化基类,定义文档格式化的基本接口"""
def __init__(self, final_summary: str, file_summaries_map: Dict, failed_files: List[Tuple]):
self.final_summary = final_summary
self.file_summaries_map = file_summaries_map
self.failed_files = failed_files
@abstractmethod
def format_failed_files(self) -> str:
"""格式化失败文件列表"""
pass
@abstractmethod
def format_file_summaries(self) -> str:
"""格式化文件总结内容"""
pass
@abstractmethod
def create_document(self) -> str:
"""创建完整文档"""
pass
class WordFormatter(DocumentFormatter):
"""Word格式文档生成器 - 符合中国政府公文格式规范(GB/T 9704-2012),并进行了优化"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.doc = Document()
self._setup_document()
self._create_styles()
# 初始化三级标题编号系统
self.numbers = {
1: 0, # 一级标题编号
2: 0, # 二级标题编号
3: 0 # 三级标题编号
}
def _setup_document(self):
"""设置文档基本格式,包括页面设置和页眉"""
sections = self.doc.sections
for section in sections:
# 设置页面大小为A4
section.page_width = Cm(21)
section.page_height = Cm(29.7)
# 设置页边距
section.top_margin = Cm(3.7) # 上边距37mm
section.bottom_margin = Cm(3.5) # 下边距35mm
section.left_margin = Cm(2.8) # 左边距28mm
section.right_margin = Cm(2.6) # 右边距26mm
# 设置页眉页脚距离
section.header_distance = Cm(2.0)
section.footer_distance = Cm(2.0)
# 添加页眉
header = section.header
header_para = header.paragraphs[0]
header_para.alignment = WD_PARAGRAPH_ALIGNMENT.RIGHT
header_run = header_para.add_run("该文档由GPT-academic生成")
header_run.font.name = '仿宋'
header_run._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
header_run.font.size = Pt(9)
def _create_styles(self):
"""创建文档样式"""
# 创建正文样式
style = self.doc.styles.add_style('Normal_Custom', WD_STYLE_TYPE.PARAGRAPH)
style.font.name = '仿宋'
style._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
style.font.size = Pt(14)
style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
style.paragraph_format.space_after = Pt(0)
style.paragraph_format.first_line_indent = Pt(28)
# 创建各级标题样式
self._create_heading_style('Title_Custom', '方正小标宋简体', 32, WD_PARAGRAPH_ALIGNMENT.CENTER)
self._create_heading_style('Heading1_Custom', '黑体', 22, WD_PARAGRAPH_ALIGNMENT.LEFT)
self._create_heading_style('Heading2_Custom', '黑体', 18, WD_PARAGRAPH_ALIGNMENT.LEFT)
self._create_heading_style('Heading3_Custom', '黑体', 16, WD_PARAGRAPH_ALIGNMENT.LEFT)
def _create_heading_style(self, style_name: str, font_name: str, font_size: int, alignment):
"""创建标题样式"""
style = self.doc.styles.add_style(style_name, WD_STYLE_TYPE.PARAGRAPH)
style.font.name = font_name
style._element.rPr.rFonts.set(qn('w:eastAsia'), font_name)
style.font.size = Pt(font_size)
style.font.bold = True
style.paragraph_format.alignment = alignment
style.paragraph_format.space_before = Pt(12)
style.paragraph_format.space_after = Pt(12)
style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
return style
def _get_heading_number(self, level: int) -> str:
"""
生成标题编号
Args:
level: 标题级别 (0-3)
Returns:
str: 格式化的标题编号
"""
if level == 0: # 主标题不需要编号
return ""
self.numbers[level] += 1 # 增加当前级别的编号
# 重置下级标题编号
for i in range(level + 1, 4):
self.numbers[i] = 0
# 根据级别返回不同格式的编号
if level == 1:
return f"{self.numbers[1]}. "
elif level == 2:
return f"{self.numbers[1]}.{self.numbers[2]} "
elif level == 3:
return f"{self.numbers[1]}.{self.numbers[2]}.{self.numbers[3]} "
return ""
def _add_heading(self, text: str, level: int):
"""
添加带编号的标题
Args:
text: 标题文本
level: 标题级别 (0-3)
"""
style_map = {
0: 'Title_Custom',
1: 'Heading1_Custom',
2: 'Heading2_Custom',
3: 'Heading3_Custom'
}
number = self._get_heading_number(level)
paragraph = self.doc.add_paragraph(style=style_map[level])
if number:
number_run = paragraph.add_run(number)
font_size = 22 if level == 1 else (18 if level == 2 else 16)
self._get_run_style(number_run, '黑体', font_size, True)
text_run = paragraph.add_run(text)
font_size = 32 if level == 0 else (22 if level == 1 else (18 if level == 2 else 16))
self._get_run_style(text_run, '黑体', font_size, True)
# 主标题添加日期
if level == 0:
date_paragraph = self.doc.add_paragraph()
date_paragraph.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
date_run = date_paragraph.add_run(datetime.now().strftime('%Y年%m月%d'))
self._get_run_style(date_run, '仿宋', 16, False)
return paragraph
def _get_run_style(self, run, font_name: str, font_size: int, bold: bool = False):
"""设置文本运行对象的样式"""
run.font.name = font_name
run._element.rPr.rFonts.set(qn('w:eastAsia'), font_name)
run.font.size = Pt(font_size)
run.font.bold = bold
def format_failed_files(self) -> str:
"""格式化失败文件列表"""
result = []
if not self.failed_files:
return "\n".join(result)
result.append("处理失败文件:")
for fp, reason in self.failed_files:
result.append(f"{os.path.basename(fp)}: {reason}")
self._add_heading("处理失败文件", 1)
for fp, reason in self.failed_files:
self._add_content(f"{os.path.basename(fp)}: {reason}", indent=False)
self.doc.add_paragraph()
return "\n".join(result)
def _add_content(self, text: str, indent: bool = True):
"""添加正文内容,使用convert_markdown_to_word处理文本"""
# 使用convert_markdown_to_word处理markdown文本
processed_text = convert_markdown_to_word(text)
paragraph = self.doc.add_paragraph(processed_text, style='Normal_Custom')
if not indent:
paragraph.paragraph_format.first_line_indent = Pt(0)
return paragraph
def format_file_summaries(self) -> str:
"""
格式化文件总结内容,确保正确的标题层级并处理markdown文本
"""
result = []
# 首先对文件路径进行分组整理
file_groups = {}
for path in sorted(self.file_summaries_map.keys()):
dir_path = os.path.dirname(path)
if dir_path not in file_groups:
file_groups[dir_path] = []
file_groups[dir_path].append(path)
# 处理没有目录的文件
root_files = file_groups.get("", [])
if root_files:
for path in sorted(root_files):
file_name = os.path.basename(path)
result.append(f"\n📄 {file_name}")
result.append(self.file_summaries_map[path])
# 无目录的文件作为二级标题
self._add_heading(f"📄 {file_name}", 2)
# 使用convert_markdown_to_word处理文件内容
self._add_content(convert_markdown_to_word(self.file_summaries_map[path]))
self.doc.add_paragraph()
# 处理有目录的文件
for dir_path in sorted(file_groups.keys()):
if dir_path == "": # 跳过已处理的根目录文件
continue
# 添加目录作为二级标题
result.append(f"\n📁 {dir_path}")
self._add_heading(f"📁 {dir_path}", 2)
# 该目录下的所有文件作为三级标题
for path in sorted(file_groups[dir_path]):
file_name = os.path.basename(path)
result.append(f"\n📄 {file_name}")
result.append(self.file_summaries_map[path])
# 添加文件名作为三级标题
self._add_heading(f"📄 {file_name}", 3)
# 使用convert_markdown_to_word处理文件内容
self._add_content(convert_markdown_to_word(self.file_summaries_map[path]))
self.doc.add_paragraph()
return "\n".join(result)
def create_document(self):
"""创建完整Word文档并返回文档对象"""
# 重置所有编号
for level in self.numbers:
self.numbers[level] = 0
# 添加主标题
self._add_heading("文档总结报告", 0)
self.doc.add_paragraph()
# 添加总体摘要,使用convert_markdown_to_word处理
self._add_heading("总体摘要", 1)
self._add_content(convert_markdown_to_word(self.final_summary))
self.doc.add_paragraph()
# 添加失败文件列表(如果有)
if self.failed_files:
self.format_failed_files()
# 添加文件详细总结
self._add_heading("各文件详细总结", 1)
self.format_file_summaries()
return self.doc
def save_as_pdf(self, word_path, pdf_path=None):
"""将生成的Word文档转换为PDF
参数:
word_path: Word文档的路径
pdf_path: 可选,PDF文件的输出路径。如果未指定,将使用与Word文档相同的名称和位置
返回:
生成的PDF文件路径,如果转换失败则返回None
"""
from crazy_functions.doc_fns.conversation_doc.word2pdf import WordToPdfConverter
try:
pdf_path = WordToPdfConverter.convert_to_pdf(word_path, pdf_path)
return pdf_path
except Exception as e:
print(f"PDF转换失败: {str(e)}")
return None
class MarkdownFormatter(DocumentFormatter):
"""Markdown格式文档生成器"""
def format_failed_files(self) -> str:
if not self.failed_files:
return ""
formatted_text = ["\n## ⚠️ 处理失败的文件"]
for fp, reason in self.failed_files:
formatted_text.append(f"- {os.path.basename(fp)}: {reason}")
formatted_text.append("\n---")
return "\n".join(formatted_text)
def format_file_summaries(self) -> str:
formatted_text = []
sorted_paths = sorted(self.file_summaries_map.keys())
current_dir = ""
for path in sorted_paths:
dir_path = os.path.dirname(path)
if dir_path != current_dir:
if dir_path:
formatted_text.append(f"\n## 📁 {dir_path}")
current_dir = dir_path
file_name = os.path.basename(path)
formatted_text.append(f"\n### 📄 {file_name}")
formatted_text.append(self.file_summaries_map[path])
formatted_text.append("\n---")
return "\n".join(formatted_text)
def create_document(self) -> str:
document = [
"# 📑 文档总结报告",
"\n## 总体摘要",
self.final_summary
]
if self.failed_files:
document.append(self.format_failed_files())
document.extend([
"\n# 📚 各文件详细总结",
self.format_file_summaries()
])
return "\n".join(document)
class HtmlFormatter(DocumentFormatter):
"""HTML格式文档生成器 - 优化版"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.md = markdown.Markdown(extensions=['extra','codehilite', 'tables','nl2br'])
self.css_styles = """
@keyframes fadeIn {
from { opacity: 0; transform: translateY(20px); }
to { opacity: 1; transform: translateY(0); }
}
@keyframes slideIn {
from { transform: translateX(-20px); opacity: 0; }
to { transform: translateX(0); opacity: 1; }
}
@keyframes pulse {
0% { transform: scale(1); }
50% { transform: scale(1.05); }
100% { transform: scale(1); }
}
:root {
/* Enhanced color palette */
--primary-color: #2563eb;
--primary-light: #eff6ff;
--secondary-color: #1e293b;
--background-color: #f8fafc;
--text-color: #334155;
--text-light: #64748b;
--border-color: #e2e8f0;
--error-color: #ef4444;
--error-light: #fef2f2;
--success-color: #22c55e;
--warning-color: #f59e0b;
--card-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1);
--hover-shadow: 0 20px 25px -5px rgb(0 0 0 / 0.1), 0 8px 10px -6px rgb(0 0 0 / 0.1);
/* Typography */
--heading-font: "Plus Jakarta Sans", system-ui, sans-serif;
--body-font: "Inter", system-ui, sans-serif;
}
body {
font-family: var(--body-font);
line-height: 1.8;
max-width: 1200px;
margin: 0 auto;
padding: 2rem;
color: var(--text-color);
background-color: var(--background-color);
font-size: 16px;
-webkit-font-smoothing: antialiased;
}
.container {
background: white;
padding: 3rem;
border-radius: 24px;
box-shadow: var(--card-shadow);
transition: all 0.4s cubic-bezier(0.4, 0, 0.2, 1);
animation: fadeIn 0.6s ease-out;
border: 1px solid var(--border-color);
}
.container:hover {
box-shadow: var(--hover-shadow);
transform: translateY(-2px);
}
h1, h2, h3 {
font-family: var(--heading-font);
font-weight: 600;
}
h1 {
color: var(--primary-color);
font-size: 2.8em;
text-align: center;
margin: 2rem 0 3rem;
padding-bottom: 1.5rem;
border-bottom: 3px solid var(--primary-color);
letter-spacing: -0.03em;
position: relative;
display: flex;
align-items: center;
justify-content: center;
gap: 1rem;
}
h1::after {
content: '';
position: absolute;
bottom: -3px;
left: 50%;
transform: translateX(-50%);
width: 120px;
height: 3px;
background: linear-gradient(90deg, var(--primary-color), var(--primary-light));
border-radius: 3px;
transition: width 0.3s ease;
}
h1:hover::after {
width: 180px;
}
h2 {
color: var(--secondary-color);
font-size: 1.9em;
margin: 2.5rem 0 1.5rem;
padding-left: 1.2rem;
border-left: 4px solid var(--primary-color);
letter-spacing: -0.02em;
display: flex;
align-items: center;
gap: 1rem;
transition: all 0.3s ease;
}
h2:hover {
color: var(--primary-color);
transform: translateX(5px);
}
h3 {
color: var(--text-color);
font-size: 1.5em;
margin: 2rem 0 1rem;
padding-bottom: 0.8rem;
border-bottom: 2px solid var(--border-color);
transition: all 0.3s ease;
display: flex;
align-items: center;
gap: 0.8rem;
}
h3:hover {
color: var(--primary-color);
border-bottom-color: var(--primary-color);
}
.summary {
background: var(--primary-light);
padding: 2.5rem;
border-radius: 16px;
margin: 2.5rem 0;
box-shadow: 0 4px 6px -1px rgba(37, 99, 235, 0.1);
position: relative;
overflow: hidden;
transition: transform 0.3s ease, box-shadow 0.3s ease;
animation: slideIn 0.5s ease-out;
}
.summary:hover {
transform: translateY(-3px);
box-shadow: 0 8px 12px -2px rgba(37, 99, 235, 0.15);
}
.summary::before {
content: '';
position: absolute;
top: 0;
left: 0;
width: 4px;
height: 100%;
background: linear-gradient(to bottom, var(--primary-color), rgba(37, 99, 235, 0.6));
}
.summary p {
margin: 1.2rem 0;
line-height: 1.9;
color: var(--text-color);
transition: color 0.3s ease;
}
.summary:hover p {
color: var(--secondary-color);
}
.details {
margin-top: 3.5rem;
padding-top: 2.5rem;
border-top: 2px dashed var(--border-color);
animation: fadeIn 0.8s ease-out;
}
.failed-files {
background: var(--error-light);
padding: 2rem;
border-radius: 16px;
margin: 3rem 0;
border-left: 4px solid var(--error-color);
position: relative;
transition: all 0.3s ease;
animation: slideIn 0.5s ease-out;
}
.failed-files:hover {
transform: translateX(5px);
box-shadow: 0 8px 15px -3px rgba(239, 68, 68, 0.1);
}
.failed-files h2 {
color: var(--error-color);
border-left: none;
padding-left: 0;
}
.failed-files ul {
margin: 1.8rem 0;
padding-left: 1.2rem;
list-style-type: none;
}
.failed-files li {
margin: 1.2rem 0;
padding: 1.2rem 1.8rem;
background: rgba(239, 68, 68, 0.08);
border-radius: 12px;
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
}
.failed-files li:hover {
transform: translateX(8px);
background: rgba(239, 68, 68, 0.12);
}
.directory-section {
margin: 3.5rem 0;
padding: 2rem;
background: var(--background-color);
border-radius: 16px;
position: relative;
transition: all 0.3s ease;
animation: fadeIn 0.6s ease-out;
}
.directory-section:hover {
background: white;
box-shadow: var(--card-shadow);
}
.file-summary {
background: white;
padding: 2rem;
margin: 1.8rem 0;
border-radius: 16px;
box-shadow: var(--card-shadow);
border-left: 4px solid var(--border-color);
transition: all 0.4s cubic-bezier(0.4, 0, 0.2, 1);
position: relative;
overflow: hidden;
}
.file-summary:hover {
border-left-color: var(--primary-color);
transform: translateX(8px) translateY(-2px);
box-shadow: var(--hover-shadow);
}
.file-summary {
background: white;
padding: 2rem;
margin: 1.8rem 0;
border-radius: 16px;
box-shadow: var(--card-shadow);
border-left: 4px solid var(--border-color);
transition: all 0.4s cubic-bezier(0.4, 0, 0.2, 1);
position: relative;
}
.file-summary:hover {
border-left-color: var(--primary-color);
transform: translateX(8px) translateY(-2px);
box-shadow: var(--hover-shadow);
}
.icon {
display: inline-flex;
align-items: center;
justify-content: center;
width: 32px;
height: 32px;
border-radius: 8px;
background: var(--primary-light);
color: var(--primary-color);
font-size: 1.2em;
transition: all 0.3s ease;
}
.file-summary:hover .icon,
.directory-section:hover .icon {
transform: scale(1.1);
background: var(--primary-color);
color: white;
}
/* Smooth scrolling */
html {
scroll-behavior: smooth;
}
/* Selection style */
::selection {
background: var(--primary-light);
color: var(--primary-color);
}
/* Print styles */
@media print {
body {
background: white;
}
.container {
box-shadow: none;
padding: 0;
}
.file-summary, .failed-files {
break-inside: avoid;
box-shadow: none;
}
.icon {
display: none;
}
}
/* Responsive design */
@media (max-width: 768px) {
body {
padding: 1rem;
font-size: 15px;
}
.container {
padding: 1.5rem;
}
h1 {
font-size: 2.2em;
margin: 1.5rem 0 2rem;
}
h2 {
font-size: 1.7em;
}
h3 {
font-size: 1.4em;
}
.summary, .failed-files, .directory-section {
padding: 1.5rem;
}
.file-summary {
padding: 1.2rem;
}
.icon {
width: 28px;
height: 28px;
}
}
/* Dark mode support */
@media (prefers-color-scheme: dark) {
:root {
--primary-light: rgba(37, 99, 235, 0.15);
--background-color: #0f172a;
--text-color: #e2e8f0;
--text-light: #94a3b8;
--border-color: #1e293b;
--error-light: rgba(239, 68, 68, 0.15);
}
.container, .file-summary {
background: #1e293b;
}
.directory-section {
background: #0f172a;
}
.directory-section:hover {
background: #1e293b;
}
}
"""
def format_failed_files(self) -> str:
if not self.failed_files:
return ""
failed_files_html = ['<div class="failed-files">']
failed_files_html.append('<h2><span class="icon">⚠️</span> 处理失败的文件</h2>')
failed_files_html.append("<ul>")
for fp, reason in self.failed_files:
failed_files_html.append(
f'<li><strong>📄 {os.path.basename(fp)}</strong><br><span style="color: var(--text-light)">{reason}</span></li>'
)
failed_files_html.append("</ul></div>")
return "\n".join(failed_files_html)
def format_file_summaries(self) -> str:
formatted_html = []
sorted_paths = sorted(self.file_summaries_map.keys())
current_dir = ""
for path in sorted_paths:
dir_path = os.path.dirname(path)
if dir_path != current_dir:
if dir_path:
formatted_html.append('<div class="directory-section">')
formatted_html.append(f'<h2><span class="icon">📁</span> {dir_path}</h2>')
formatted_html.append('</div>')
current_dir = dir_path
file_name = os.path.basename(path)
formatted_html.append('<div class="file-summary">')
formatted_html.append(f'<h3><span class="icon">📄</span> {file_name}</h3>')
formatted_html.append(self.md.convert(self.file_summaries_map[path]))
formatted_html.append('</div>')
return "\n".join(formatted_html)
def create_document(self) -> str:
"""生成HTML文档
Returns:
str: 完整的HTML文档字符串
"""
return f"""
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>文档总结报告</title>
<link href="https://cdnjs.cloudflare.com/ajax/libs/inter/3.19.3/inter.css" rel="stylesheet">
<link href="https://fonts.googleapis.com/css2?family=Plus+Jakarta+Sans:wght@400;600&display=swap" rel="stylesheet">
<style>{self.css_styles}</style>
</head>
<body>
<div class="container">
<h1><span class="icon">📑</span> 文档总结报告</h1>
<div class="summary">
<h2><span class="icon">📋</span> 总体摘要</h2>
<p>{self.md.convert(self.final_summary)}</p>
</div>
{self.format_failed_files()}
<div class="details">
<h2><span class="icon">📚</span> 各文件详细总结</h2>
{self.format_file_summaries()}
</div>
</div>
</body>
</html>
"""

查看文件

@@ -0,0 +1,237 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Type, TypeVar, Generic, Union
from dataclasses import dataclass
from enum import Enum, auto
import logging
from datetime import datetime
# 设置日志
logger = logging.getLogger(__name__)
# 自定义异常类定义
class FoldingError(Exception):
"""折叠相关的自定义异常基类"""
pass
class FormattingError(FoldingError):
"""格式化过程中的错误"""
pass
class MetadataError(FoldingError):
"""元数据相关的错误"""
pass
class ValidationError(FoldingError):
"""验证错误"""
pass
class FoldingStyle(Enum):
"""折叠样式枚举"""
SIMPLE = auto() # 简单折叠
DETAILED = auto() # 详细折叠(带有额外信息)
NESTED = auto() # 嵌套折叠
@dataclass
class FoldingOptions:
"""折叠选项配置"""
style: FoldingStyle = FoldingStyle.DETAILED
code_language: Optional[str] = None # 代码块的语言
show_timestamp: bool = False # 是否显示时间戳
indent_level: int = 0 # 缩进级别
custom_css: Optional[str] = None # 自定义CSS类
T = TypeVar('T') # 用于泛型类型
class BaseMetadata(ABC):
"""元数据基类"""
@abstractmethod
def validate(self) -> bool:
"""验证元数据的有效性"""
pass
def _validate_non_empty_str(self, value: Optional[str]) -> bool:
"""验证字符串非空"""
return bool(value and value.strip())
@dataclass
class FileMetadata(BaseMetadata):
"""文件元数据"""
rel_path: str
size: float
last_modified: Optional[datetime] = None
mime_type: Optional[str] = None
encoding: str = 'utf-8'
def validate(self) -> bool:
"""验证文件元数据的有效性"""
try:
if not self._validate_non_empty_str(self.rel_path):
return False
if self.size < 0:
return False
return True
except Exception as e:
logger.error(f"File metadata validation error: {str(e)}")
return False
class ContentFormatter(ABC, Generic[T]):
"""内容格式化抽象基类
支持泛型类型参数,可以指定具体的元数据类型。
"""
@abstractmethod
def format(self,
content: str,
metadata: T,
options: Optional[FoldingOptions] = None) -> str:
"""格式化内容
Args:
content: 需要格式化的内容
metadata: 类型化的元数据
options: 折叠选项
Returns:
str: 格式化后的内容
Raises:
FormattingError: 格式化过程中的错误
"""
pass
def _create_summary(self, metadata: T) -> str:
"""创建折叠摘要,可被子类重写"""
return str(metadata)
def _format_content_block(self,
content: str,
options: Optional[FoldingOptions]) -> str:
"""格式化内容块,处理代码块等特殊格式"""
if not options:
return content
if options.code_language:
return f"```{options.code_language}\n{content}\n```"
return content
def _add_indent(self, text: str, level: int) -> str:
"""添加缩进"""
if level <= 0:
return text
indent = " " * level
return "\n".join(indent + line for line in text.splitlines())
class FileContentFormatter(ContentFormatter[FileMetadata]):
"""文件内容格式化器"""
def format(self,
content: str,
metadata: FileMetadata,
options: Optional[FoldingOptions] = None) -> str:
"""格式化文件内容"""
if not metadata.validate():
raise MetadataError("Invalid file metadata")
try:
options = options or FoldingOptions()
# 构建摘要信息
summary_parts = [
f"{metadata.rel_path} ({metadata.size:.2f}MB)",
f"Type: {metadata.mime_type}" if metadata.mime_type else None,
(f"Modified: {metadata.last_modified.strftime('%Y-%m-%d %H:%M:%S')}"
if metadata.last_modified and options.show_timestamp else None)
]
summary = " | ".join(filter(None, summary_parts))
# 构建HTML类
css_class = f' class="{options.custom_css}"' if options.custom_css else ''
# 格式化内容
formatted_content = self._format_content_block(content, options)
# 组装最终结果
result = (
f'<details{css_class}><summary>{summary}</summary>\n\n'
f'{formatted_content}\n\n'
f'</details>\n\n'
)
return self._add_indent(result, options.indent_level)
except Exception as e:
logger.error(f"Error formatting file content: {str(e)}")
raise FormattingError(f"Failed to format file content: {str(e)}")
class ContentFoldingManager:
"""内容折叠管理器"""
def __init__(self):
"""初始化折叠管理器"""
self._formatters: Dict[str, ContentFormatter] = {}
self._register_default_formatters()
def _register_default_formatters(self) -> None:
"""注册默认的格式化器"""
self.register_formatter('file', FileContentFormatter())
def register_formatter(self, name: str, formatter: ContentFormatter) -> None:
"""注册新的格式化器"""
if not isinstance(formatter, ContentFormatter):
raise TypeError("Formatter must implement ContentFormatter interface")
self._formatters[name] = formatter
def _guess_language(self, extension: str) -> Optional[str]:
"""根据文件扩展名猜测编程语言"""
extension = extension.lower().lstrip('.')
language_map = {
'py': 'python',
'js': 'javascript',
'java': 'java',
'cpp': 'cpp',
'cs': 'csharp',
'html': 'html',
'css': 'css',
'md': 'markdown',
'json': 'json',
'xml': 'xml',
'sql': 'sql',
'sh': 'bash',
'yaml': 'yaml',
'yml': 'yaml',
'txt': None # 纯文本不需要语言标识
}
return language_map.get(extension)
def format_content(self,
content: str,
formatter_type: str,
metadata: Union[FileMetadata],
options: Optional[FoldingOptions] = None) -> str:
"""格式化内容"""
formatter = self._formatters.get(formatter_type)
if not formatter:
raise KeyError(f"No formatter registered for type: {formatter_type}")
if not isinstance(metadata, FileMetadata):
raise TypeError("Invalid metadata type")
return formatter.format(content, metadata, options)

查看文件

@@ -0,0 +1,211 @@
import re
import os
import pandas as pd
from datetime import datetime
from openpyxl import Workbook
class ExcelTableFormatter:
"""聊天记录中Markdown表格转Excel生成器"""
def __init__(self):
"""初始化Excel文档对象"""
self.workbook = Workbook()
self._table_count = 0
self._current_sheet = None
def _normalize_table_row(self, row):
"""标准化表格行,处理不同的分隔符情况"""
row = row.strip()
if row.startswith('|'):
row = row[1:]
if row.endswith('|'):
row = row[:-1]
return [cell.strip() for cell in row.split('|')]
def _is_separator_row(self, row):
"""检查是否是分隔行(由 - 或 : 组成)"""
clean_row = re.sub(r'[\s|]', '', row)
return bool(re.match(r'^[-:]+$', clean_row))
def _extract_tables_from_text(self, text):
"""从文本中提取所有表格内容"""
if not isinstance(text, str):
return []
tables = []
current_table = []
is_in_table = False
for line in text.split('\n'):
line = line.strip()
if not line:
if is_in_table and current_table:
if len(current_table) >= 2:
tables.append(current_table)
current_table = []
is_in_table = False
continue
if '|' in line:
if not is_in_table:
is_in_table = True
current_table.append(line)
else:
if is_in_table and current_table:
if len(current_table) >= 2:
tables.append(current_table)
current_table = []
is_in_table = False
if is_in_table and current_table and len(current_table) >= 2:
tables.append(current_table)
return tables
def _parse_table(self, table_lines):
"""解析表格内容为结构化数据"""
try:
headers = self._normalize_table_row(table_lines[0])
separator_index = next(
(i for i, line in enumerate(table_lines) if self._is_separator_row(line)),
1
)
data_rows = []
for line in table_lines[separator_index + 1:]:
cells = self._normalize_table_row(line)
# 确保单元格数量与表头一致
while len(cells) < len(headers):
cells.append('')
cells = cells[:len(headers)]
data_rows.append(cells)
if headers and data_rows:
return {
'headers': headers,
'data': data_rows
}
except Exception as e:
print(f"解析表格时发生错误: {str(e)}")
return None
def _create_sheet(self, question_num, table_num):
"""创建新的工作表"""
sheet_name = f'Q{question_num}_T{table_num}'
if len(sheet_name) > 31:
sheet_name = f'Table{self._table_count}'
if sheet_name in self.workbook.sheetnames:
sheet_name = f'{sheet_name}_{datetime.now().strftime("%H%M%S")}'
return self.workbook.create_sheet(title=sheet_name)
def create_document(self, history):
"""
处理聊天历史中的所有表格并创建Excel文档
Args:
history: 聊天历史列表
Returns:
Workbook: 处理完成的Excel工作簿对象,如果没有表格则返回None
"""
has_tables = False
# 删除默认创建的工作表
default_sheet = self.workbook['Sheet']
self.workbook.remove(default_sheet)
# 遍历所有回答
for i in range(1, len(history), 2):
answer = history[i]
tables = self._extract_tables_from_text(answer)
for table_lines in tables:
parsed_table = self._parse_table(table_lines)
if parsed_table:
self._table_count += 1
sheet = self._create_sheet(i // 2 + 1, self._table_count)
# 写入表头
for col, header in enumerate(parsed_table['headers'], 1):
sheet.cell(row=1, column=col, value=header)
# 写入数据
for row_idx, row_data in enumerate(parsed_table['data'], 2):
for col_idx, value in enumerate(row_data, 1):
sheet.cell(row=row_idx, column=col_idx, value=value)
has_tables = True
return self.workbook if has_tables else None
def save_chat_tables(history, save_dir, base_name):
"""
保存聊天历史中的表格到Excel文件
Args:
history: 聊天历史列表
save_dir: 保存目录
base_name: 基础文件名
Returns:
list: 保存的文件路径列表
"""
result_files = []
try:
# 创建Excel格式
excel_formatter = ExcelTableFormatter()
workbook = excel_formatter.create_document(history)
if workbook is not None:
# 确保保存目录存在
os.makedirs(save_dir, exist_ok=True)
# 生成Excel文件路径
excel_file = os.path.join(save_dir, base_name + '.xlsx')
# 保存Excel文件
workbook.save(excel_file)
result_files.append(excel_file)
print(f"已保存表格到Excel文件: {excel_file}")
except Exception as e:
print(f"保存Excel格式失败: {str(e)}")
return result_files
# 使用示例
if __name__ == "__main__":
# 示例聊天历史
history = [
"问题1",
"""这是第一个表格:
| A | B | C |
|---|---|---|
| 1 | 2 | 3 |""",
"问题2",
"这是没有表格的回答",
"问题3",
"""回答包含多个表格:
| Name | Age |
|------|-----|
| Tom | 20 |
第二个表格:
| X | Y |
|---|---|
| 1 | 2 |"""
]
# 保存表格
save_dir = "output"
base_name = "chat_tables"
saved_files = save_chat_tables(history, save_dir, base_name)

查看文件

@@ -0,0 +1,190 @@
class HtmlFormatter:
"""聊天记录HTML格式生成器"""
def __init__(self, chatbot, history):
self.chatbot = chatbot
self.history = history
self.css_styles = """
:root {
--primary-color: #2563eb;
--primary-light: #eff6ff;
--secondary-color: #1e293b;
--background-color: #f8fafc;
--text-color: #334155;
--border-color: #e2e8f0;
--card-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1);
}
body {
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
line-height: 1.8;
margin: 0;
padding: 2rem;
color: var(--text-color);
background-color: var(--background-color);
}
.container {
max-width: 1200px;
margin: 0 auto;
background: white;
padding: 2rem;
border-radius: 16px;
box-shadow: var(--card-shadow);
}
::selection {
background: var(--primary-light);
color: var(--primary-color);
}
@keyframes fadeIn {
from { opacity: 0; transform: translateY(20px); }
to { opacity: 1; transform: translateY(0); }
}
@keyframes slideIn {
from { transform: translateX(-20px); opacity: 0; }
to { transform: translateX(0); opacity: 1; }
}
.container {
animation: fadeIn 0.6s ease-out;
}
.QaBox {
animation: slideIn 0.5s ease-out;
transition: all 0.3s ease;
}
.QaBox:hover {
transform: translateX(5px);
}
.Question, .Answer, .historyBox {
transition: all 0.3s ease;
}
.chat-title {
color: var(--primary-color);
font-size: 2em;
text-align: center;
margin: 1rem 0 2rem;
padding-bottom: 1rem;
border-bottom: 2px solid var(--primary-color);
}
.chat-body {
display: flex;
flex-direction: column;
gap: 1.5rem;
margin: 2rem 0;
}
.QaBox {
background: white;
padding: 1.5rem;
border-radius: 8px;
border-left: 4px solid var(--primary-color);
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
margin-bottom: 1.5rem;
}
.Question {
color: var(--secondary-color);
font-weight: 500;
margin-bottom: 1rem;
}
.Answer {
color: var(--text-color);
background: var(--primary-light);
padding: 1rem;
border-radius: 6px;
}
.history-section {
margin-top: 3rem;
padding-top: 2rem;
border-top: 2px solid var(--border-color);
}
.history-title {
color: var(--secondary-color);
font-size: 1.5em;
margin-bottom: 1.5rem;
text-align: center;
}
.historyBox {
background: white;
padding: 1rem;
margin: 0.5rem 0;
border-radius: 6px;
border: 1px solid var(--border-color);
}
@media (prefers-color-scheme: dark) {
:root {
--background-color: #0f172a;
--text-color: #e2e8f0;
--border-color: #1e293b;
}
.container, .QaBox {
background: #1e293b;
}
}
"""
def format_chat_content(self) -> str:
"""格式化聊天内容"""
chat_content = []
for q, a in self.chatbot:
question = str(q) if q is not None else ""
answer = str(a) if a is not None else ""
chat_content.append(f'''
<div class="QaBox">
<div class="Question">{question}</div>
<div class="Answer">{answer}</div>
</div>
''')
return "\n".join(chat_content)
def format_history_content(self) -> str:
"""格式化历史记录内容"""
if not self.history:
return ""
history_content = []
for entry in self.history:
history_content.append(f'''
<div class="historyBox">
<div class="entry">{entry}</div>
</div>
''')
return "\n".join(history_content)
def create_document(self) -> str:
"""生成完整的HTML文档
Returns:
str: 完整的HTML文档字符串
"""
return f"""
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>对话存档</title>
<style>{self.css_styles}</style>
</head>
<body>
<div class="container">
<h1 class="chat-title">对话存档</h1>
<div class="chat-body">
{self.format_chat_content()}
</div>
</div>
</body>
</html>
"""

查看文件

@@ -0,0 +1,39 @@
class MarkdownFormatter:
"""Markdown格式文档生成器 - 用于生成对话记录的markdown文档"""
def __init__(self):
self.content = []
def _add_content(self, text: str):
"""添加正文内容"""
if text:
self.content.append(f"\n{text}\n")
def create_document(self, history: list) -> str:
"""
创建完整的Markdown文档
Args:
history: 历史记录列表,偶数位置为问题,奇数位置为答案
Returns:
str: 生成的Markdown文本
"""
self.content = []
# 处理问答对
for i in range(0, len(history), 2):
question = history[i]
answer = history[i + 1]
# 添加问题
self.content.append(f"\n### 问题 {i//2 + 1}")
self._add_content(question)
# 添加回答
self.content.append(f"\n### 回答 {i//2 + 1}")
self._add_content(answer)
# 添加分隔线
self.content.append("\n---\n")
return "\n".join(self.content)

查看文件

@@ -0,0 +1,172 @@
from datetime import datetime
import os
import re
from reportlab.pdfbase import pdfmetrics
from reportlab.pdfbase.ttfonts import TTFont
def convert_markdown_to_pdf(markdown_text):
"""将Markdown文本转换为PDF格式的纯文本"""
if not markdown_text:
return ""
# 标准化换行符
markdown_text = markdown_text.replace('\r\n', '\n').replace('\r', '\n')
# 处理标题、粗体、斜体
markdown_text = re.sub(r'^#\s+(.+)$', r'\1', markdown_text, flags=re.MULTILINE)
markdown_text = re.sub(r'\*\*(.+?)\*\*', r'\1', markdown_text)
markdown_text = re.sub(r'\*(.+?)\*', r'\1', markdown_text)
# 处理列表
markdown_text = re.sub(r'^\s*[-*+]\s+(.+?)(?=\n|$)', r'\1', markdown_text, flags=re.MULTILINE)
markdown_text = re.sub(r'^\s*\d+\.\s+(.+?)(?=\n|$)', r'\1', markdown_text, flags=re.MULTILINE)
# 处理链接
markdown_text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', r'\1', markdown_text)
# 处理段落
markdown_text = re.sub(r'\n{2,}', '\n', markdown_text)
markdown_text = re.sub(r'(?<!\n)(?<!^)(?<!•\s)(?<!\d\.\s)\n(?![\s•\d])', '\n\n', markdown_text, flags=re.MULTILINE)
# 清理空白
markdown_text = re.sub(r' +', ' ', markdown_text)
markdown_text = re.sub(r'(?m)^\s+|\s+$', '', markdown_text)
return markdown_text.strip()
class PDFFormatter:
"""聊天记录PDF文档生成器 - 使用 Noto Sans CJK 字体"""
def __init__(self):
self._init_reportlab()
self._register_fonts()
self.styles = self._get_reportlab_lib()['getSampleStyleSheet']()
self._create_styles()
def _init_reportlab(self):
"""初始化 ReportLab 相关组件"""
from reportlab.lib.pagesizes import A4
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.lib.units import cm
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer
self._lib = {
'A4': A4,
'getSampleStyleSheet': getSampleStyleSheet,
'ParagraphStyle': ParagraphStyle,
'cm': cm
}
self._platypus = {
'SimpleDocTemplate': SimpleDocTemplate,
'Paragraph': Paragraph,
'Spacer': Spacer
}
def _get_reportlab_lib(self):
return self._lib
def _get_reportlab_platypus(self):
return self._platypus
def _register_fonts(self):
"""注册 Noto Sans CJK 字体"""
possible_font_paths = [
'/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc',
'/usr/share/fonts/noto-cjk/NotoSansCJK-Regular.ttc',
'/usr/share/fonts/noto/NotoSansCJK-Regular.ttc'
]
font_registered = False
for path in possible_font_paths:
if os.path.exists(path):
try:
pdfmetrics.registerFont(TTFont('NotoSansCJK', path))
font_registered = True
break
except:
continue
if not font_registered:
print("Warning: Could not find Noto Sans CJK font. Using fallback font.")
self.font_name = 'Helvetica'
else:
self.font_name = 'NotoSansCJK'
def _create_styles(self):
"""创建文档样式"""
ParagraphStyle = self._lib['ParagraphStyle']
# 标题样式
self.styles.add(ParagraphStyle(
name='Title_Custom',
fontName=self.font_name,
fontSize=24,
leading=38,
alignment=1,
spaceAfter=32
))
# 日期样式
self.styles.add(ParagraphStyle(
name='Date_Style',
fontName=self.font_name,
fontSize=16,
leading=20,
alignment=1,
spaceAfter=20
))
# 问题样式
self.styles.add(ParagraphStyle(
name='Question_Style',
fontName=self.font_name,
fontSize=12,
leading=18,
leftIndent=28,
spaceAfter=6
))
# 回答样式
self.styles.add(ParagraphStyle(
name='Answer_Style',
fontName=self.font_name,
fontSize=12,
leading=18,
leftIndent=28,
spaceAfter=12
))
def create_document(self, history, output_path):
"""生成PDF文档"""
# 创建PDF文档
doc = self._platypus['SimpleDocTemplate'](
output_path,
pagesize=self._lib['A4'],
rightMargin=2.6 * self._lib['cm'],
leftMargin=2.8 * self._lib['cm'],
topMargin=3.7 * self._lib['cm'],
bottomMargin=3.5 * self._lib['cm']
)
# 构建内容
story = []
Paragraph = self._platypus['Paragraph']
# 添加对话内容
for i in range(0, len(history), 2):
question = history[i]
answer = convert_markdown_to_pdf(history[i + 1]) if i + 1 < len(history) else ""
if question:
q_text = f'问题 {i // 2 + 1}{str(question)}'
story.append(Paragraph(q_text, self.styles['Question_Style']))
if answer:
a_text = f'回答 {i // 2 + 1}{str(answer)}'
story.append(Paragraph(a_text, self.styles['Answer_Style']))
# 构建PDF
doc.build(story)
return doc

查看文件

@@ -0,0 +1,79 @@
import re
def convert_markdown_to_txt(markdown_text):
"""Convert markdown text to plain text while preserving formatting"""
# Standardize line endings
markdown_text = markdown_text.replace('\r\n', '\n').replace('\r', '\n')
# 1. Handle headers but keep their formatting instead of removing them
markdown_text = re.sub(r'^#\s+(.+)$', r'# \1', markdown_text, flags=re.MULTILINE)
markdown_text = re.sub(r'^##\s+(.+)$', r'## \1', markdown_text, flags=re.MULTILINE)
markdown_text = re.sub(r'^###\s+(.+)$', r'### \1', markdown_text, flags=re.MULTILINE)
# 2. Handle bold and italic - simply remove markers
markdown_text = re.sub(r'\*\*(.+?)\*\*', r'\1', markdown_text)
markdown_text = re.sub(r'\*(.+?)\*', r'\1', markdown_text)
# 3. Handle lists but preserve formatting
markdown_text = re.sub(r'^\s*[-*+]\s+(.+?)(?=\n|$)', r'\1', markdown_text, flags=re.MULTILINE)
# 4. Handle links - keep only the text
markdown_text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', r'\1 (\2)', markdown_text)
# 5. Handle HTML links - convert to user-friendly format
markdown_text = re.sub(r'<a href=[\'"]([^\'"]+)[\'"](?:\s+target=[\'"][^\'"]+[\'"])?>([^<]+)</a>', r'\2 (\1)',
markdown_text)
# 6. Preserve paragraph breaks
markdown_text = re.sub(r'\n{3,}', '\n\n', markdown_text) # normalize multiple newlines to double newlines
# 7. Clean up extra spaces but maintain indentation
markdown_text = re.sub(r' +', ' ', markdown_text)
return markdown_text.strip()
class TxtFormatter:
"""Chat history TXT document generator"""
def __init__(self):
self.content = []
self._setup_document()
def _setup_document(self):
"""Initialize document with header"""
self.content.append("=" * 50)
self.content.append("GPT-Academic对话记录".center(48))
self.content.append("=" * 50)
def _format_header(self):
"""Create document header with current date"""
from datetime import datetime
date_str = datetime.now().strftime('%Y年%m月%d')
return [
date_str.center(48),
"\n" # Add blank line after date
]
def create_document(self, history):
"""Generate document from chat history"""
# Add header with date
self.content.extend(self._format_header())
# Add conversation content
for i in range(0, len(history), 2):
question = history[i]
answer = convert_markdown_to_txt(history[i + 1]) if i + 1 < len(history) else ""
if question:
self.content.append(f"问题 {i // 2 + 1}{str(question)}")
self.content.append("") # Add blank line
if answer:
self.content.append(f"回答 {i // 2 + 1}{str(answer)}")
self.content.append("") # Add blank line
# Join all content with newlines
return "\n".join(self.content)

查看文件

@@ -0,0 +1,155 @@
from docx2pdf import convert
import os
import platform
import subprocess
from typing import Union
from pathlib import Path
from datetime import datetime
class WordToPdfConverter:
"""Word文档转PDF转换器"""
@staticmethod
def convert_to_pdf(word_path: Union[str, Path], pdf_path: Union[str, Path] = None) -> str:
"""
将Word文档转换为PDF
参数:
word_path: Word文档的路径
pdf_path: 可选,PDF文件的输出路径。如果未指定,将使用与Word文档相同的名称和位置
返回:
生成的PDF文件路径
异常:
如果转换失败,将抛出相应异常
"""
try:
# 确保输入路径是Path对象
word_path = Path(word_path)
# 如果未指定pdf_path,则使用与word文档相同的名称
if pdf_path is None:
pdf_path = word_path.with_suffix('.pdf')
else:
pdf_path = Path(pdf_path)
# 检查操作系统
if platform.system() == 'Linux':
# Linux系统需要安装libreoffice
which_result = subprocess.run(['which', 'libreoffice'], capture_output=True, text=True)
if which_result.returncode != 0:
raise RuntimeError("请先安装LibreOffice: sudo apt-get install libreoffice")
print(f"开始转换Word文档: {word_path} 到 PDF")
# 使用subprocess代替os.system
result = subprocess.run(
['libreoffice', '--headless', '--convert-to', 'pdf:writer_pdf_Export',
str(word_path), '--outdir', str(pdf_path.parent)],
capture_output=True, text=True
)
if result.returncode != 0:
error_msg = result.stderr or "未知错误"
print(f"LibreOffice转换失败,错误信息: {error_msg}")
raise RuntimeError(f"LibreOffice转换失败: {error_msg}")
print(f"LibreOffice转换输出: {result.stdout}")
# 如果输出路径与默认生成的不同,则重命名
default_pdf = word_path.with_suffix('.pdf')
if default_pdf != pdf_path and default_pdf.exists():
os.rename(default_pdf, pdf_path)
print(f"已将PDF从 {default_pdf} 重命名为 {pdf_path}")
# 验证PDF是否成功生成
if not pdf_path.exists() or pdf_path.stat().st_size == 0:
raise RuntimeError("PDF生成失败或文件为空")
print(f"PDF转换成功,文件大小: {pdf_path.stat().st_size} 字节")
else:
# Windows和MacOS使用docx2pdf
print(f"使用docx2pdf转换 {word_path}{pdf_path}")
convert(word_path, pdf_path)
# 验证PDF是否成功生成
if not pdf_path.exists() or pdf_path.stat().st_size == 0:
raise RuntimeError("PDF生成失败或文件为空")
print(f"PDF转换成功,文件大小: {pdf_path.stat().st_size} 字节")
return str(pdf_path)
except Exception as e:
print(f"PDF转换异常: {str(e)}")
raise Exception(f"转换PDF失败: {str(e)}")
@staticmethod
def batch_convert(word_dir: Union[str, Path], pdf_dir: Union[str, Path] = None) -> list:
"""
批量转换目录下的所有Word文档
参数:
word_dir: 包含Word文档的目录路径
pdf_dir: 可选,PDF文件的输出目录。如果未指定,将使用与Word文档相同的目录
返回:
生成的PDF文件路径列表
"""
word_dir = Path(word_dir)
if pdf_dir:
pdf_dir = Path(pdf_dir)
pdf_dir.mkdir(parents=True, exist_ok=True)
converted_files = []
for word_file in word_dir.glob("*.docx"):
try:
if pdf_dir:
pdf_path = pdf_dir / word_file.with_suffix('.pdf').name
else:
pdf_path = word_file.with_suffix('.pdf')
pdf_file = WordToPdfConverter.convert_to_pdf(word_file, pdf_path)
converted_files.append(pdf_file)
except Exception as e:
print(f"转换 {word_file} 失败: {str(e)}")
return converted_files
@staticmethod
def convert_doc_to_pdf(doc, output_dir: Union[str, Path] = None) -> str:
"""
将docx对象直接转换为PDF
参数:
doc: python-docx的Document对象
output_dir: 可选,输出目录。如果未指定,将使用当前目录
返回:
生成的PDF文件路径
"""
try:
# 设置临时文件路径和输出路径
output_dir = Path(output_dir) if output_dir else Path.cwd()
output_dir.mkdir(parents=True, exist_ok=True)
# 生成临时word文件
temp_docx = output_dir / f"temp_{datetime.now().strftime('%Y%m%d_%H%M%S')}.docx"
doc.save(temp_docx)
# 转换为PDF
pdf_path = temp_docx.with_suffix('.pdf')
WordToPdfConverter.convert_to_pdf(temp_docx, pdf_path)
# 删除临时word文件
temp_docx.unlink()
return str(pdf_path)
except Exception as e:
if temp_docx.exists():
temp_docx.unlink()
raise Exception(f"转换PDF失败: {str(e)}")

查看文件

@@ -0,0 +1,177 @@
import re
from docx import Document
from docx.shared import Cm, Pt
from docx.enum.text import WD_PARAGRAPH_ALIGNMENT, WD_LINE_SPACING
from docx.enum.style import WD_STYLE_TYPE
from docx.oxml.ns import qn
from datetime import datetime
def convert_markdown_to_word(markdown_text):
# 0. 首先标准化所有换行符为\n
markdown_text = markdown_text.replace('\r\n', '\n').replace('\r', '\n')
# 1. 处理标题 - 支持更多级别的标题,使用更精确的正则
# 保留标题标记,以便后续处理时还能识别出标题级别
markdown_text = re.sub(r'^(#{1,6})\s+(.+?)(?:\s+#+)?$', r'\1 \2', markdown_text, flags=re.MULTILINE)
# 2. 处理粗体、斜体和加粗斜体
markdown_text = re.sub(r'\*\*\*(.+?)\*\*\*', r'\1', markdown_text) # 加粗斜体
markdown_text = re.sub(r'\*\*(.+?)\*\*', r'\1', markdown_text) # 加粗
markdown_text = re.sub(r'\*(.+?)\*', r'\1', markdown_text) # 斜体
markdown_text = re.sub(r'_(.+?)_', r'\1', markdown_text) # 下划线斜体
markdown_text = re.sub(r'__(.+?)__', r'\1', markdown_text) # 下划线加粗
# 3. 处理代码块 - 不移除,而是简化格式
# 多行代码块
markdown_text = re.sub(r'```(?:\w+)?\n([\s\S]*?)```', r'[代码块]\n\1[/代码块]', markdown_text)
# 单行代码
markdown_text = re.sub(r'`([^`]+)`', r'[代码]\1[/代码]', markdown_text)
# 4. 处理列表 - 保留列表结构
# 匹配无序列表
markdown_text = re.sub(r'^(\s*)[-*+]\s+(.+?)$', r'\1• \2', markdown_text, flags=re.MULTILINE)
# 5. 处理Markdown链接
markdown_text = re.sub(r'\[([^\]]+)\]\(([^)]+?)\s*(?:"[^"]*")?\)', r'\1 (\2)', markdown_text)
# 6. 处理HTML链接
markdown_text = re.sub(r'<a href=[\'"]([^\'"]+)[\'"](?:\s+target=[\'"][^\'"]+[\'"])?>([^<]+)</a>', r'\2 (\1)',
markdown_text)
# 7. 处理图片
markdown_text = re.sub(r'!\[([^\]]*)\]\([^)]+\)', r'[图片:\1]', markdown_text)
return markdown_text
class WordFormatter:
"""聊天记录Word文档生成器 - 符合中国政府公文格式规范(GB/T 9704-2012)"""
def __init__(self):
self.doc = Document()
self._setup_document()
self._create_styles()
def _setup_document(self):
"""设置文档基本格式,包括页面设置和页眉"""
sections = self.doc.sections
for section in sections:
# 设置页面大小为A4
section.page_width = Cm(21)
section.page_height = Cm(29.7)
# 设置页边距
section.top_margin = Cm(3.7) # 上边距37mm
section.bottom_margin = Cm(3.5) # 下边距35mm
section.left_margin = Cm(2.8) # 左边距28mm
section.right_margin = Cm(2.6) # 右边距26mm
# 设置页眉页脚距离
section.header_distance = Cm(2.0)
section.footer_distance = Cm(2.0)
# 添加页眉
header = section.header
header_para = header.paragraphs[0]
header_para.alignment = WD_PARAGRAPH_ALIGNMENT.RIGHT
header_run = header_para.add_run("GPT-Academic对话记录")
header_run.font.name = '仿宋'
header_run._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
header_run.font.size = Pt(9)
def _create_styles(self):
"""创建文档样式"""
# 创建正文样式
style = self.doc.styles.add_style('Normal_Custom', WD_STYLE_TYPE.PARAGRAPH)
style.font.name = '仿宋'
style._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
style.font.size = Pt(12) # 调整为12磅
style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
style.paragraph_format.space_after = Pt(0)
# 创建问题样式
question_style = self.doc.styles.add_style('Question_Style', WD_STYLE_TYPE.PARAGRAPH)
question_style.font.name = '黑体'
question_style._element.rPr.rFonts.set(qn('w:eastAsia'), '黑体')
question_style.font.size = Pt(14) # 调整为14磅
question_style.font.bold = True
question_style.paragraph_format.space_before = Pt(12) # 减小段前距
question_style.paragraph_format.space_after = Pt(6)
question_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
question_style.paragraph_format.left_indent = Pt(0) # 移除左缩进
# 创建回答样式
answer_style = self.doc.styles.add_style('Answer_Style', WD_STYLE_TYPE.PARAGRAPH)
answer_style.font.name = '仿宋'
answer_style._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
answer_style.font.size = Pt(12) # 调整为12磅
answer_style.paragraph_format.space_before = Pt(6)
answer_style.paragraph_format.space_after = Pt(12)
answer_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
answer_style.paragraph_format.left_indent = Pt(0) # 移除左缩进
# 创建标题样式
title_style = self.doc.styles.add_style('Title_Custom', WD_STYLE_TYPE.PARAGRAPH)
title_style.font.name = '黑体' # 改用黑体
title_style._element.rPr.rFonts.set(qn('w:eastAsia'), '黑体')
title_style.font.size = Pt(22) # 调整为22磅
title_style.font.bold = True
title_style.paragraph_format.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
title_style.paragraph_format.space_before = Pt(0)
title_style.paragraph_format.space_after = Pt(24)
title_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
# 添加参考文献样式
ref_style = self.doc.styles.add_style('Reference_Style', WD_STYLE_TYPE.PARAGRAPH)
ref_style.font.name = '宋体'
ref_style._element.rPr.rFonts.set(qn('w:eastAsia'), '宋体')
ref_style.font.size = Pt(10.5) # 参考文献使用小号字体
ref_style.paragraph_format.space_before = Pt(3)
ref_style.paragraph_format.space_after = Pt(3)
ref_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.SINGLE
ref_style.paragraph_format.left_indent = Pt(21)
ref_style.paragraph_format.first_line_indent = Pt(-21)
# 添加参考文献标题样式
ref_title_style = self.doc.styles.add_style('Reference_Title_Style', WD_STYLE_TYPE.PARAGRAPH)
ref_title_style.font.name = '黑体'
ref_title_style._element.rPr.rFonts.set(qn('w:eastAsia'), '黑体')
ref_title_style.font.size = Pt(16)
ref_title_style.font.bold = True
ref_title_style.paragraph_format.space_before = Pt(24)
ref_title_style.paragraph_format.space_after = Pt(12)
ref_title_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
def create_document(self, history):
"""写入聊天历史"""
# 添加标题
title_para = self.doc.add_paragraph(style='Title_Custom')
title_run = title_para.add_run('GPT-Academic 对话记录')
# 添加日期
date_para = self.doc.add_paragraph()
date_para.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
date_run = date_para.add_run(datetime.now().strftime('%Y年%m月%d'))
date_run.font.name = '仿宋'
date_run._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
date_run.font.size = Pt(16)
self.doc.add_paragraph() # 添加空行
# 添加对话内容
for i in range(0, len(history), 2):
question = history[i]
answer = convert_markdown_to_word(history[i + 1])
if question:
q_para = self.doc.add_paragraph(style='Question_Style')
q_para.add_run(f'问题 {i//2 + 1}').bold = True
q_para.add_run(str(question))
if answer:
a_para = self.doc.add_paragraph(style='Answer_Style')
a_para.add_run(f'回答 {i//2 + 1}').bold = True
a_para.add_run(str(answer))
return self.doc

查看文件

@@ -0,0 +1,4 @@
import nltk
nltk.data.path.append('~/nltk_data')
nltk.download('averaged_perceptron_tagger', download_dir='~/nltk_data')
nltk.download('punkt', download_dir='~/nltk_data')

查看文件

@@ -0,0 +1,286 @@
from __future__ import annotations
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Optional, List, Set, Dict, Union, Iterator, Tuple
from dataclasses import dataclass, field
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
import chardet
from functools import lru_cache
import os
@dataclass
class ExtractorConfig:
"""提取器配置类"""
encoding: str = 'auto'
na_filter: bool = True
skip_blank_lines: bool = True
chunk_size: int = 10000
max_workers: int = 4
preserve_format: bool = True
read_all_sheets: bool = True # 新增:是否读取所有工作表
text_cleanup: Dict[str, bool] = field(default_factory=lambda: {
'remove_extra_spaces': True,
'normalize_whitespace': False,
'remove_special_chars': False,
'lowercase': False
})
class ExcelTextExtractor:
"""增强的Excel格式文件文本内容提取器"""
SUPPORTED_EXTENSIONS: Set[str] = {
'.xlsx', '.xls', '.csv', '.tsv', '.xlsm', '.xltx', '.xltm', '.ods'
}
def __init__(self, config: Optional[ExtractorConfig] = None):
self.config = config or ExtractorConfig()
self._setup_logging()
self._detect_encoding = lru_cache(maxsize=128)(self._detect_encoding)
def _setup_logging(self) -> None:
"""配置日志记录器"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
self.logger = logging.getLogger(__name__)
fh = logging.FileHandler('excel_extractor.log')
fh.setLevel(logging.ERROR)
self.logger.addHandler(fh)
def _detect_encoding(self, file_path: Path) -> str:
if self.config.encoding != 'auto':
return self.config.encoding
try:
with open(file_path, 'rb') as f:
raw_data = f.read(10000)
result = chardet.detect(raw_data)
return result['encoding'] or 'utf-8'
except Exception as e:
self.logger.warning(f"Encoding detection failed: {e}. Using utf-8")
return 'utf-8'
def _validate_file(self, file_path: Union[str, Path]) -> Path:
path = Path(file_path).resolve()
if not path.exists():
raise ValueError(f"File not found: {path}")
if not path.is_file():
raise ValueError(f"Not a file: {path}")
if not os.access(path, os.R_OK):
raise PermissionError(f"No read permission: {path}")
if path.suffix.lower() not in self.SUPPORTED_EXTENSIONS:
raise ValueError(
f"Unsupported format: {path.suffix}. "
f"Supported: {', '.join(sorted(self.SUPPORTED_EXTENSIONS))}"
)
return path
def _format_value(self, value: Any) -> str:
if pd.isna(value) or value is None:
return ''
if isinstance(value, (int, float)):
return str(value)
return str(value).strip()
def _process_chunk(self, chunk: pd.DataFrame, columns: Optional[List[str]] = None, sheet_name: str = '') -> str:
"""处理数据块,新增sheet_name参数"""
try:
if columns:
chunk = chunk[columns]
if self.config.preserve_format:
formatted_chunk = chunk.applymap(self._format_value)
rows = []
# 添加工作表名称作为标题
if sheet_name:
rows.append(f"[Sheet: {sheet_name}]")
# 添加表头
headers = [str(col) for col in formatted_chunk.columns]
rows.append('\t'.join(headers))
# 添加数据行
for _, row in formatted_chunk.iterrows():
rows.append('\t'.join(row.values))
return '\n'.join(rows)
else:
flat_values = (
chunk.astype(str)
.replace({'nan': '', 'None': '', 'NaN': ''})
.values.flatten()
)
return ' '.join(v for v in flat_values if v)
except Exception as e:
self.logger.error(f"Error processing chunk: {e}")
raise
def _read_file(self, file_path: Path) -> Union[pd.DataFrame, Iterator[pd.DataFrame], Dict[str, pd.DataFrame]]:
"""读取文件,支持多工作表"""
try:
encoding = self._detect_encoding(file_path)
if file_path.suffix.lower() in {'.csv', '.tsv'}:
sep = '\t' if file_path.suffix.lower() == '.tsv' else ','
# 对大文件使用分块读取
if file_path.stat().st_size > self.config.chunk_size * 1024:
return pd.read_csv(
file_path,
encoding=encoding,
na_filter=self.config.na_filter,
skip_blank_lines=self.config.skip_blank_lines,
sep=sep,
chunksize=self.config.chunk_size,
on_bad_lines='warn'
)
else:
return pd.read_csv(
file_path,
encoding=encoding,
na_filter=self.config.na_filter,
skip_blank_lines=self.config.skip_blank_lines,
sep=sep
)
else:
# Excel文件处理,支持多工作表
if self.config.read_all_sheets:
# 读取所有工作表
return pd.read_excel(
file_path,
na_filter=self.config.na_filter,
keep_default_na=self.config.na_filter,
engine='openpyxl',
sheet_name=None # None表示读取所有工作表
)
else:
# 只读取第一个工作表
return pd.read_excel(
file_path,
na_filter=self.config.na_filter,
keep_default_na=self.config.na_filter,
engine='openpyxl',
sheet_name=0 # 读取第一个工作表
)
except Exception as e:
self.logger.error(f"Error reading file {file_path}: {e}")
raise
def extract_text(
self,
file_path: Union[str, Path],
columns: Optional[List[str]] = None,
separator: str = '\n'
) -> str:
"""提取文本,支持多工作表"""
try:
path = self._validate_file(file_path)
self.logger.info(f"Processing: {path}")
reader = self._read_file(path)
texts = []
# 处理Excel多工作表
if isinstance(reader, dict):
for sheet_name, df in reader.items():
sheet_text = self._process_chunk(df, columns, sheet_name)
if sheet_text:
texts.append(sheet_text)
return separator.join(texts)
# 处理单个DataFrame
elif isinstance(reader, pd.DataFrame):
return self._process_chunk(reader, columns)
# 处理DataFrame迭代器
else:
with ThreadPoolExecutor(max_workers=self.config.max_workers) as executor:
futures = {
executor.submit(self._process_chunk, chunk, columns): i
for i, chunk in enumerate(reader)
}
chunk_texts = []
for future in as_completed(futures):
try:
text = future.result()
if text:
chunk_texts.append((futures[future], text))
except Exception as e:
self.logger.error(f"Error in chunk {futures[future]}: {e}")
# 按块的顺序排序
chunk_texts.sort(key=lambda x: x[0])
texts = [text for _, text in chunk_texts]
# 合并文本,保留格式
if texts and self.config.preserve_format:
result = texts[0] # 第一块包含表头
if len(texts) > 1:
# 跳过后续块的表头行
for text in texts[1:]:
result += '\n' + '\n'.join(text.split('\n')[1:])
return result
else:
return separator.join(texts)
except Exception as e:
self.logger.error(f"Extraction failed: {e}")
raise
@staticmethod
def get_supported_formats() -> List[str]:
"""获取支持的文件格式列表"""
return sorted(ExcelTextExtractor.SUPPORTED_EXTENSIONS)
def main():
"""主函数:演示用法"""
config = ExtractorConfig(
encoding='auto',
preserve_format=True,
read_all_sheets=True, # 启用多工作表读取
text_cleanup={
'remove_extra_spaces': True,
'normalize_whitespace': False,
'remove_special_chars': False,
'lowercase': False
}
)
extractor = ExcelTextExtractor(config)
try:
sample_file = 'example.xlsx'
if Path(sample_file).exists():
text = extractor.extract_text(
sample_file,
columns=['title', 'content']
)
print("提取的文本:")
print(text)
else:
print(f"示例文件 {sample_file} 不存在")
print("\n支持的格式:", extractor.get_supported_formats())
except Exception as e:
print(f"错误: {e}")
if __name__ == "__main__":
main()

查看文件

@@ -0,0 +1,359 @@
from __future__ import annotations
from pathlib import Path
from typing import Optional, Set, Dict, Union, List
from dataclasses import dataclass, field
import logging
import os
import re
import subprocess
import tempfile
import shutil
@dataclass
class MarkdownConverterConfig:
"""PDF 到 Markdown 转换器配置类
Attributes:
extract_images: 是否提取图片
extract_tables: 是否尝试保留表格结构
extract_code_blocks: 是否识别代码块
extract_math: 是否转换数学公式
output_dir: 输出目录路径
image_dir: 图片保存目录路径
paragraph_separator: 段落之间的分隔符
text_cleanup: 文本清理选项字典
docintel_endpoint: Document Intelligence端点URL (可选)
enable_plugins: 是否启用插件
llm_client: LLM客户端对象 (例如OpenAI client)
llm_model: 要使用的LLM模型名称
"""
extract_images: bool = True
extract_tables: bool = True
extract_code_blocks: bool = True
extract_math: bool = True
output_dir: str = ""
image_dir: str = "images"
paragraph_separator: str = '\n\n'
text_cleanup: Dict[str, bool] = field(default_factory=lambda: {
'remove_extra_spaces': True,
'normalize_whitespace': True,
'remove_special_chars': False,
'lowercase': False
})
docintel_endpoint: str = ""
enable_plugins: bool = False
llm_client: Optional[object] = None
llm_model: str = ""
class MarkdownConverter:
"""PDF 到 Markdown 转换器
使用 markitdown 库实现 PDF 到 Markdown 的转换,支持多种配置选项。
"""
SUPPORTED_EXTENSIONS: Set[str] = {
'.pdf',
}
def __init__(self, config: Optional[MarkdownConverterConfig] = None):
"""初始化转换器
Args:
config: 转换器配置对象,如果为None则使用默认配置
"""
self.config = config or MarkdownConverterConfig()
self._setup_logging()
# 检查是否安装了 markitdown
self._check_markitdown_installation()
def _setup_logging(self) -> None:
"""配置日志记录器"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
self.logger = logging.getLogger(__name__)
# 添加文件处理器
fh = logging.FileHandler('markdown_converter.log')
fh.setLevel(logging.ERROR)
self.logger.addHandler(fh)
def _check_markitdown_installation(self) -> None:
"""检查是否安装了 markitdown"""
try:
# 尝试导入 markitdown 库
from markitdown import MarkItDown
self.logger.info("markitdown 库已安装")
except ImportError:
self.logger.warning("markitdown 库未安装,尝试安装...")
try:
subprocess.check_call(["pip", "install", "markitdown"])
self.logger.info("markitdown 库安装成功")
from markitdown import MarkItDown
except (subprocess.SubprocessError, ImportError):
self.logger.error("无法安装 markitdown 库,请手动安装")
self.markitdown_available = False
return
self.markitdown_available = True
def _validate_file(self, file_path: Union[str, Path], max_size_mb: int = 100) -> Path:
"""验证文件
Args:
file_path: 文件路径
max_size_mb: 允许的最大文件大小(MB)
Returns:
Path: 验证后的Path对象
Raises:
ValueError: 文件不存在、格式不支持或大小超限
PermissionError: 没有读取权限
"""
path = Path(file_path).resolve()
if not path.exists():
raise ValueError(f"文件不存在: {path}")
if not path.is_file():
raise ValueError(f"不是一个文件: {path}")
if not os.access(path, os.R_OK):
raise PermissionError(f"没有读取权限: {path}")
file_size_mb = path.stat().st_size / (1024 * 1024)
if file_size_mb > max_size_mb:
raise ValueError(
f"文件大小 ({file_size_mb:.1f}MB) 超过限制 {max_size_mb}MB"
)
if path.suffix.lower() not in self.SUPPORTED_EXTENSIONS:
raise ValueError(
f"不支持的格式: {path.suffix}. "
f"支持的格式: {', '.join(sorted(self.SUPPORTED_EXTENSIONS))}"
)
return path
def _cleanup_text(self, text: str) -> str:
"""清理文本
Args:
text: 原始文本
Returns:
str: 清理后的文本
"""
if self.config.text_cleanup['remove_extra_spaces']:
text = ' '.join(text.split())
if self.config.text_cleanup['normalize_whitespace']:
text = text.replace('\t', ' ').replace('\r', '\n')
if self.config.text_cleanup['lowercase']:
text = text.lower()
return text.strip()
@staticmethod
def get_supported_formats() -> List[str]:
"""获取支持的文件格式列表"""
return sorted(MarkdownConverter.SUPPORTED_EXTENSIONS)
def convert_to_markdown(
self,
file_path: Union[str, Path],
output_path: Optional[Union[str, Path]] = None
) -> str:
"""将 PDF 转换为 Markdown
Args:
file_path: PDF 文件路径
output_path: 输出 Markdown 文件路径,如果为 None 则返回内容而不保存
Returns:
str: 转换后的 Markdown 内容
Raises:
Exception: 转换过程中的错误
"""
try:
path = self._validate_file(file_path)
self.logger.info(f"处理: {path}")
if not self.markitdown_available:
raise ImportError("markitdown 库未安装,无法进行转换")
# 导入 markitdown 库
from markitdown import MarkItDown
# 准备输出目录
if output_path:
output_path = Path(output_path)
output_dir = output_path.parent
output_dir.mkdir(parents=True, exist_ok=True)
else:
# 创建临时目录作为输出目录
temp_dir = tempfile.mkdtemp()
output_dir = Path(temp_dir)
output_path = output_dir / f"{path.stem}.md"
# 图片目录
image_dir = output_dir / self.config.image_dir
image_dir.mkdir(parents=True, exist_ok=True)
# 创建 MarkItDown 实例并进行转换
if self.config.docintel_endpoint:
md = MarkItDown(docintel_endpoint=self.config.docintel_endpoint)
elif self.config.llm_client and self.config.llm_model:
md = MarkItDown(
enable_plugins=self.config.enable_plugins,
llm_client=self.config.llm_client,
llm_model=self.config.llm_model
)
else:
md = MarkItDown(enable_plugins=self.config.enable_plugins)
# 执行转换
result = md.convert(str(path))
markdown_content = result.text_content
# 清理文本
markdown_content = self._cleanup_text(markdown_content)
# 如果需要保存到文件
if output_path:
with open(output_path, 'w', encoding='utf-8') as f:
f.write(markdown_content)
self.logger.info(f"转换成功,输出到: {output_path}")
return markdown_content
except Exception as e:
self.logger.error(f"转换失败: {e}")
raise
finally:
# 如果使用了临时目录且没有指定输出路径,则清理临时目录
if 'temp_dir' in locals() and not output_path:
shutil.rmtree(temp_dir, ignore_errors=True)
def convert_to_markdown_and_save(
self,
file_path: Union[str, Path],
output_path: Union[str, Path]
) -> Path:
"""将 PDF 转换为 Markdown 并保存到指定路径
Args:
file_path: PDF 文件路径
output_path: 输出 Markdown 文件路径
Returns:
Path: 输出文件的 Path 对象
Raises:
Exception: 转换过程中的错误
"""
self.convert_to_markdown(file_path, output_path)
return Path(output_path)
def batch_convert(
self,
file_paths: List[Union[str, Path]],
output_dir: Union[str, Path]
) -> List[Path]:
"""批量转换多个 PDF 文件为 Markdown
Args:
file_paths: PDF 文件路径列表
output_dir: 输出目录路径
Returns:
List[Path]: 输出文件路径列表
Raises:
Exception: 转换过程中的错误
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_paths = []
for file_path in file_paths:
path = Path(file_path)
output_path = output_dir / f"{path.stem}.md"
try:
self.convert_to_markdown(file_path, output_path)
output_paths.append(output_path)
self.logger.info(f"成功转换: {path} -> {output_path}")
except Exception as e:
self.logger.error(f"转换失败 {path}: {e}")
return output_paths
def main():
"""主函数:演示用法"""
# 配置
config = MarkdownConverterConfig(
extract_images=True,
extract_tables=True,
extract_code_blocks=True,
extract_math=True,
enable_plugins=False,
text_cleanup={
'remove_extra_spaces': True,
'normalize_whitespace': True,
'remove_special_chars': False,
'lowercase': False
}
)
# 创建转换器
converter = MarkdownConverter(config)
# 使用示例
try:
# 替换为实际的文件路径
sample_file = './crazy_functions/doc_fns/read_fns/paper/2501.12599v1.pdf'
if Path(sample_file).exists():
# 转换为 Markdown 并打印内容
markdown_content = converter.convert_to_markdown(sample_file)
print("转换后的 Markdown 内容:")
print(markdown_content[:500] + "...") # 只打印前500个字符
# 转换并保存到文件
output_file = f"./output_{Path(sample_file).stem}.md"
output_path = converter.convert_to_markdown_and_save(sample_file, output_file)
print(f"\n已保存到: {output_path}")
# 使用LLM增强的示例 (需要添加相应的导入和配置)
# try:
# from openai import OpenAI
# client = OpenAI()
# llm_config = MarkdownConverterConfig(
# llm_client=client,
# llm_model="gpt-4o"
# )
# llm_converter = MarkdownConverter(llm_config)
# llm_result = llm_converter.convert_to_markdown("example.jpg")
# print("LLM增强的结果:")
# print(llm_result[:500] + "...")
# except ImportError:
# print("未安装OpenAI库,跳过LLM示例")
else:
print(f"示例文件 {sample_file} 不存在")
print("\n支持的格式:", converter.get_supported_formats())
except Exception as e:
print(f"错误: {e}")
if __name__ == "__main__":
main()

查看文件

@@ -0,0 +1,493 @@
from __future__ import annotations
from pathlib import Path
from typing import Optional, Set, Dict, Union, List
from dataclasses import dataclass, field
import logging
import os
import re
from unstructured.partition.auto import partition
from unstructured.documents.elements import (
Text, Title, NarrativeText, ListItem, Table,
Footer, Header, PageBreak, Image, Address
)
@dataclass
class PaperMetadata:
"""论文元数据类"""
title: str = ""
authors: List[str] = field(default_factory=list)
affiliations: List[str] = field(default_factory=list)
journal: str = ""
volume: str = ""
issue: str = ""
year: str = ""
doi: str = ""
date: str = ""
publisher: str = ""
conference: str = ""
abstract: str = ""
keywords: List[str] = field(default_factory=list)
@dataclass
class ExtractorConfig:
"""元数据提取器配置类"""
paragraph_separator: str = '\n\n'
text_cleanup: Dict[str, bool] = field(default_factory=lambda: {
'remove_extra_spaces': True,
'normalize_whitespace': True,
'remove_special_chars': False,
'lowercase': False
})
class PaperMetadataExtractor:
"""论文元数据提取器
使用unstructured库从多种文档格式中提取论文的标题、作者、摘要等元数据信息。
"""
SUPPORTED_EXTENSIONS: Set[str] = {
'.pdf', '.docx', '.doc', '.txt', '.ppt', '.pptx',
'.xlsx', '.xls', '.md', '.org', '.odt', '.rst',
'.rtf', '.epub', '.html', '.xml', '.json'
}
# 定义论文各部分的关键词模式
SECTION_PATTERNS = {
'abstract': r'\b(摘要|abstract|summary|概要|résumé|zusammenfassung|аннотация)\b',
'keywords': r'\b(关键词|keywords|key\s+words|关键字|mots[- ]clés|schlüsselwörter|ключевые слова)\b',
}
def __init__(self, config: Optional[ExtractorConfig] = None):
"""初始化提取器
Args:
config: 提取器配置对象,如果为None则使用默认配置
"""
self.config = config or ExtractorConfig()
self._setup_logging()
def _setup_logging(self) -> None:
"""配置日志记录器"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
self.logger = logging.getLogger(__name__)
# 添加文件处理器
fh = logging.FileHandler('paper_metadata_extractor.log')
fh.setLevel(logging.ERROR)
self.logger.addHandler(fh)
def _validate_file(self, file_path: Union[str, Path], max_size_mb: int = 100) -> Path:
"""验证文件
Args:
file_path: 文件路径
max_size_mb: 允许的最大文件大小(MB)
Returns:
Path: 验证后的Path对象
Raises:
ValueError: 文件不存在、格式不支持或大小超限
PermissionError: 没有读取权限
"""
path = Path(file_path).resolve()
if not path.exists():
raise ValueError(f"文件不存在: {path}")
if not path.is_file():
raise ValueError(f"不是文件: {path}")
if not os.access(path, os.R_OK):
raise PermissionError(f"没有读取权限: {path}")
file_size_mb = path.stat().st_size / (1024 * 1024)
if file_size_mb > max_size_mb:
raise ValueError(
f"文件大小 ({file_size_mb:.1f}MB) 超过限制 {max_size_mb}MB"
)
if path.suffix.lower() not in self.SUPPORTED_EXTENSIONS:
raise ValueError(
f"不支持的文件格式: {path.suffix}. "
f"支持的格式: {', '.join(sorted(self.SUPPORTED_EXTENSIONS))}"
)
return path
def _cleanup_text(self, text: str) -> str:
"""清理文本
Args:
text: 原始文本
Returns:
str: 清理后的文本
"""
if self.config.text_cleanup['remove_extra_spaces']:
text = ' '.join(text.split())
if self.config.text_cleanup['normalize_whitespace']:
text = text.replace('\t', ' ').replace('\r', '\n')
if self.config.text_cleanup['lowercase']:
text = text.lower()
return text.strip()
@staticmethod
def get_supported_formats() -> List[str]:
"""获取支持的文件格式列表"""
return sorted(PaperMetadataExtractor.SUPPORTED_EXTENSIONS)
def extract_metadata(self, file_path: Union[str, Path], strategy: str = "fast") -> PaperMetadata:
"""提取论文元数据
Args:
file_path: 文件路径
strategy: 提取策略 ("fast""accurate")
Returns:
PaperMetadata: 提取的论文元数据
Raises:
Exception: 提取过程中的错误
"""
try:
path = self._validate_file(file_path)
self.logger.info(f"正在处理: {path}")
# 使用unstructured库分解文档
elements = partition(
str(path),
strategy=strategy,
include_metadata=True,
nlp=False,
)
# 提取元数据
metadata = PaperMetadata()
# 提取标题和作者
self._extract_title_and_authors(elements, metadata)
# 提取摘要和关键词
self._extract_abstract_and_keywords(elements, metadata)
# 提取其他元数据
self._extract_additional_metadata(elements, metadata)
return metadata
except Exception as e:
self.logger.error(f"元数据提取失败: {e}")
raise
def _extract_title_and_authors(self, elements, metadata: PaperMetadata) -> None:
"""从文档中提取标题和作者信息 - 改进版"""
# 收集所有潜在的标题候选
title_candidates = []
all_text = []
raw_text = []
# 首先收集文档前30个元素的文本,用于辅助判断
for i, element in enumerate(elements[:30]):
if isinstance(element, (Text, Title, NarrativeText)):
text = str(element).strip()
if text:
all_text.append(text)
raw_text.append(text)
# 打印出原始文本,用于调试
print("原始文本前10行:")
for i, text in enumerate(raw_text[:10]):
print(f"{i}: {text}")
# 1. 尝试查找连续的标题片段并合并它们
i = 0
while i < len(all_text) - 1:
current = all_text[i]
next_text = all_text[i + 1]
# 检查是否存在标题分割情况:一行以冒号结尾,下一行像是标题的延续
if current.endswith(':') and len(current) < 50 and len(next_text) > 5 and next_text[0].isupper():
# 合并这两行文本
combined_title = f"{current} {next_text}"
# 查找合并前的文本并替换
all_text[i] = combined_title
all_text.pop(i + 1)
# 给合并后的标题很高的分数
title_candidates.append((combined_title, 15, i))
else:
i += 1
# 2. 首先尝试从标题元素中查找
for i, element in enumerate(elements[:15]): # 只检查前15个元素
if isinstance(element, Title):
title_text = str(element).strip()
# 排除常见的非标题内容
if title_text.lower() not in ['abstract', '摘要', 'introduction', '引言']:
# 计算标题分数(越高越可能是真正的标题)
score = self._evaluate_title_candidate(title_text, i, element)
title_candidates.append((title_text, score, i))
# 3. 特别处理常见的论文标题格式
for i, text in enumerate(all_text[:15]):
# 特别检查"KIMI K1.5:"类型的前缀标题
if re.match(r'^[A-Z][A-Z0-9\s\.]+(\s+K\d+(\.\d+)?)?:', text):
score = 12 # 给予很高的分数
title_candidates.append((text, score, i))
# 如果下一行也是全大写,很可能是标题的延续
if i+1 < len(all_text) and all_text[i+1].isupper() and len(all_text[i+1]) > 10:
combined_title = f"{text} {all_text[i+1]}"
title_candidates.append((combined_title, 15, i)) # 给合并标题更高分数
# 匹配全大写的标题行
elif text.isupper() and len(text) > 10 and len(text) < 100:
score = 10 - i * 0.5 # 越靠前越可能是标题
title_candidates.append((text, score, i))
# 对标题候选按分数排序并选取最佳候选
if title_candidates:
title_candidates.sort(key=lambda x: x[1], reverse=True)
metadata.title = title_candidates[0][0]
title_position = title_candidates[0][2]
print(f"所有标题候选: {title_candidates[:3]}")
else:
# 如果没有找到合适的标题,使用一个备选策略
for text in all_text[:10]:
if text.isupper() and len(text) > 10 and len(text) < 200: # 大写且适当长度的文本
metadata.title = text
break
title_position = 0
# 提取作者信息 - 改进后的作者提取逻辑
author_candidates = []
# 1. 特别处理"TECHNICAL REPORT OF"之后的行,通常是作者或团队
for i, text in enumerate(all_text):
if "TECHNICAL REPORT" in text.upper() and i+1 < len(all_text):
team_text = all_text[i+1].strip()
if re.search(r'\b(team|group|lab)\b', team_text, re.IGNORECASE):
author_candidates.append((team_text, 15))
# 2. 查找包含Team的文本
for text in all_text[:20]:
if "Team" in text and len(text) < 30:
# 这很可能是团队名
author_candidates.append((text, 12))
# 添加作者到元数据
if author_candidates:
# 按分数排序
author_candidates.sort(key=lambda x: x[1], reverse=True)
# 去重
seen_authors = set()
for author, _ in author_candidates:
if author.lower() not in seen_authors and not author.isdigit():
seen_authors.add(author.lower())
metadata.authors.append(author)
# 如果没有找到作者,尝试查找隶属机构信息中的团队名称
if not metadata.authors:
for text in all_text[:20]:
if re.search(r'\b(team|group|lab|laboratory|研究组|团队)\b', text, re.IGNORECASE):
if len(text) < 50: # 避免太长的文本
metadata.authors.append(text.strip())
break
# 提取隶属机构信息
for i, element in enumerate(elements[:30]):
element_text = str(element).strip()
if re.search(r'(university|institute|department|school|laboratory|college|center|centre|\d{5,}|^[a-zA-Z]+@|学院|大学|研究所|研究院)', element_text, re.IGNORECASE):
# 可能是隶属机构
if element_text not in metadata.affiliations and len(element_text) > 10:
metadata.affiliations.append(element_text)
def _evaluate_title_candidate(self, text, position, element):
"""评估标题候选项的可能性分数"""
score = 0
# 位置因素:越靠前越可能是标题
score += max(0, 10 - position) * 0.5
# 长度因素:标题通常不会太短也不会太长
if 10 <= len(text) <= 150:
score += 3
elif len(text) < 10:
score -= 2
elif len(text) > 150:
score -= 3
# 格式因素
if text.isupper(): # 全大写可能是标题
score += 2
if re.match(r'^[A-Z]', text): # 首字母大写
score += 1
if ':' in text: # 标题常包含冒号
score += 1.5
# 内容因素
if re.search(r'\b(scaling|learning|model|approach|method|system|framework|analysis)\b', text.lower()):
score += 2 # 包含常见的学术论文关键词
# 避免误判
if re.match(r'^\d+$', text): # 纯数字
score -= 10
if re.search(r'^(http|www|doi)', text.lower()): # URL或DOI
score -= 5
if len(text.split()) <= 2 and len(text) < 15: # 太短的短语
score -= 3
# 元数据因素(如果有)
if hasattr(element, 'metadata') and element.metadata:
# 修复正确处理ElementMetadata对象
try:
# 尝试通过getattr安全地获取属性
font_size = getattr(element.metadata, 'font_size', None)
if font_size is not None and font_size > 14: # 假设标准字体大小是12
score += 3
font_weight = getattr(element.metadata, 'font_weight', None)
if font_weight == 'bold':
score += 2 # 粗体加分
except (AttributeError, TypeError):
# 如果metadata的访问方式不正确,尝试其他可能的访问方式
try:
metadata_dict = element.metadata.__dict__ if hasattr(element.metadata, '__dict__') else {}
if 'font_size' in metadata_dict and metadata_dict['font_size'] > 14:
score += 3
if 'font_weight' in metadata_dict and metadata_dict['font_weight'] == 'bold':
score += 2
except Exception:
# 如果所有尝试都失败,忽略元数据处理
pass
return score
def _extract_abstract_and_keywords(self, elements, metadata: PaperMetadata) -> None:
"""从文档中提取摘要和关键词"""
abstract_found = False
keywords_found = False
abstract_text = []
for i, element in enumerate(elements):
element_text = str(element).strip().lower()
# 寻找摘要部分
if not abstract_found and (
isinstance(element, Title) and
re.search(self.SECTION_PATTERNS['abstract'], element_text, re.IGNORECASE)
):
abstract_found = True
continue
# 如果找到摘要部分,收集内容直到遇到关键词部分或新章节
if abstract_found and not keywords_found:
# 检查是否遇到关键词部分或新章节
if (
isinstance(element, Title) or
re.search(self.SECTION_PATTERNS['keywords'], element_text, re.IGNORECASE) or
re.match(r'\b(introduction|引言|method|方法)\b', element_text, re.IGNORECASE)
):
keywords_found = re.search(self.SECTION_PATTERNS['keywords'], element_text, re.IGNORECASE)
abstract_found = False # 停止收集摘要
else:
# 收集摘要文本
if isinstance(element, (Text, NarrativeText)) and element_text:
abstract_text.append(element_text)
# 如果找到关键词部分,提取关键词
if keywords_found and not abstract_found and not metadata.keywords:
if isinstance(element, (Text, NarrativeText)):
# 清除可能的"关键词:"/"Keywords:"前缀
cleaned_text = re.sub(r'^\s*(关键词|keywords|key\s+words)\s*[:]\s*', '', element_text, flags=re.IGNORECASE)
# 尝试按不同分隔符分割
for separator in [';', '', ',', '']:
if separator in cleaned_text:
metadata.keywords = [k.strip() for k in cleaned_text.split(separator) if k.strip()]
break
# 如果未能分割,将整个文本作为一个关键词
if not metadata.keywords and cleaned_text:
metadata.keywords = [cleaned_text]
keywords_found = False # 已提取关键词,停止处理
# 设置摘要文本
if abstract_text:
metadata.abstract = self.config.paragraph_separator.join(abstract_text)
def _extract_additional_metadata(self, elements, metadata: PaperMetadata) -> None:
"""提取其他元数据信息"""
for element in elements[:30]: # 只检查文档前部分
element_text = str(element).strip()
# 尝试匹配DOI
doi_match = re.search(r'(doi|DOI):\s*(10\.\d{4,}\/[a-zA-Z0-9.-]+)', element_text)
if doi_match and not metadata.doi:
metadata.doi = doi_match.group(2)
# 尝试匹配日期
date_match = re.search(r'(published|received|accepted|submitted):\s*(\d{1,2}\s+[a-zA-Z]+\s+\d{4}|\d{4}[-/]\d{1,2}[-/]\d{1,2})', element_text, re.IGNORECASE)
if date_match and not metadata.date:
metadata.date = date_match.group(2)
# 尝试匹配年份
year_match = re.search(r'\b(19|20)\d{2}\b', element_text)
if year_match and not metadata.year:
metadata.year = year_match.group(0)
# 尝试匹配期刊/会议名称
journal_match = re.search(r'(journal|conference):\s*([^,;.]+)', element_text, re.IGNORECASE)
if journal_match:
if "journal" in journal_match.group(1).lower() and not metadata.journal:
metadata.journal = journal_match.group(2).strip()
elif not metadata.conference:
metadata.conference = journal_match.group(2).strip()
def main():
"""主函数:演示用法"""
# 创建提取器
extractor = PaperMetadataExtractor()
# 使用示例
try:
# 替换为实际的文件路径
sample_file = '/Users/boyin.liu/Documents/示例文档/论文/3.pdf'
if Path(sample_file).exists():
metadata = extractor.extract_metadata(sample_file)
print("提取的元数据:")
print(f"标题: {metadata.title}")
print(f"作者: {', '.join(metadata.authors)}")
print(f"机构: {', '.join(metadata.affiliations)}")
print(f"摘要: {metadata.abstract[:200]}...")
print(f"关键词: {', '.join(metadata.keywords)}")
print(f"DOI: {metadata.doi}")
print(f"日期: {metadata.date}")
print(f"年份: {metadata.year}")
print(f"期刊: {metadata.journal}")
print(f"会议: {metadata.conference}")
else:
print(f"示例文件 {sample_file} 不存在")
print("\n支持的格式:", extractor.get_supported_formats())
except Exception as e:
print(f"错误: {e}")
if __name__ == "__main__":
main()

查看文件

@@ -0,0 +1,86 @@
from pathlib import Path
from crazy_functions.doc_fns.read_fns.unstructured_all.paper_structure_extractor import PaperStructureExtractor
def extract_and_save_as_markdown(paper_path, output_path=None):
"""
提取论文结构并保存为Markdown格式
参数:
paper_path: 论文文件路径
output_path: 输出的Markdown文件路径,如果不指定,将使用与输入相同的文件名但扩展名为.md
返回:
保存的Markdown文件路径
"""
# 创建提取器
extractor = PaperStructureExtractor()
# 解析文件路径
paper_path = Path(paper_path)
# 如果未指定输出路径,使用相同文件名但扩展名为.md
if output_path is None:
output_path = paper_path.with_suffix('.md')
else:
output_path = Path(output_path)
# 确保输出目录存在
output_path.parent.mkdir(parents=True, exist_ok=True)
print(f"正在处理论文: {paper_path}")
try:
# 提取论文结构
paper = extractor.extract_paper_structure(paper_path)
# 生成Markdown内容
markdown_content = extractor.generate_markdown(paper)
# 保存到文件
with open(output_path, 'w', encoding='utf-8') as f:
f.write(markdown_content)
print(f"已成功保存Markdown文件: {output_path}")
# 打印摘要信息
print("\n论文摘要信息:")
print(f"标题: {paper.metadata.title}")
print(f"作者: {', '.join(paper.metadata.authors)}")
print(f"关键词: {', '.join(paper.keywords)}")
print(f"章节数: {len(paper.sections)}")
print(f"图表数: {len(paper.figures)}")
print(f"表格数: {len(paper.tables)}")
print(f"公式数: {len(paper.formulas)}")
print(f"参考文献数: {len(paper.references)}")
return output_path
except Exception as e:
print(f"处理论文时出错: {e}")
import traceback
traceback.print_exc()
return None
# 使用示例
if __name__ == "__main__":
# 替换为实际的论文文件路径
sample_paper = "crazy_functions/doc_fns/read_fns/paper/2501.12599v1.pdf"
# 可以指定输出路径,也可以使用默认路径
# output_file = "/path/to/output/paper_structure.md"
# extract_and_save_as_markdown(sample_paper, output_file)
# 使用默认输出路径(与输入文件同名但扩展名为.md
extract_and_save_as_markdown(sample_paper)
# # 批量处理多个论文的示例
# paper_dir = Path("/path/to/papers/folder")
# output_dir = Path("/path/to/output/folder")
#
# # 确保输出目录存在
# output_dir.mkdir(parents=True, exist_ok=True)
#
# # 处理目录中的所有PDF文件
# for paper_file in paper_dir.glob("*.pdf"):
# output_file = output_dir / f"{paper_file.stem}.md"
# extract_and_save_as_markdown(paper_file, output_file)

查看文件

@@ -0,0 +1,275 @@
from __future__ import annotations
from pathlib import Path
from typing import Optional, Set, Dict, Union, List
from dataclasses import dataclass, field
import logging
import os
from unstructured.partition.auto import partition
from unstructured.documents.elements import (
Text, Title, NarrativeText, ListItem, Table,
Footer, Header, PageBreak, Image, Address
)
@dataclass
class TextExtractorConfig:
"""通用文档提取器配置类
Attributes:
extract_headers_footers: 是否提取页眉页脚
extract_tables: 是否提取表格内容
extract_lists: 是否提取列表内容
extract_titles: 是否提取标题
paragraph_separator: 段落之间的分隔符
text_cleanup: 文本清理选项字典
"""
extract_headers_footers: bool = False
extract_tables: bool = True
extract_lists: bool = True
extract_titles: bool = True
paragraph_separator: str = '\n\n'
text_cleanup: Dict[str, bool] = field(default_factory=lambda: {
'remove_extra_spaces': True,
'normalize_whitespace': True,
'remove_special_chars': False,
'lowercase': False
})
class UnstructuredTextExtractor:
"""通用文档文本内容提取器
使用 unstructured 库支持多种文档格式的文本提取,提供统一的接口和配置选项。
"""
SUPPORTED_EXTENSIONS: Set[str] = {
# 文档格式
'.pdf', '.docx', '.doc', '.txt',
# 演示文稿
'.ppt', '.pptx',
# 电子表格
'.xlsx', '.xls', '.csv',
# 图片
'.png', '.jpg', '.jpeg', '.tiff',
# 邮件
'.eml', '.msg', '.p7s',
# Markdown
".md",
# Org Mode
".org",
# Open Office
".odt",
# reStructured Text
".rst",
# Rich Text
".rtf",
# TSV
".tsv",
# EPUB
'.epub',
# 其他格式
'.html', '.xml', '.json',
}
def __init__(self, config: Optional[TextExtractorConfig] = None):
"""初始化提取器
Args:
config: 提取器配置对象,如果为None则使用默认配置
"""
self.config = config or TextExtractorConfig()
self._setup_logging()
def _setup_logging(self) -> None:
"""配置日志记录器"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
self.logger = logging.getLogger(__name__)
# 添加文件处理器
fh = logging.FileHandler('text_extractor.log')
fh.setLevel(logging.ERROR)
self.logger.addHandler(fh)
def _validate_file(self, file_path: Union[str, Path], max_size_mb: int = 100) -> Path:
"""验证文件
Args:
file_path: 文件路径
max_size_mb: 允许的最大文件大小(MB)
Returns:
Path: 验证后的Path对象
Raises:
ValueError: 文件不存在、格式不支持或大小超限
PermissionError: 没有读取权限
"""
path = Path(file_path).resolve()
if not path.exists():
raise ValueError(f"File not found: {path}")
if not path.is_file():
raise ValueError(f"Not a file: {path}")
if not os.access(path, os.R_OK):
raise PermissionError(f"No read permission: {path}")
file_size_mb = path.stat().st_size / (1024 * 1024)
if file_size_mb > max_size_mb:
raise ValueError(
f"File size ({file_size_mb:.1f}MB) exceeds limit of {max_size_mb}MB"
)
if path.suffix.lower() not in self.SUPPORTED_EXTENSIONS:
raise ValueError(
f"Unsupported format: {path.suffix}. "
f"Supported: {', '.join(sorted(self.SUPPORTED_EXTENSIONS))}"
)
return path
def _cleanup_text(self, text: str) -> str:
"""清理文本
Args:
text: 原始文本
Returns:
str: 清理后的文本
"""
if self.config.text_cleanup['remove_extra_spaces']:
text = ' '.join(text.split())
if self.config.text_cleanup['normalize_whitespace']:
text = text.replace('\t', ' ').replace('\r', '\n')
if self.config.text_cleanup['lowercase']:
text = text.lower()
return text.strip()
def _should_extract_element(self, element) -> bool:
"""判断是否应该提取某个元素
Args:
element: 文档元素
Returns:
bool: 是否应该提取
"""
if isinstance(element, (Text, NarrativeText)):
return True
if isinstance(element, Title) and self.config.extract_titles:
return True
if isinstance(element, ListItem) and self.config.extract_lists:
return True
if isinstance(element, Table) and self.config.extract_tables:
return True
if isinstance(element, (Header, Footer)) and self.config.extract_headers_footers:
return True
return False
@staticmethod
def get_supported_formats() -> List[str]:
"""获取支持的文件格式列表"""
return sorted(UnstructuredTextExtractor.SUPPORTED_EXTENSIONS)
def extract_text(
self,
file_path: Union[str, Path],
strategy: str = "fast"
) -> str:
"""提取文本
Args:
file_path: 文件路径
strategy: 提取策略 ("fast""accurate")
Returns:
str: 提取的文本内容
Raises:
Exception: 提取过程中的错误
"""
try:
path = self._validate_file(file_path)
self.logger.info(f"Processing: {path}")
# 修改这里:添加 nlp=False 参数来禁用 NLTK
elements = partition(
str(path),
strategy=strategy,
include_metadata=True,
nlp=True,
)
# 其余代码保持不变
text_parts = []
for element in elements:
if self._should_extract_element(element):
text = str(element)
cleaned_text = self._cleanup_text(text)
if cleaned_text:
if isinstance(element, (Header, Footer)):
prefix = "[Header] " if isinstance(element, Header) else "[Footer] "
text_parts.append(f"{prefix}{cleaned_text}")
else:
text_parts.append(cleaned_text)
return self.config.paragraph_separator.join(text_parts)
except Exception as e:
self.logger.error(f"Extraction failed: {e}")
raise
def main():
"""主函数:演示用法"""
# 配置
config = TextExtractorConfig(
extract_headers_footers=True,
extract_tables=True,
extract_lists=True,
extract_titles=True,
text_cleanup={
'remove_extra_spaces': True,
'normalize_whitespace': True,
'remove_special_chars': False,
'lowercase': False
}
)
# 创建提取器
extractor = UnstructuredTextExtractor(config)
# 使用示例
try:
# 替换为实际的文件路径
sample_file = './crazy_functions/doc_fns/read_fns/paper/2501.12599v1.pdf'
if Path(sample_file).exists() or True:
text = extractor.extract_text(sample_file)
print("提取的文本:")
print(text)
else:
print(f"示例文件 {sample_file} 不存在")
print("\n支持的格式:", extractor.get_supported_formats())
except Exception as e:
print(f"错误: {e}")
if __name__ == "__main__":
main()

查看文件

@@ -0,0 +1,219 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, Optional, Union
from urllib.parse import urlparse
import logging
import trafilatura
import requests
from pathlib import Path
@dataclass
class WebExtractorConfig:
"""网页内容提取器配置类
Attributes:
extract_comments: 是否提取评论
extract_tables: 是否提取表格
extract_links: 是否保留链接信息
paragraph_separator: 段落分隔符
timeout: 网络请求超时时间(秒)
max_retries: 最大重试次数
user_agent: 自定义User-Agent
text_cleanup: 文本清理选项
"""
extract_comments: bool = False
extract_tables: bool = True
extract_links: bool = False
paragraph_separator: str = '\n\n'
timeout: int = 10
max_retries: int = 3
user_agent: str = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
text_cleanup: Dict[str, bool] = field(default_factory=lambda: {
'remove_extra_spaces': True,
'normalize_whitespace': True,
'remove_special_chars': False,
'lowercase': False
})
class WebTextExtractor:
"""网页文本内容提取器
使用trafilatura库提取网页中的主要文本内容,去除广告、导航等无关内容。
"""
def __init__(self, config: Optional[WebExtractorConfig] = None):
"""初始化提取器
Args:
config: 提取器配置对象,如果为None则使用默认配置
"""
self.config = config or WebExtractorConfig()
self._setup_logging()
def _setup_logging(self) -> None:
"""配置日志记录器"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
self.logger = logging.getLogger(__name__)
# 添加文件处理器
fh = logging.FileHandler('web_extractor.log')
fh.setLevel(logging.ERROR)
self.logger.addHandler(fh)
def _validate_url(self, url: str) -> bool:
"""验证URL格式是否有效
Args:
url: 网页URL
Returns:
bool: URL是否有效
"""
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except Exception:
return False
def _download_webpage(self, url: str) -> Optional[str]:
"""下载网页内容
Args:
url: 网页URL
Returns:
Optional[str]: 网页HTML内容,失败返回None
Raises:
Exception: 下载失败时抛出异常
"""
headers = {'User-Agent': self.config.user_agent}
for attempt in range(self.config.max_retries):
try:
response = requests.get(
url,
headers=headers,
timeout=self.config.timeout
)
response.raise_for_status()
return response.text
except requests.RequestException as e:
self.logger.warning(f"Attempt {attempt + 1} failed: {e}")
if attempt == self.config.max_retries - 1:
raise Exception(f"Failed to download webpage after {self.config.max_retries} attempts: {e}")
return None
def _cleanup_text(self, text: str) -> str:
"""清理文本
Args:
text: 原始文本
Returns:
str: 清理后的文本
"""
if not text:
return ""
if self.config.text_cleanup['remove_extra_spaces']:
text = ' '.join(text.split())
if self.config.text_cleanup['normalize_whitespace']:
text = text.replace('\t', ' ').replace('\r', '\n')
if self.config.text_cleanup['lowercase']:
text = text.lower()
return text.strip()
def extract_text(self, url: str) -> str:
"""提取网页文本内容
Args:
url: 网页URL
Returns:
str: 提取的文本内容
Raises:
ValueError: URL无效时抛出
Exception: 提取失败时抛出
"""
try:
if not self._validate_url(url):
raise ValueError(f"Invalid URL: {url}")
self.logger.info(f"Processing URL: {url}")
# 下载网页
html_content = self._download_webpage(url)
if not html_content:
raise Exception("Failed to download webpage")
# 配置trafilatura提取选项
extract_config = {
'include_comments': self.config.extract_comments,
'include_tables': self.config.extract_tables,
'include_links': self.config.extract_links,
'no_fallback': False, # 允许使用后备提取器
}
# 提取文本
extracted_text = trafilatura.extract(
html_content,
**extract_config
)
if not extracted_text:
raise Exception("No content could be extracted")
# 清理文本
cleaned_text = self._cleanup_text(extracted_text)
return cleaned_text
except Exception as e:
self.logger.error(f"Extraction failed: {e}")
raise
def main():
"""主函数:演示用法"""
# 配置
config = WebExtractorConfig(
extract_comments=False,
extract_tables=True,
extract_links=False,
timeout=10,
text_cleanup={
'remove_extra_spaces': True,
'normalize_whitespace': True,
'remove_special_chars': False,
'lowercase': False
}
)
# 创建提取器
extractor = WebTextExtractor(config)
# 使用示例
try:
# 替换为实际的URL
sample_url = 'https://arxiv.org/abs/2412.00036'
text = extractor.extract_text(sample_url)
print("提取的文本:")
print(text)
except Exception as e:
print(f"错误: {e}")
if __name__ == "__main__":
main()

查看文件

@@ -0,0 +1,451 @@
import os
import re
import glob
import time
import queue
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Generator, Tuple, Set, Optional, Dict
from dataclasses import dataclass
from loguru import logger
from toolbox import update_ui
from crazy_functions.rag_fns.rag_file_support import extract_text
from crazy_functions.doc_fns.content_folder import ContentFoldingManager, FileMetadata, FoldingOptions, FoldingStyle, FoldingError
from shared_utils.fastapi_server import validate_path_safety
from datetime import datetime
import mimetypes
@dataclass
class FileInfo:
"""文件信息数据类"""
path: str # 完整路径
rel_path: str # 相对路径
size: float # 文件大小(MB)
extension: str # 文件扩展名
last_modified: str # 最后修改时间
class TextContentLoader:
"""优化版本的文本内容加载器 - 保持原有接口"""
# 压缩文件扩展名
COMPRESSED_EXTENSIONS: Set[str] = {'.zip', '.rar', '.7z', '.tar', '.gz', '.bz2', '.xz'}
# 系统配置
MAX_FILE_SIZE: int = 100 * 1024 * 1024 # 最大文件大小100MB
MAX_TOTAL_SIZE: int = 100 * 1024 * 1024 # 最大总大小100MB
MAX_FILES: int = 100 # 最大文件数量
CHUNK_SIZE: int = 1024 * 1024 # 文件读取块大小1MB
MAX_WORKERS: int = min(32, (os.cpu_count() or 1) * 4) # 最大工作线程数
BATCH_SIZE: int = 5 # 批处理大小
def __init__(self, chatbot: List, history: List):
"""初始化加载器"""
self.chatbot = chatbot
self.history = history
self.failed_files: List[Tuple[str, str]] = []
self.processed_size: int = 0
self.start_time: float = 0
self.file_cache: Dict[str, str] = {}
self._lock = threading.Lock()
self.executor = ThreadPoolExecutor(max_workers=self.MAX_WORKERS)
self.results_queue = queue.Queue()
self.folding_manager = ContentFoldingManager()
def _create_file_info(self, entry: os.DirEntry, root_path: str) -> FileInfo:
"""优化的文件信息创建
Args:
entry: 目录入口对象
root_path: 根路径
Returns:
FileInfo: 文件信息对象
"""
try:
stats = entry.stat() # 使用缓存的文件状态
return FileInfo(
path=entry.path,
rel_path=os.path.relpath(entry.path, root_path),
size=stats.st_size / (1024 * 1024),
extension=os.path.splitext(entry.path)[1].lower(),
last_modified=time.strftime('%Y-%m-%d %H:%M:%S',
time.localtime(stats.st_mtime))
)
except (OSError, ValueError) as e:
return None
def _process_file_batch(self, file_batch: List[FileInfo]) -> List[Tuple[FileInfo, Optional[str]]]:
"""批量处理文件
Args:
file_batch: 要处理的文件信息列表
Returns:
List[Tuple[FileInfo, Optional[str]]]: 处理结果列表
"""
results = []
futures = {}
for file_info in file_batch:
if file_info.path in self.file_cache:
results.append((file_info, self.file_cache[file_info.path]))
continue
if file_info.size * 1024 * 1024 > self.MAX_FILE_SIZE:
with self._lock:
self.failed_files.append(
(file_info.rel_path,
f"文件过大({file_info.size:.2f}MB > {self.MAX_FILE_SIZE / (1024 * 1024)}MB")
)
continue
future = self.executor.submit(self._read_file_content, file_info)
futures[future] = file_info
for future in as_completed(futures):
file_info = futures[future]
try:
content = future.result()
if content:
with self._lock:
self.file_cache[file_info.path] = content
self.processed_size += file_info.size * 1024 * 1024
results.append((file_info, content))
except Exception as e:
with self._lock:
self.failed_files.append((file_info.rel_path, f"读取失败: {str(e)}"))
return results
def _read_file_content(self, file_info: FileInfo) -> Optional[str]:
"""读取单个文件内容
Args:
file_info: 文件信息对象
Returns:
Optional[str]: 文件内容
"""
try:
content = extract_text(file_info.path)
if not content or not content.strip():
return None
return content
except Exception as e:
logger.exception(f"读取文件失败: {str(e)}")
raise Exception(f"读取文件失败: {str(e)}")
def _is_valid_file(self, file_path: str) -> bool:
"""检查文件是否有效
Args:
file_path: 文件路径
Returns:
bool: 是否为有效文件
"""
if not os.path.isfile(file_path):
return False
extension = os.path.splitext(file_path)[1].lower()
if (extension in self.COMPRESSED_EXTENSIONS or
os.path.basename(file_path).startswith('.') or
not os.access(file_path, os.R_OK)):
return False
# 只要文件可以访问且不在排除列表中就认为是有效的
return True
def _collect_files(self, path: str) -> List[FileInfo]:
"""收集文件信息
Args:
path: 目标路径
Returns:
List[FileInfo]: 有效文件信息列表
"""
files = []
total_size = 0
# 处理单个文件的情况
if os.path.isfile(path):
if self._is_valid_file(path):
file_info = self._create_file_info(os.DirEntry(os.path.dirname(path)), os.path.dirname(path))
if file_info:
return [file_info]
return []
# 处理目录的情况
try:
# 使用os.walk来递归遍历目录
for root, _, filenames in os.walk(path):
for filename in filenames:
if len(files) >= self.MAX_FILES:
self.failed_files.append((filename, f"超出最大文件数限制({self.MAX_FILES})"))
continue
file_path = os.path.join(root, filename)
if not self._is_valid_file(file_path):
continue
try:
stats = os.stat(file_path)
file_size = stats.st_size / (1024 * 1024) # 转换为MB
if file_size * 1024 * 1024 > self.MAX_FILE_SIZE:
self.failed_files.append((file_path,
f"文件过大({file_size:.2f}MB > {self.MAX_FILE_SIZE / (1024 * 1024)}MB"))
continue
if total_size + file_size * 1024 * 1024 > self.MAX_TOTAL_SIZE:
self.failed_files.append((file_path, "超出总大小限制"))
continue
file_info = FileInfo(
path=file_path,
rel_path=os.path.relpath(file_path, path),
size=file_size,
extension=os.path.splitext(file_path)[1].lower(),
last_modified=time.strftime('%Y-%m-%d %H:%M:%S',
time.localtime(stats.st_mtime))
)
total_size += file_size * 1024 * 1024
files.append(file_info)
except Exception as e:
self.failed_files.append((file_path, f"处理文件失败: {str(e)}"))
continue
except Exception as e:
self.failed_files.append(("目录扫描", f"扫描失败: {str(e)}"))
return []
return sorted(files, key=lambda x: x.rel_path)
def _format_content_with_fold(self, file_info, content: str) -> str:
"""使用折叠管理器格式化文件内容"""
try:
metadata = FileMetadata(
rel_path=file_info.rel_path,
size=file_info.size,
last_modified=datetime.fromtimestamp(
os.path.getmtime(file_info.path)
),
mime_type=mimetypes.guess_type(file_info.path)[0]
)
options = FoldingOptions(
style=FoldingStyle.DETAILED,
code_language=self.folding_manager._guess_language(
os.path.splitext(file_info.path)[1]
),
show_timestamp=True
)
return self.folding_manager.format_content(
content=content,
formatter_type='file',
metadata=metadata,
options=options
)
except Exception as e:
return f"Error formatting content: {str(e)}"
def _format_content_for_llm(self, file_infos: List[FileInfo], contents: List[str]) -> str:
"""格式化用于LLM的内容
Args:
file_infos: 文件信息列表
contents: 内容列表
Returns:
str: 格式化后的内容
"""
if len(file_infos) != len(contents):
raise ValueError("文件信息和内容数量不匹配")
result = [
"以下是多个文件的内容集合。每个文件的内容都以 '===== 文件 {序号}: {文件名} =====' 开始,",
"'===== 文件 {序号} 结束 =====' 结束。你可以根据这些分隔符来识别不同文件的内容。\n\n"
]
for idx, (file_info, content) in enumerate(zip(file_infos, contents), 1):
result.extend([
f"===== 文件 {idx}: {file_info.rel_path} =====",
"文件内容:",
content.strip(),
f"===== 文件 {idx} 结束 =====\n"
])
return "\n".join(result)
def execute(self, txt: str) -> Generator:
"""执行文本加载和显示 - 保持原有接口
Args:
txt: 目标路径
Yields:
Generator: UI更新生成器
"""
try:
# 首先显示正在处理的提示信息
self.chatbot.append(["提示", "正在提取文本内容,请稍作等待..."])
yield from update_ui(chatbot=self.chatbot, history=self.history)
user_name = self.chatbot.get_user()
validate_path_safety(txt, user_name)
self.start_time = time.time()
self.processed_size = 0
self.failed_files.clear()
successful_files = []
successful_contents = []
# 收集文件
files = self._collect_files(txt)
if not files:
# 移除之前的提示信息
self.chatbot.pop()
self.chatbot.append(["提示", "未找到任何有效文件"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return
# 批量处理文件
content_blocks = []
for i in range(0, len(files), self.BATCH_SIZE):
batch = files[i:i + self.BATCH_SIZE]
results = self._process_file_batch(batch)
for file_info, content in results:
if content:
content_blocks.append(self._format_content_with_fold(file_info, content))
successful_files.append(file_info)
successful_contents.append(content)
# 显示文件内容,替换之前的提示信息
if content_blocks:
# 移除之前的提示信息
self.chatbot.pop()
self.chatbot.append(["文件内容", "\n".join(content_blocks)])
self.history.extend([
self._format_content_for_llm(successful_files, successful_contents),
"我已经接收到你上传的文件的内容,请提问"
])
yield from update_ui(chatbot=self.chatbot, history=self.history)
yield from update_ui(chatbot=self.chatbot, history=self.history)
except Exception as e:
# 发生错误时,移除之前的提示信息
if len(self.chatbot) > 0 and self.chatbot[-1][0] == "提示":
self.chatbot.pop()
self.chatbot.append(["错误", f"处理过程中出现错误: {str(e)}"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
finally:
self.executor.shutdown(wait=False)
self.file_cache.clear()
def execute_single_file(self, file_path: str) -> Generator:
"""执行单个文件的加载和显示
Args:
file_path: 文件路径
Yields:
Generator: UI更新生成器
"""
try:
# 首先显示正在处理的提示信息
self.chatbot.append(["提示", "正在提取文本内容,请稍作等待..."])
yield from update_ui(chatbot=self.chatbot, history=self.history)
user_name = self.chatbot.get_user()
validate_path_safety(file_path, user_name)
self.start_time = time.time()
self.processed_size = 0
self.failed_files.clear()
# 验证文件是否存在且可读
if not os.path.isfile(file_path):
self.chatbot.pop()
self.chatbot.append(["错误", f"指定路径不是文件: {file_path}"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return
if not self._is_valid_file(file_path):
self.chatbot.pop()
self.chatbot.append(["错误", f"无效的文件类型或无法读取: {file_path}"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return
# 创建文件信息
try:
stats = os.stat(file_path)
file_size = stats.st_size / (1024 * 1024) # 转换为MB
if file_size * 1024 * 1024 > self.MAX_FILE_SIZE:
self.chatbot.pop()
self.chatbot.append(["错误", f"文件过大({file_size:.2f}MB > {self.MAX_FILE_SIZE / (1024 * 1024)}MB"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return
file_info = FileInfo(
path=file_path,
rel_path=os.path.basename(file_path),
size=file_size,
extension=os.path.splitext(file_path)[1].lower(),
last_modified=time.strftime('%Y-%m-%d %H:%M:%S',
time.localtime(stats.st_mtime))
)
except Exception as e:
self.chatbot.pop()
self.chatbot.append(["错误", f"处理文件失败: {str(e)}"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return
# 读取文件内容
try:
content = self._read_file_content(file_info)
if not content:
self.chatbot.pop()
self.chatbot.append(["提示", f"文件内容为空或无法提取: {file_path}"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return
except Exception as e:
self.chatbot.pop()
self.chatbot.append(["错误", f"读取文件失败: {str(e)}"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return
# 格式化内容并更新UI
formatted_content = self._format_content_with_fold(file_info, content)
# 移除之前的提示信息
self.chatbot.pop()
self.chatbot.append(["文件内容", formatted_content])
# 更新历史记录,便于LLM处理
llm_content = self._format_content_for_llm([file_info], [content])
self.history.extend([llm_content, "我已经接收到你上传的文件的内容,请提问"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
except Exception as e:
# 发生错误时,移除之前的提示信息
if len(self.chatbot) > 0 and self.chatbot[-1][0] == "提示":
self.chatbot.pop()
self.chatbot.append(["错误", f"处理过程中出现错误: {str(e)}"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
def __del__(self):
"""析构函数 - 确保资源被正确释放"""
if hasattr(self, 'executor'):
self.executor.shutdown(wait=False)
if hasattr(self, 'file_cache'):
self.file_cache.clear()

查看文件

@@ -1,4 +1,4 @@
from toolbox import CatchException, update_ui, update_ui_lastest_msg
from toolbox import CatchException, update_ui, update_ui_latest_msg
from crazy_functions.multi_stage.multi_stage_utils import GptAcademicGameBaseState
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
from request_llms.bridge_all import predict_no_ui_long_connection
@@ -13,7 +13,7 @@ class MiniGame_ASCII_Art(GptAcademicGameBaseState):
else:
if prompt.strip() == 'exit':
self.delete_game = True
yield from update_ui_lastest_msg(lastmsg=f"谜底是{self.obj},游戏结束。", chatbot=chatbot, history=history, delay=0.)
yield from update_ui_latest_msg(lastmsg=f"谜底是{self.obj},游戏结束。", chatbot=chatbot, history=history, delay=0.)
return
chatbot.append([prompt, ""])
yield from update_ui(chatbot=chatbot, history=history)
@@ -31,12 +31,12 @@ class MiniGame_ASCII_Art(GptAcademicGameBaseState):
self.cur_task = 'identify user guess'
res = get_code_block(raw_res)
history += ['', f'the answer is {self.obj}', inputs, res]
yield from update_ui_lastest_msg(lastmsg=res, chatbot=chatbot, history=history, delay=0.)
yield from update_ui_latest_msg(lastmsg=res, chatbot=chatbot, history=history, delay=0.)
elif self.cur_task == 'identify user guess':
if is_same_thing(self.obj, prompt, self.llm_kwargs):
self.delete_game = True
yield from update_ui_lastest_msg(lastmsg="你猜对了!", chatbot=chatbot, history=history, delay=0.)
yield from update_ui_latest_msg(lastmsg="你猜对了!", chatbot=chatbot, history=history, delay=0.)
else:
self.cur_task = 'identify user guess'
yield from update_ui_lastest_msg(lastmsg="猜错了,再试试,输入“exit”获取答案。", chatbot=chatbot, history=history, delay=0.)
yield from update_ui_latest_msg(lastmsg="猜错了,再试试,输入“exit”获取答案。", chatbot=chatbot, history=history, delay=0.)

查看文件

@@ -63,7 +63,7 @@ prompts_terminate = """小说的前文回顾:
"""
from toolbox import CatchException, update_ui, update_ui_lastest_msg
from toolbox import CatchException, update_ui, update_ui_latest_msg
from crazy_functions.multi_stage.multi_stage_utils import GptAcademicGameBaseState
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
from request_llms.bridge_all import predict_no_ui_long_connection
@@ -112,7 +112,7 @@ class MiniGame_ResumeStory(GptAcademicGameBaseState):
if prompt.strip() == 'exit' or prompt.strip() == '结束剧情':
# should we terminate game here?
self.delete_game = True
yield from update_ui_lastest_msg(lastmsg=f"游戏结束。", chatbot=chatbot, history=history, delay=0.)
yield from update_ui_latest_msg(lastmsg=f"游戏结束。", chatbot=chatbot, history=history, delay=0.)
return
if '剧情收尾' in prompt:
self.cur_task = 'story_terminate'
@@ -137,8 +137,8 @@ class MiniGame_ResumeStory(GptAcademicGameBaseState):
)
self.story.append(story_paragraph)
# # 配图
yield from update_ui_lastest_msg(lastmsg=story_paragraph + '<br/>正在生成插图中 ...', chatbot=chatbot, history=history, delay=0.)
yield from update_ui_lastest_msg(lastmsg=story_paragraph + '<br/>'+ self.generate_story_image(story_paragraph), chatbot=chatbot, history=history, delay=0.)
yield from update_ui_latest_msg(lastmsg=story_paragraph + '<br/>正在生成插图中 ...', chatbot=chatbot, history=history, delay=0.)
yield from update_ui_latest_msg(lastmsg=story_paragraph + '<br/>'+ self.generate_story_image(story_paragraph), chatbot=chatbot, history=history, delay=0.)
# # 构建后续剧情引导
previously_on_story = ""
@@ -171,8 +171,8 @@ class MiniGame_ResumeStory(GptAcademicGameBaseState):
)
self.story.append(story_paragraph)
# # 配图
yield from update_ui_lastest_msg(lastmsg=story_paragraph + '<br/>正在生成插图中 ...', chatbot=chatbot, history=history, delay=0.)
yield from update_ui_lastest_msg(lastmsg=story_paragraph + '<br/>'+ self.generate_story_image(story_paragraph), chatbot=chatbot, history=history, delay=0.)
yield from update_ui_latest_msg(lastmsg=story_paragraph + '<br/>正在生成插图中 ...', chatbot=chatbot, history=history, delay=0.)
yield from update_ui_latest_msg(lastmsg=story_paragraph + '<br/>'+ self.generate_story_image(story_paragraph), chatbot=chatbot, history=history, delay=0.)
# # 构建后续剧情引导
previously_on_story = ""
@@ -204,8 +204,8 @@ class MiniGame_ResumeStory(GptAcademicGameBaseState):
chatbot, history_, self.sys_prompt_
)
# # 配图
yield from update_ui_lastest_msg(lastmsg=story_paragraph + '<br/>正在生成插图中 ...', chatbot=chatbot, history=history, delay=0.)
yield from update_ui_lastest_msg(lastmsg=story_paragraph + '<br/>'+ self.generate_story_image(story_paragraph), chatbot=chatbot, history=history, delay=0.)
yield from update_ui_latest_msg(lastmsg=story_paragraph + '<br/>正在生成插图中 ...', chatbot=chatbot, history=history, delay=0.)
yield from update_ui_latest_msg(lastmsg=story_paragraph + '<br/>'+ self.generate_story_image(story_paragraph), chatbot=chatbot, history=history, delay=0.)
# terminate game
self.delete_game = True

查看文件

@@ -2,7 +2,7 @@ import time
import importlib
from toolbox import trimmed_format_exc, gen_time_str, get_log_folder
from toolbox import CatchException, update_ui, gen_time_str, trimmed_format_exc, is_the_upload_folder
from toolbox import promote_file_to_downloadzone, get_log_folder, update_ui_lastest_msg
from toolbox import promote_file_to_downloadzone, get_log_folder, update_ui_latest_msg
import multiprocessing
def get_class_name(class_string):

查看文件

@@ -102,10 +102,10 @@ class GptJsonIO():
logging.info(f'Repairing json{response}')
repair_prompt = self.generate_repair_prompt(broken_json = response, error=repr(e))
result = self.generate_output(gpt_gen_fn(repair_prompt, self.format_instructions))
logging.info('Repaire json success.')
logging.info('Repair json success.')
except Exception as e:
# 没辙了,放弃治疗
logging.info('Repaire json fail.')
logging.info('Repair json fail.')
raise JsonStringError('Cannot repair json.', str(e))
return result

查看文件

@@ -3,7 +3,7 @@ import re
import shutil
import numpy as np
from loguru import logger
from toolbox import update_ui, update_ui_lastest_msg, get_log_folder, gen_time_str
from toolbox import update_ui, update_ui_latest_msg, get_log_folder, gen_time_str
from toolbox import get_conf, promote_file_to_downloadzone
from crazy_functions.latex_fns.latex_toolbox import PRESERVE, TRANSFORM
from crazy_functions.latex_fns.latex_toolbox import set_forbidden_text, set_forbidden_text_begin_end, set_forbidden_text_careful_brace
@@ -20,7 +20,7 @@ def split_subprocess(txt, project_folder, return_dict, opts):
"""
break down latex file to a linked list,
each node use a preserve flag to indicate whether it should
be proccessed by GPT.
be processed by GPT.
"""
text = txt
mask = np.zeros(len(txt), dtype=np.uint8) + TRANSFORM
@@ -85,14 +85,14 @@ class LatexPaperSplit():
"""
break down latex file to a linked list,
each node use a preserve flag to indicate whether it should
be proccessed by GPT.
be processed by GPT.
"""
def __init__(self) -> None:
self.nodes = None
self.msg = "*{\\scriptsize\\textbf{警告该PDF由GPT-Academic开源项目调用大语言模型+Latex翻译插件一键生成," + \
"版权归原文作者所有。翻译内容可靠性无保障,请仔细鉴别并以原文为准。" + \
"项目Github地址 \\url{https://github.com/binary-husky/gpt_academic/}。"
# 请您不要删除或修改这行警告,除非您是论文的原作者如果您是论文原作者,欢迎加REAME中的QQ联系开发者
# 请您不要删除或修改这行警告,除非您是论文的原作者如果您是论文原作者,欢迎加README中的QQ联系开发者
self.msg_declare = "为了防止大语言模型的意外谬误产生扩散影响,禁止移除或修改此警告。}}\\\\"
self.title = "unknown"
self.abstract = "unknown"
@@ -151,7 +151,7 @@ class LatexPaperSplit():
"""
break down latex file to a linked list,
each node use a preserve flag to indicate whether it should
be proccessed by GPT.
be processed by GPT.
P.S. use multiprocessing to avoid timeout error
"""
import multiprocessing
@@ -351,7 +351,7 @@ def 编译Latex(chatbot, history, main_file_original, main_file_modified, work_f
max_try = 32
chatbot.append([f"正在编译PDF文档", f'编译已经开始。当前工作路径为{work_folder},如果程序停顿5分钟以上,请直接去该路径下取回翻译结果,或者重启之后再度尝试 ...']); yield from update_ui(chatbot=chatbot, history=history)
chatbot.append([f"正在编译PDF文档", '...']); yield from update_ui(chatbot=chatbot, history=history); time.sleep(1); chatbot[-1] = list(chatbot[-1]) # 刷新界面
yield from update_ui_lastest_msg('编译已经开始...', chatbot, history) # 刷新Gradio前端界面
yield from update_ui_latest_msg('编译已经开始...', chatbot, history) # 刷新Gradio前端界面
# 检查是否需要使用xelatex
def check_if_need_xelatex(tex_path):
try:
@@ -373,7 +373,7 @@ def 编译Latex(chatbot, history, main_file_original, main_file_modified, work_f
# 根据编译器类型返回编译命令
def get_compile_command(compiler, filename):
compile_command = f'{compiler} -interaction=batchmode -file-line-error {filename}.tex'
logger.info('Latex 编译指令: ', compile_command)
logger.info('Latex 编译指令: ' + compile_command)
return compile_command
# 确定使用的编译器
@@ -396,32 +396,32 @@ def 编译Latex(chatbot, history, main_file_original, main_file_modified, work_f
shutil.copyfile(may_exist_bbl, target_bbl)
# https://stackoverflow.com/questions/738755/dont-make-me-manually-abort-a-latex-compile-when-theres-an-error
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 编译原始PDF ...', chatbot, history) # 刷新Gradio前端界面
yield from update_ui_latest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 编译原始PDF ...', chatbot, history) # 刷新Gradio前端界面
ok = compile_latex_with_timeout(get_compile_command(compiler, main_file_original), work_folder_original)
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 编译转化后的PDF ...', chatbot, history) # 刷新Gradio前端界面
yield from update_ui_latest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 编译转化后的PDF ...', chatbot, history) # 刷新Gradio前端界面
ok = compile_latex_with_timeout(get_compile_command(compiler, main_file_modified), work_folder_modified)
if ok and os.path.exists(pj(work_folder_modified, f'{main_file_modified}.pdf')):
# 只有第二步成功,才能继续下面的步骤
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 编译BibTex ...', chatbot, history) # 刷新Gradio前端界面
yield from update_ui_latest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 编译BibTex ...', chatbot, history) # 刷新Gradio前端界面
if not os.path.exists(pj(work_folder_original, f'{main_file_original}.bbl')):
ok = compile_latex_with_timeout(f'bibtex {main_file_original}.aux', work_folder_original)
if not os.path.exists(pj(work_folder_modified, f'{main_file_modified}.bbl')):
ok = compile_latex_with_timeout(f'bibtex {main_file_modified}.aux', work_folder_modified)
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 编译文献交叉引用 ...', chatbot, history) # 刷新Gradio前端界面
yield from update_ui_latest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 编译文献交叉引用 ...', chatbot, history) # 刷新Gradio前端界面
ok = compile_latex_with_timeout(get_compile_command(compiler, main_file_original), work_folder_original)
ok = compile_latex_with_timeout(get_compile_command(compiler, main_file_modified), work_folder_modified)
ok = compile_latex_with_timeout(get_compile_command(compiler, main_file_original), work_folder_original)
ok = compile_latex_with_timeout(get_compile_command(compiler, main_file_modified), work_folder_modified)
if mode!='translate_zh':
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 使用latexdiff生成论文转化前后对比 ...', chatbot, history) # 刷新Gradio前端界面
yield from update_ui_latest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 使用latexdiff生成论文转化前后对比 ...', chatbot, history) # 刷新Gradio前端界面
logger.info( f'latexdiff --encoding=utf8 --append-safecmd=subfile {work_folder_original}/{main_file_original}.tex {work_folder_modified}/{main_file_modified}.tex --flatten > {work_folder}/merge_diff.tex')
ok = compile_latex_with_timeout(f'latexdiff --encoding=utf8 --append-safecmd=subfile {work_folder_original}/{main_file_original}.tex {work_folder_modified}/{main_file_modified}.tex --flatten > {work_folder}/merge_diff.tex', os.getcwd())
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 正在编译对比PDF ...', chatbot, history) # 刷新Gradio前端界面
yield from update_ui_latest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 正在编译对比PDF ...', chatbot, history) # 刷新Gradio前端界面
ok = compile_latex_with_timeout(get_compile_command(compiler, 'merge_diff'), work_folder)
ok = compile_latex_with_timeout(f'bibtex merge_diff.aux', work_folder)
ok = compile_latex_with_timeout(get_compile_command(compiler, 'merge_diff'), work_folder)
@@ -435,13 +435,13 @@ def 编译Latex(chatbot, history, main_file_original, main_file_modified, work_f
results_ += f"原始PDF编译是否成功: {original_pdf_success};"
results_ += f"转化PDF编译是否成功: {modified_pdf_success};"
results_ += f"对比PDF编译是否成功: {diff_pdf_success};"
yield from update_ui_lastest_msg(f'{n_fix}编译结束:<br/>{results_}...', chatbot, history) # 刷新Gradio前端界面
yield from update_ui_latest_msg(f'{n_fix}编译结束:<br/>{results_}...', chatbot, history) # 刷新Gradio前端界面
if diff_pdf_success:
result_pdf = pj(work_folder_modified, f'merge_diff.pdf') # get pdf path
promote_file_to_downloadzone(result_pdf, rename_file=None, chatbot=chatbot) # promote file to web UI
if modified_pdf_success:
yield from update_ui_lastest_msg(f'转化PDF编译已经成功, 正在尝试生成对比PDF, 请稍候 ...', chatbot, history) # 刷新Gradio前端界面
yield from update_ui_latest_msg(f'转化PDF编译已经成功, 正在尝试生成对比PDF, 请稍候 ...', chatbot, history) # 刷新Gradio前端界面
result_pdf = pj(work_folder_modified, f'{main_file_modified}.pdf') # get pdf path
origin_pdf = pj(work_folder_original, f'{main_file_original}.pdf') # get pdf path
if os.path.exists(pj(work_folder, '..', 'translation')):
@@ -472,7 +472,7 @@ def 编译Latex(chatbot, history, main_file_original, main_file_modified, work_f
work_folder_modified=work_folder_modified,
fixed_line=fixed_line
)
yield from update_ui_lastest_msg(f'由于最为关键的转化PDF编译失败, 将根据报错信息修正tex源文件并重试, 当前报错的latex代码处于第{buggy_lines}行 ...', chatbot, history) # 刷新Gradio前端界面
yield from update_ui_latest_msg(f'由于最为关键的转化PDF编译失败, 将根据报错信息修正tex源文件并重试, 当前报错的latex代码处于第{buggy_lines}行 ...', chatbot, history) # 刷新Gradio前端界面
if not can_retry: break
return False # 失败啦

查看文件

@@ -168,7 +168,7 @@ def set_forbidden_text(text, mask, pattern, flags=0):
def reverse_forbidden_text(text, mask, pattern, flags=0, forbid_wrapper=True):
"""
Move area out of preserve area (make text editable for GPT)
count the number of the braces so as to catch compelete text area.
count the number of the braces so as to catch complete text area.
e.g.
\begin{abstract} blablablablablabla. \end{abstract}
"""
@@ -188,7 +188,7 @@ def reverse_forbidden_text(text, mask, pattern, flags=0, forbid_wrapper=True):
def set_forbidden_text_careful_brace(text, mask, pattern, flags=0):
"""
Add a preserve text area in this paper (text become untouchable for GPT).
count the number of the braces so as to catch compelete text area.
count the number of the braces so as to catch complete text area.
e.g.
\caption{blablablablabla\texbf{blablabla}blablabla.}
"""
@@ -214,7 +214,7 @@ def reverse_forbidden_text_careful_brace(
):
"""
Move area out of preserve area (make text editable for GPT)
count the number of the braces so as to catch compelete text area.
count the number of the braces so as to catch complete text area.
e.g.
\caption{blablablablabla\texbf{blablabla}blablabla.}
"""
@@ -287,23 +287,23 @@ def find_main_tex_file(file_manifest, mode):
在多Tex文档中,寻找主文件,必须包含documentclass,返回找到的第一个。
P.S. 但愿没人把latex模板放在里面传进来 (6.25 加入判定latex模板的代码)
"""
canidates = []
candidates = []
for texf in file_manifest:
if os.path.basename(texf).startswith("merge"):
continue
with open(texf, "r", encoding="utf8", errors="ignore") as f:
file_content = f.read()
if r"\documentclass" in file_content:
canidates.append(texf)
candidates.append(texf)
else:
continue
if len(canidates) == 0:
if len(candidates) == 0:
raise RuntimeError("无法找到一个主Tex文件包含documentclass关键字")
elif len(canidates) == 1:
return canidates[0]
else: # if len(canidates) >= 2 通过一些Latex模板中常见但通常不会出现在正文的单词,对不同latex源文件扣分,取评分最高者返回
canidates_score = []
elif len(candidates) == 1:
return candidates[0]
else: # if len(candidates) >= 2 通过一些Latex模板中常见但通常不会出现在正文的单词,对不同latex源文件扣分,取评分最高者返回
candidates_score = []
# 给出一些判定模板文档的词作为扣分项
unexpected_words = [
"\\LaTeX",
@@ -316,19 +316,19 @@ def find_main_tex_file(file_manifest, mode):
"reviewers",
]
expected_words = ["\\input", "\\ref", "\\cite"]
for texf in canidates:
canidates_score.append(0)
for texf in candidates:
candidates_score.append(0)
with open(texf, "r", encoding="utf8", errors="ignore") as f:
file_content = f.read()
file_content = rm_comments(file_content)
for uw in unexpected_words:
if uw in file_content:
canidates_score[-1] -= 1
candidates_score[-1] -= 1
for uw in expected_words:
if uw in file_content:
canidates_score[-1] += 1
select = np.argmax(canidates_score) # 取评分最高者返回
return canidates[select]
candidates_score[-1] += 1
select = np.argmax(candidates_score) # 取评分最高者返回
return candidates[select]
def rm_comments(main_file):
@@ -374,7 +374,7 @@ def find_tex_file_ignore_case(fp):
def merge_tex_files_(project_foler, main_file, mode):
"""
Merge Tex project recrusively
Merge Tex project recursively
"""
main_file = rm_comments(main_file)
for s in reversed([q for q in re.finditer(r"\\input\{(.*?)\}", main_file, re.M)]):
@@ -429,7 +429,7 @@ def find_title_and_abs(main_file):
def merge_tex_files(project_foler, main_file, mode):
"""
Merge Tex project recrusively
Merge Tex project recursively
P.S. 顺便把CTEX塞进去以支持中文
P.S. 顺便把Latex的注释去除
"""

查看文件

@@ -1,4 +1,4 @@
from toolbox import update_ui, get_conf, promote_file_to_downloadzone, update_ui_lastest_msg, generate_file_link
from toolbox import update_ui, get_conf, promote_file_to_downloadzone, update_ui_latest_msg, generate_file_link
from shared_utils.docker_as_service_api import stream_daas
from shared_utils.docker_as_service_api import DockerServiceApiComModel
import random
@@ -25,7 +25,7 @@ def download_video(video_id, only_audio, user_name, chatbot, history):
status_buf += "\n\n"
status_buf += "DaaS file attach: \n\n"
status_buf += str(output_manifest['server_file_attach'])
yield from update_ui_lastest_msg(status_buf, chatbot, history)
yield from update_ui_latest_msg(status_buf, chatbot, history)
return output_manifest['server_file_attach']

查看文件

@@ -1,6 +1,6 @@
from pydantic import BaseModel, Field
from typing import List
from toolbox import update_ui_lastest_msg, disable_auto_promotion
from toolbox import update_ui_latest_msg, disable_auto_promotion
from toolbox import CatchException, update_ui, get_conf, select_api_key, get_log_folder
from request_llms.bridge_all import predict_no_ui_long_connection
from crazy_functions.json_fns.pydantic_io import GptJsonIO, JsonStringError

查看文件

查看文件

@@ -0,0 +1,386 @@
from abc import ABC, abstractmethod
from typing import List, Dict, Any
from ..query_analyzer import SearchCriteria
from ..sources.github_source import GitHubSource
import asyncio
import re
from datetime import datetime
class BaseHandler(ABC):
"""处理器基类"""
def __init__(self, github: GitHubSource, llm_kwargs: Dict = None):
self.github = github
self.llm_kwargs = llm_kwargs or {}
self.ranked_repos = [] # 存储排序后的仓库列表
def _get_search_params(self, plugin_kwargs: Dict) -> Dict:
"""获取搜索参数"""
return {
'max_repos': plugin_kwargs.get('max_repos', 150), # 最大仓库数量,从30改为150
'max_details': plugin_kwargs.get('max_details', 80), # 最多展示详情的仓库数量,新增参数
'search_multiplier': plugin_kwargs.get('search_multiplier', 3), # 检索倍数
'min_stars': plugin_kwargs.get('min_stars', 0), # 最少星标数
}
@abstractmethod
async def handle(
self,
criteria: SearchCriteria,
chatbot: List[List[str]],
history: List[List[str]],
system_prompt: str,
llm_kwargs: Dict[str, Any],
plugin_kwargs: Dict[str, Any],
) -> str:
"""处理查询"""
pass
async def _search_repositories(self, query: str, language: str = None, min_stars: int = 0,
sort: str = "stars", per_page: int = 30) -> List[Dict]:
"""搜索仓库"""
try:
# 构建查询字符串
if min_stars > 0 and "stars:>" not in query:
query += f" stars:>{min_stars}"
if language and "language:" not in query:
query += f" language:{language}"
# 执行搜索
result = await self.github.search_repositories(
query=query,
sort=sort,
per_page=per_page
)
if result and "items" in result:
return result["items"]
return []
except Exception as e:
print(f"仓库搜索出错: {str(e)}")
return []
async def _search_bilingual_repositories(self, english_query: str, chinese_query: str, language: str = None, min_stars: int = 0,
sort: str = "stars", per_page: int = 30) -> List[Dict]:
"""同时搜索中英文仓库并合并结果"""
try:
# 搜索英文仓库
english_results = await self._search_repositories(
query=english_query,
language=language,
min_stars=min_stars,
sort=sort,
per_page=per_page
)
# 搜索中文仓库
chinese_results = await self._search_repositories(
query=chinese_query,
language=language,
min_stars=min_stars,
sort=sort,
per_page=per_page
)
# 合并结果,去除重复项
merged_results = []
seen_repos = set()
# 优先添加英文结果
for repo in english_results:
repo_id = repo.get('id')
if repo_id and repo_id not in seen_repos:
seen_repos.add(repo_id)
merged_results.append(repo)
# 添加中文结果(排除重复)
for repo in chinese_results:
repo_id = repo.get('id')
if repo_id and repo_id not in seen_repos:
seen_repos.add(repo_id)
merged_results.append(repo)
# 按星标数重新排序
merged_results.sort(key=lambda x: x.get('stargazers_count', 0), reverse=True)
return merged_results[:per_page] # 返回合并后的前per_page个结果
except Exception as e:
print(f"双语仓库搜索出错: {str(e)}")
return []
async def _search_code(self, query: str, language: str = None, per_page: int = 30) -> List[Dict]:
"""搜索代码"""
try:
# 构建查询字符串
if language and "language:" not in query:
query += f" language:{language}"
# 执行搜索
result = await self.github.search_code(
query=query,
per_page=per_page
)
if result and "items" in result:
return result["items"]
return []
except Exception as e:
print(f"代码搜索出错: {str(e)}")
return []
async def _search_bilingual_code(self, english_query: str, chinese_query: str, language: str = None, per_page: int = 30) -> List[Dict]:
"""同时搜索中英文代码并合并结果"""
try:
# 搜索英文代码
english_results = await self._search_code(
query=english_query,
language=language,
per_page=per_page
)
# 搜索中文代码
chinese_results = await self._search_code(
query=chinese_query,
language=language,
per_page=per_page
)
# 合并结果,去除重复项
merged_results = []
seen_files = set()
# 优先添加英文结果
for item in english_results:
# 使用文件URL作为唯一标识
file_url = item.get('html_url', '')
if file_url and file_url not in seen_files:
seen_files.add(file_url)
merged_results.append(item)
# 添加中文结果(排除重复)
for item in chinese_results:
file_url = item.get('html_url', '')
if file_url and file_url not in seen_files:
seen_files.add(file_url)
merged_results.append(item)
# 对结果进行排序,优先显示匹配度高的结果
# 由于无法直接获取匹配度,这里使用仓库的星标数作为替代指标
merged_results.sort(key=lambda x: x.get('repository', {}).get('stargazers_count', 0), reverse=True)
return merged_results[:per_page] # 返回合并后的前per_page个结果
except Exception as e:
print(f"双语代码搜索出错: {str(e)}")
return []
async def _search_users(self, query: str, per_page: int = 30) -> List[Dict]:
"""搜索用户"""
try:
result = await self.github.search_users(
query=query,
per_page=per_page
)
if result and "items" in result:
return result["items"]
return []
except Exception as e:
print(f"用户搜索出错: {str(e)}")
return []
async def _search_bilingual_users(self, english_query: str, chinese_query: str, per_page: int = 30) -> List[Dict]:
"""同时搜索中英文用户并合并结果"""
try:
# 搜索英文用户
english_results = await self._search_users(
query=english_query,
per_page=per_page
)
# 搜索中文用户
chinese_results = await self._search_users(
query=chinese_query,
per_page=per_page
)
# 合并结果,去除重复项
merged_results = []
seen_users = set()
# 优先添加英文结果
for user in english_results:
user_id = user.get('id')
if user_id and user_id not in seen_users:
seen_users.add(user_id)
merged_results.append(user)
# 添加中文结果(排除重复)
for user in chinese_results:
user_id = user.get('id')
if user_id and user_id not in seen_users:
seen_users.add(user_id)
merged_results.append(user)
# 按关注者数量进行排序
merged_results.sort(key=lambda x: x.get('followers', 0), reverse=True)
return merged_results[:per_page] # 返回合并后的前per_page个结果
except Exception as e:
print(f"双语用户搜索出错: {str(e)}")
return []
async def _search_topics(self, query: str, per_page: int = 30) -> List[Dict]:
"""搜索主题"""
try:
result = await self.github.search_topics(
query=query,
per_page=per_page
)
if result and "items" in result:
return result["items"]
return []
except Exception as e:
print(f"主题搜索出错: {str(e)}")
return []
async def _search_bilingual_topics(self, english_query: str, chinese_query: str, per_page: int = 30) -> List[Dict]:
"""同时搜索中英文主题并合并结果"""
try:
# 搜索英文主题
english_results = await self._search_topics(
query=english_query,
per_page=per_page
)
# 搜索中文主题
chinese_results = await self._search_topics(
query=chinese_query,
per_page=per_page
)
# 合并结果,去除重复项
merged_results = []
seen_topics = set()
# 优先添加英文结果
for topic in english_results:
topic_name = topic.get('name')
if topic_name and topic_name not in seen_topics:
seen_topics.add(topic_name)
merged_results.append(topic)
# 添加中文结果(排除重复)
for topic in chinese_results:
topic_name = topic.get('name')
if topic_name and topic_name not in seen_topics:
seen_topics.add(topic_name)
merged_results.append(topic)
# 可以按流行度进行排序(如果有)
if merged_results and 'featured' in merged_results[0]:
merged_results.sort(key=lambda x: x.get('featured', False), reverse=True)
return merged_results[:per_page] # 返回合并后的前per_page个结果
except Exception as e:
print(f"双语主题搜索出错: {str(e)}")
return []
async def _get_repo_details(self, repos: List[Dict]) -> List[Dict]:
"""获取仓库详细信息"""
enhanced_repos = []
for repo in repos:
try:
# 获取README信息
owner = repo.get('owner', {}).get('login') if repo.get('owner') is not None else None
repo_name = repo.get('name')
if owner and repo_name:
readme = await self.github.get_repo_readme(owner, repo_name)
if readme and "decoded_content" in readme:
# 提取README的前1000个字符作为摘要
repo['readme_excerpt'] = readme["decoded_content"][:1000] + "..."
# 获取语言使用情况
languages = await self.github.get_repository_languages(owner, repo_name)
if languages:
repo['languages_detail'] = languages
# 获取最新发布版本
releases = await self.github.get_repo_releases(owner, repo_name, per_page=1)
if releases and len(releases) > 0:
repo['latest_release'] = releases[0]
# 获取主题标签
topics = await self.github.get_repo_topics(owner, repo_name)
if topics and "names" in topics:
repo['topics'] = topics["names"]
enhanced_repos.append(repo)
except Exception as e:
print(f"获取仓库 {repo.get('full_name')} 详情时出错: {str(e)}")
enhanced_repos.append(repo) # 添加原始仓库信息
return enhanced_repos
def _format_repos(self, repos: List[Dict]) -> str:
"""格式化仓库列表"""
formatted = []
for i, repo in enumerate(repos, 1):
# 构建仓库URL
repo_url = repo.get('html_url', '')
# 构建完整的引用
reference = (
f"{i}. **{repo.get('full_name', '')}**\n"
f" - 描述: {repo.get('description', 'N/A')}\n"
f" - 语言: {repo.get('language', 'N/A')}\n"
f" - 星标: {repo.get('stargazers_count', 0)}\n"
f" - Fork数: {repo.get('forks_count', 0)}\n"
f" - 更新时间: {repo.get('updated_at', 'N/A')[:10]}\n"
f" - 创建时间: {repo.get('created_at', 'N/A')[:10]}\n"
f" - URL: <a href='{repo_url}' target='_blank'>{repo_url}</a>\n"
)
# 添加主题标签(如果有)
if repo.get('topics'):
topics_str = ", ".join(repo.get('topics'))
reference += f" - 主题标签: {topics_str}\n"
# 添加最新发布版本(如果有)
if repo.get('latest_release'):
release = repo.get('latest_release')
reference += f" - 最新版本: {release.get('tag_name', 'N/A')} ({release.get('published_at', 'N/A')[:10]})\n"
# 添加README摘要(如果有)
if repo.get('readme_excerpt'):
# 截断README,只取前300个字符
readme_short = repo.get('readme_excerpt')[:300].replace('\n', ' ')
reference += f" - README摘要: {readme_short}...\n"
formatted.append(reference)
return "\n".join(formatted)
def _generate_apology_prompt(self, criteria: SearchCriteria) -> str:
"""生成道歉提示"""
return f"""很抱歉,我们未能找到与"{criteria.main_topic}"相关的GitHub项目。
可能的原因:
1. 搜索词过于具体或冷门
2. 星标数要求过高
3. 编程语言限制过于严格
建议解决方案:
1. 尝试使用更通用的关键词
2. 降低最低星标数要求
3. 移除或更改编程语言限制
请根据以上建议调整后重试。"""
def _get_current_time(self) -> str:
"""获取当前时间信息"""
now = datetime.now()
return now.strftime("%Y年%m月%d")

查看文件

@@ -0,0 +1,156 @@
from typing import List, Dict, Any
from .base_handler import BaseHandler
from ..query_analyzer import SearchCriteria
import asyncio
class CodeSearchHandler(BaseHandler):
"""代码搜索处理器"""
def __init__(self, github, llm_kwargs=None):
super().__init__(github, llm_kwargs)
async def handle(
self,
criteria: SearchCriteria,
chatbot: List[List[str]],
history: List[List[str]],
system_prompt: str,
llm_kwargs: Dict[str, Any],
plugin_kwargs: Dict[str, Any],
) -> str:
"""处理代码搜索请求,返回最终的prompt"""
search_params = self._get_search_params(plugin_kwargs)
# 搜索代码
code_results = await self._search_bilingual_code(
english_query=criteria.github_params["query"],
chinese_query=criteria.github_params["chinese_query"],
language=criteria.language,
per_page=search_params['max_repos']
)
if not code_results:
return self._generate_apology_prompt(criteria)
# 获取代码文件内容
enhanced_code_results = await self._get_code_details(code_results[:search_params['max_details']])
self.ranked_repos = [item["repository"] for item in enhanced_code_results if "repository" in item]
if not enhanced_code_results:
return self._generate_apology_prompt(criteria)
# 构建最终的prompt
current_time = self._get_current_time()
final_prompt = f"""当前时间: {current_time}
基于用户对{criteria.main_topic}的查询,我找到了以下代码示例。
代码搜索结果:
{self._format_code_results(enhanced_code_results)}
请提供:
1. 对于搜索的"{criteria.main_topic}"主题的综合解释:
- 概念和原理介绍
- 常见实现方法和技术
- 最佳实践和注意事项
2. 对每个代码示例:
- 解释代码的主要功能和实现方式
- 分析代码质量、可读性和效率
- 指出代码中的亮点和潜在改进空间
- 说明代码的适用场景
3. 代码实现比较:
- 不同实现方法的优缺点
- 性能和可维护性分析
- 适用不同场景的实现建议
4. 学习建议:
- 理解和使用这些代码需要的背景知识
- 如何扩展或改进所展示的代码
- 进一步学习相关技术的资源
重要提示:
- 深入解释代码的核心逻辑和实现思路
- 提供专业、技术性的分析
- 优先关注代码的实现质量和技术价值
- 当代码实现有问题时,指出并提供改进建议
- 对于复杂代码,分解解释其组成部分
- 根据用户查询的具体问题提供针对性答案
- 所有链接请使用<a href='链接地址' target='_blank'>链接文本</a>格式,确保链接在新窗口打开
使用markdown格式提供清晰的分节回复。
"""
return final_prompt
async def _get_code_details(self, code_results: List[Dict]) -> List[Dict]:
"""获取代码详情"""
enhanced_results = []
for item in code_results:
try:
repo = item.get('repository', {})
file_path = item.get('path', '')
repo_name = repo.get('full_name', '')
if repo_name and file_path:
owner, repo_name = repo_name.split('/')
# 获取文件内容
file_content = await self.github.get_file_content(owner, repo_name, file_path)
if file_content and "decoded_content" in file_content:
item['code_content'] = file_content["decoded_content"]
# 获取仓库基本信息
repo_details = await self.github.get_repo(owner, repo_name)
if repo_details:
item['repository'] = repo_details
enhanced_results.append(item)
except Exception as e:
print(f"获取代码详情时出错: {str(e)}")
enhanced_results.append(item) # 添加原始信息
return enhanced_results
def _format_code_results(self, code_results: List[Dict]) -> str:
"""格式化代码搜索结果"""
formatted = []
for i, item in enumerate(code_results, 1):
# 构建仓库信息
repo = item.get('repository', {})
repo_name = repo.get('full_name', 'N/A')
repo_url = repo.get('html_url', '')
stars = repo.get('stargazers_count', 0)
language = repo.get('language', 'N/A')
# 构建文件信息
file_path = item.get('path', 'N/A')
file_url = item.get('html_url', '')
# 构建代码内容
code_content = item.get('code_content', '')
if code_content:
# 只显示前30行代码
code_lines = code_content.split("\n")
if len(code_lines) > 30:
displayed_code = "\n".join(code_lines[:30]) + "\n... (代码太长已截断) ..."
else:
displayed_code = code_content
else:
displayed_code = "(代码内容获取失败)"
reference = (
f"### {i}. {file_path} (在 {repo_name} 中)\n\n"
f"- **仓库**: <a href='{repo_url}' target='_blank'>{repo_name}</a> (⭐ {stars}, 语言: {language})\n"
f"- **文件路径**: <a href='{file_url}' target='_blank'>{file_path}</a>\n\n"
f"```{language.lower()}\n{displayed_code}\n```\n\n"
)
formatted.append(reference)
return "\n".join(formatted)

查看文件

@@ -0,0 +1,192 @@
from typing import List, Dict, Any
from .base_handler import BaseHandler
from ..query_analyzer import SearchCriteria
import asyncio
class RepositoryHandler(BaseHandler):
"""仓库搜索处理器"""
def __init__(self, github, llm_kwargs=None):
super().__init__(github, llm_kwargs)
async def handle(
self,
criteria: SearchCriteria,
chatbot: List[List[str]],
history: List[List[str]],
system_prompt: str,
llm_kwargs: Dict[str, Any],
plugin_kwargs: Dict[str, Any],
) -> str:
"""处理仓库搜索请求,返回最终的prompt"""
search_params = self._get_search_params(plugin_kwargs)
# 如果是特定仓库查询
if criteria.repo_id:
try:
owner, repo = criteria.repo_id.split('/')
repo_details = await self.github.get_repo(owner, repo)
if repo_details:
# 获取推荐的相似仓库
similar_repos = await self.github.get_repo_recommendations(criteria.repo_id, limit=5)
# 添加详细信息
all_repos = [repo_details] + similar_repos
enhanced_repos = await self._get_repo_details(all_repos)
self.ranked_repos = enhanced_repos
# 构建最终的prompt
current_time = self._get_current_time()
final_prompt = self._build_repo_detail_prompt(enhanced_repos[0], enhanced_repos[1:], current_time)
return final_prompt
else:
return self._generate_apology_prompt(criteria)
except Exception as e:
print(f"处理特定仓库时出错: {str(e)}")
return self._generate_apology_prompt(criteria)
# 一般仓库搜索
repos = await self._search_bilingual_repositories(
english_query=criteria.github_params["query"],
chinese_query=criteria.github_params["chinese_query"],
language=criteria.language,
min_stars=criteria.min_stars,
per_page=search_params['max_repos']
)
if not repos:
return self._generate_apology_prompt(criteria)
# 获取仓库详情
enhanced_repos = await self._get_repo_details(repos[:search_params['max_details']]) # 使用max_details参数
self.ranked_repos = enhanced_repos
if not enhanced_repos:
return self._generate_apology_prompt(criteria)
# 构建最终的prompt
current_time = self._get_current_time()
final_prompt = f"""当前时间: {current_time}
基于用户对{criteria.main_topic}的兴趣,以下是相关的GitHub仓库。
可供推荐的GitHub仓库:
{self._format_repos(enhanced_repos)}
请提供:
1. 按功能、用途或成熟度对仓库进行分组
2. 对每个仓库:
- 简要描述其主要功能和用途
- 分析其技术特点和优势
- 说明其适用场景和使用难度
- 指出其与同类产品相比的独特优势
- 解释其星标数量和活跃度代表的意义
3. 使用建议:
- 新手最适合入门的仓库
- 生产环境中最稳定可靠的选择
- 最新技术栈或创新方案的代表
- 学习特定技术的最佳资源
4. 相关资源:
- 学习这些项目需要的前置知识
- 项目间的关联和技术栈兼容性
- 可能的使用组合方案
重要提示:
- 重点解释为什么每个仓库值得关注
- 突出项目间的关联性和差异性
- 考虑用户不同水平的需求(初学者vs专业人士)
- 在介绍项目时,使用<a href='链接' target='_blank'>文本</a>格式,确保链接在新窗口打开
- 根据仓库的活跃度、更新频率、维护状态提供使用建议
- 仅基于提供的信息,不要做无根据的猜测
- 在信息缺失或不明确时,坦诚说明
使用markdown格式提供清晰的分节回复。
"""
return final_prompt
def _build_repo_detail_prompt(self, main_repo: Dict, similar_repos: List[Dict], current_time: str) -> str:
"""构建仓库详情prompt"""
# 提取README摘要
readme_content = "未提供"
if main_repo.get('readme_excerpt'):
readme_content = main_repo.get('readme_excerpt')
# 构建语言分布
languages = main_repo.get('languages_detail', {})
lang_distribution = []
if languages:
total = sum(languages.values())
for lang, bytes_val in languages.items():
percentage = (bytes_val / total) * 100
lang_distribution.append(f"{lang}: {percentage:.1f}%")
lang_str = "未知"
if lang_distribution:
lang_str = ", ".join(lang_distribution)
# 构建最终prompt
prompt = f"""当前时间: {current_time}
## 主要仓库信息
### {main_repo.get('full_name')}
- **描述**: {main_repo.get('description', '未提供')}
- **星标数**: {main_repo.get('stargazers_count', 0)}
- **Fork数**: {main_repo.get('forks_count', 0)}
- **Watch数**: {main_repo.get('watchers_count', 0)}
- **Issues数**: {main_repo.get('open_issues_count', 0)}
- **语言分布**: {lang_str}
- **许可证**: {main_repo.get('license', {}).get('name', '未指定') if main_repo.get('license') is not None else '未指定'}
- **创建时间**: {main_repo.get('created_at', '')[:10]}
- **最近更新**: {main_repo.get('updated_at', '')[:10]}
- **主题标签**: {', '.join(main_repo.get('topics', ['']))}
- **GitHub链接**: <a href='{main_repo.get('html_url')}' target='_blank'>链接</a>
### README摘要:
{readme_content}
## 类似仓库:
{self._format_repos(similar_repos)}
请提供以下内容:
1. **项目概述**
- 详细解释{main_repo.get('name', '')}项目的主要功能和用途
- 分析其技术特点、架构和实现原理
- 讨论其在所属领域的地位和影响力
- 评估项目成熟度和稳定性
2. **优势与特点**
- 与同类项目相比的独特优势
- 显著的技术创新或设计模式
- 值得学习或借鉴的代码实践
3. **使用场景**
- 最适合的应用场景
- 潜在的使用限制和注意事项
- 入门门槛和学习曲线评估
- 产品级应用的可行性分析
4. **资源与生态**
- 相关学习资源推荐
- 配套工具和库的建议
- 社区支持和活跃度评估
5. **类似项目对比**
- 与列出的类似项目的详细对比
- 不同场景下的最佳选择建议
- 潜在的互补使用方案
提示:所有链接请使用<a href='链接地址' target='_blank'>链接文本</a>格式,确保链接在新窗口打开。
请以专业、客观的技术分析角度回答,使用markdown格式提供结构化信息。
"""
return prompt

查看文件

@@ -0,0 +1,217 @@
from typing import List, Dict, Any
from .base_handler import BaseHandler
from ..query_analyzer import SearchCriteria
import asyncio
class TopicHandler(BaseHandler):
"""主题搜索处理器"""
def __init__(self, github, llm_kwargs=None):
super().__init__(github, llm_kwargs)
async def handle(
self,
criteria: SearchCriteria,
chatbot: List[List[str]],
history: List[List[str]],
system_prompt: str,
llm_kwargs: Dict[str, Any],
plugin_kwargs: Dict[str, Any],
) -> str:
"""处理主题搜索请求,返回最终的prompt"""
search_params = self._get_search_params(plugin_kwargs)
# 搜索主题
topics = await self._search_bilingual_topics(
english_query=criteria.github_params["query"],
chinese_query=criteria.github_params["chinese_query"],
per_page=search_params['max_repos']
)
if not topics:
# 尝试用主题搜索仓库
search_query = criteria.github_params["query"]
chinese_search_query = criteria.github_params["chinese_query"]
if "topic:" not in search_query:
search_query += " topic:" + criteria.main_topic.replace(" ", "-")
if "topic:" not in chinese_search_query:
chinese_search_query += " topic:" + criteria.main_topic.replace(" ", "-")
repos = await self._search_bilingual_repositories(
english_query=search_query,
chinese_query=chinese_search_query,
language=criteria.language,
min_stars=criteria.min_stars,
per_page=search_params['max_repos']
)
if not repos:
return self._generate_apology_prompt(criteria)
# 获取仓库详情
enhanced_repos = await self._get_repo_details(repos[:10])
self.ranked_repos = enhanced_repos
if not enhanced_repos:
return self._generate_apology_prompt(criteria)
# 构建基于主题的仓库列表prompt
current_time = self._get_current_time()
final_prompt = f"""当前时间: {current_time}
基于用户对主题"{criteria.main_topic}"的查询,我找到了以下相关GitHub仓库。
主题相关仓库:
{self._format_repos(enhanced_repos)}
请提供:
1. 主题综述:
- "{criteria.main_topic}"主题的概述和重要性
- 该主题在技术领域中的应用和发展趋势
- 主题相关的主要技术栈和知识体系
2. 仓库分析:
- 按功能、技术栈或应用场景对仓库进行分类
- 每个仓库在该主题领域的定位和贡献
- 不同仓库间的技术路线对比
3. 学习路径建议:
- 初学者入门该主题的推荐仓库和学习顺序
- 进阶学习的关键仓库和技术要点
- 实际应用中的最佳实践选择
4. 技术生态分析:
- 该主题下的主流工具和库
- 社区活跃度和维护状况
- 与其他相关技术的集成方案
重要提示:
- 主题"{criteria.main_topic}"是用户查询的核心,请围绕此主题展开分析
- 注重仓库质量评估和使用建议
- 提供基于事实的客观技术分析
- 在介绍仓库时使用<a href='链接地址' target='_blank'>链接文本</a>格式,确保链接在新窗口打开
- 考虑不同技术水平用户的需求
使用markdown格式提供清晰的分节回复。
"""
return final_prompt
# 如果找到了主题,则获取主题下的热门仓库
topic_repos = []
for topic in topics[:5]: # 增加到5个主题
topic_name = topic.get('name', '')
if topic_name:
# 搜索该主题下的仓库
repos = await self._search_repositories(
query=f"topic:{topic_name}",
language=criteria.language,
min_stars=criteria.min_stars,
per_page=20 # 每个主题最多20个仓库
)
if repos:
for repo in repos:
repo['topic_source'] = topic_name
topic_repos.append(repo)
if not topic_repos:
return self._generate_apology_prompt(criteria)
# 获取前N个仓库的详情
enhanced_repos = await self._get_repo_details(topic_repos[:search_params['max_details']])
self.ranked_repos = enhanced_repos
if not enhanced_repos:
return self._generate_apology_prompt(criteria)
# 构建最终的prompt
current_time = self._get_current_time()
final_prompt = f"""当前时间: {current_time}
基于用户对"{criteria.main_topic}"主题的查询,我找到了以下相关GitHub主题和仓库。
主题相关仓库:
{self._format_topic_repos(enhanced_repos)}
请提供:
1. 主题概述:
- 对"{criteria.main_topic}"相关主题的介绍和技术背景
- 这些主题在软件开发中的重要性和应用范围
- 主题间的关联性和技术演进路径
2. 精选仓库分析:
- 每个主题下最具代表性的仓库详解
- 仓库的技术亮点和创新点
- 使用场景和技术成熟度评估
3. 技术趋势分析:
- 基于主题和仓库活跃度的技术发展趋势
- 新兴解决方案和传统方案的对比
- 未来可能的技术方向预测
4. 实践建议:
- 不同应用场景下的最佳仓库选择
- 学习路径和资源推荐
- 实际项目中的应用策略
重要提示:
- 将分析重点放在主题的技术内涵和价值上
- 突出主题间的关联性和技术演进脉络
- 提供基于数据(星标数、更新频率等)的客观分析
- 考虑不同技术背景用户的需求
- 所有链接请使用<a href='链接地址' target='_blank'>链接文本</a>格式,确保链接在新窗口打开
使用markdown格式提供清晰的分节回复。
"""
return final_prompt
def _format_topic_repos(self, repos: List[Dict]) -> str:
"""按主题格式化仓库列表"""
# 按主题分组
topics_dict = {}
for repo in repos:
topic = repo.get('topic_source', '其他')
if topic not in topics_dict:
topics_dict[topic] = []
topics_dict[topic].append(repo)
# 格式化输出
formatted = []
for topic, topic_repos in topics_dict.items():
formatted.append(f"## 主题: {topic}\n")
for i, repo in enumerate(topic_repos, 1):
# 构建仓库URL
repo_url = repo.get('html_url', '')
# 构建引用
reference = (
f"{i}. **{repo.get('full_name', '')}**\n"
f" - 描述: {repo.get('description', 'N/A')}\n"
f" - 语言: {repo.get('language', 'N/A')}\n"
f" - 星标: {repo.get('stargazers_count', 0)}\n"
f" - Fork数: {repo.get('forks_count', 0)}\n"
f" - 更新时间: {repo.get('updated_at', 'N/A')[:10]}\n"
f" - URL: <a href='{repo_url}' target='_blank'>{repo_url}</a>\n"
)
# 添加主题标签(如果有)
if repo.get('topics'):
topics_str = ", ".join(repo.get('topics'))
reference += f" - 主题标签: {topics_str}\n"
# 添加README摘要(如果有)
if repo.get('readme_excerpt'):
# 截断README,只取前200个字符
readme_short = repo.get('readme_excerpt')[:200].replace('\n', ' ')
reference += f" - README摘要: {readme_short}...\n"
formatted.append(reference)
formatted.append("\n") # 主题之间添加空行
return "\n".join(formatted)

查看文件

@@ -0,0 +1,164 @@
from typing import List, Dict, Any
from .base_handler import BaseHandler
from ..query_analyzer import SearchCriteria
import asyncio
class UserSearchHandler(BaseHandler):
"""用户搜索处理器"""
def __init__(self, github, llm_kwargs=None):
super().__init__(github, llm_kwargs)
async def handle(
self,
criteria: SearchCriteria,
chatbot: List[List[str]],
history: List[List[str]],
system_prompt: str,
llm_kwargs: Dict[str, Any],
plugin_kwargs: Dict[str, Any],
) -> str:
"""处理用户搜索请求,返回最终的prompt"""
search_params = self._get_search_params(plugin_kwargs)
# 搜索用户
users = await self._search_bilingual_users(
english_query=criteria.github_params["query"],
chinese_query=criteria.github_params["chinese_query"],
per_page=search_params['max_repos']
)
if not users:
return self._generate_apology_prompt(criteria)
# 获取用户详情和仓库
enhanced_users = await self._get_user_details(users[:search_params['max_details']])
self.ranked_repos = [] # 添加用户top仓库进行展示
for user in enhanced_users:
if user.get('top_repos'):
self.ranked_repos.extend(user.get('top_repos'))
if not enhanced_users:
return self._generate_apology_prompt(criteria)
# 构建最终的prompt
current_time = self._get_current_time()
final_prompt = f"""当前时间: {current_time}
基于用户对{criteria.main_topic}的查询,我找到了以下GitHub用户。
GitHub用户搜索结果:
{self._format_users(enhanced_users)}
请提供:
1. 用户综合分析:
- 各开发者的专业领域和技术专长
- 他们在GitHub开源社区的影响力
- 技术实力和项目质量评估
2. 对每位开发者:
- 其主要贡献领域和技术栈
- 代表性项目及其价值
- 编程风格和技术特点
- 在相关领域的影响力
3. 项目推荐:
- 针对用户查询的最有价值项目
- 值得学习和借鉴的代码实践
- 不同用户项目的相互补充关系
4. 如何学习和使用:
- 如何从这些开发者项目中学习
- 最适合入门学习的项目
- 进阶学习的路径建议
重要提示:
- 关注开发者的技术专长和核心贡献
- 分析其开源项目的技术价值
- 根据用户的原始查询提供相关建议
- 避免过度赞美或主观评价
- 基于事实数据(项目数、星标数等)进行客观分析
- 所有链接请使用<a href='链接地址' target='_blank'>链接文本</a>格式,确保链接在新窗口打开
使用markdown格式提供清晰的分节回复。
"""
return final_prompt
async def _get_user_details(self, users: List[Dict]) -> List[Dict]:
"""获取用户详情和仓库"""
enhanced_users = []
for user in users:
try:
username = user.get('login')
if username:
# 获取用户详情
user_details = await self.github.get_user(username)
if user_details:
user.update(user_details)
# 获取用户仓库
repos = await self.github.get_user_repos(
username,
sort="stars",
per_page=10 # 增加到10个仓库
)
if repos:
user['top_repos'] = repos
enhanced_users.append(user)
except Exception as e:
print(f"获取用户 {user.get('login')} 详情时出错: {str(e)}")
enhanced_users.append(user) # 添加原始信息
return enhanced_users
def _format_users(self, users: List[Dict]) -> str:
"""格式化用户列表"""
formatted = []
for i, user in enumerate(users, 1):
# 构建用户信息
username = user.get('login', 'N/A')
name = user.get('name', username)
profile_url = user.get('html_url', '')
bio = user.get('bio', '无简介')
followers = user.get('followers', 0)
public_repos = user.get('public_repos', 0)
company = user.get('company', '未指定')
location = user.get('location', '未指定')
blog = user.get('blog', '')
user_info = (
f"### {i}. {name} (@{username})\n\n"
f"- **简介**: {bio}\n"
f"- **关注者**: {followers} | **公开仓库**: {public_repos}\n"
f"- **公司**: {company} | **地点**: {location}\n"
f"- **个人网站**: {blog}\n"
f"- **GitHub**: <a href='{profile_url}' target='_blank'>{username}</a>\n\n"
)
# 添加用户的热门仓库
top_repos = user.get('top_repos', [])
if top_repos:
user_info += "**热门仓库**:\n\n"
for repo in top_repos:
repo_name = repo.get('name', '')
repo_url = repo.get('html_url', '')
repo_desc = repo.get('description', '无描述')
repo_stars = repo.get('stargazers_count', 0)
repo_language = repo.get('language', '未指定')
user_info += (
f"- <a href='{repo_url}' target='_blank'>{repo_name}</a> - ⭐ {repo_stars}, {repo_language}\n"
f" {repo_desc}\n\n"
)
formatted.append(user_info)
return "\n".join(formatted)

查看文件

@@ -0,0 +1,356 @@
from typing import Dict, List
from dataclasses import dataclass
import re
@dataclass
class SearchCriteria:
"""搜索条件"""
query_type: str # 查询类型: repo/code/user/topic
main_topic: str # 主题
sub_topics: List[str] # 子主题列表
language: str # 编程语言
min_stars: int # 最少星标数
github_params: Dict # GitHub搜索参数
original_query: str = "" # 原始查询字符串
repo_id: str = "" # 特定仓库ID或名称
class QueryAnalyzer:
"""查询分析器"""
# 响应索引常量
BASIC_QUERY_INDEX = 0
GITHUB_QUERY_INDEX = 1
def __init__(self):
self.valid_types = {
"repo": ["repository", "project", "library", "framework", "tool"],
"code": ["code", "snippet", "implementation", "function", "class", "algorithm"],
"user": ["user", "developer", "organization", "contributor", "maintainer"],
"topic": ["topic", "category", "tag", "field", "area", "domain"]
}
def analyze_query(self, query: str, chatbot: List, llm_kwargs: Dict):
"""分析查询意图"""
from crazy_functions.crazy_utils import \
request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency as request_gpt
# 1. 基本查询分析
type_prompt = f"""请分析这个与GitHub相关的查询,并严格按照以下XML格式回答
查询: {query}
说明:
1. 你的回答必须使用下面显示的XML标签,不要有任何标签外的文本
2. 从以下选项中选择查询类型: repo/code/user/topic
- repo: 用于查找仓库、项目、框架或库
- code: 用于查找代码片段、函数实现或算法
- user: 用于查找用户、开发者或组织
- topic: 用于查找主题、类别或领域相关项目
3. 识别主题和子主题
4. 识别首选编程语言(如果有)
5. 确定最低星标数(如果适用)
必需格式:
<query_type>此处回答</query_type>
<main_topic>此处回答</main_topic>
<sub_topics>子主题1, 子主题2, ...</sub_topics>
<language>此处回答</language>
<min_stars>此处回答</min_stars>
示例回答:
1. 仓库查询:
查询: "查找有至少1000颗星的Python web框架"
<query_type>repo</query_type>
<main_topic>web框架</main_topic>
<sub_topics>后端开发, HTTP服务器, ORM</sub_topics>
<language>Python</language>
<min_stars>1000</min_stars>
2. 代码查询:
查询: "如何用JavaScript实现防抖函数"
<query_type>code</query_type>
<main_topic>防抖函数</main_topic>
<sub_topics>事件处理, 性能优化, 函数节流</sub_topics>
<language>JavaScript</language>
<min_stars>0</min_stars>"""
# 2. 生成英文搜索条件
github_prompt = f"""Optimize the following GitHub search query:
Query: {query}
Task: Convert the natural language query into an optimized GitHub search query.
Please use English, regardless of the language of the input query.
Available search fields and filters:
1. Basic fields:
- in:name - Search in repository names
- in:description - Search in repository descriptions
- in:readme - Search in README files
- in:topic - Search in topics
- language:X - Filter by programming language
- user:X - Repositories from a specific user
- org:X - Repositories from a specific organization
2. Code search fields:
- extension:X - Filter by file extension
- path:X - Filter by path
- filename:X - Filter by filename
3. Metric filters:
- stars:>X - Has more than X stars
- forks:>X - Has more than X forks
- size:>X - Size greater than X KB
- created:>YYYY-MM-DD - Created after a specific date
- pushed:>YYYY-MM-DD - Updated after a specific date
4. Other filters:
- is:public/private - Public or private repositories
- archived:true/false - Archived or not archived
- license:X - Specific license
- topic:X - Contains specific topic tag
Examples:
1. Query: "Find Python machine learning libraries with at least 1000 stars"
<query>machine learning in:description language:python stars:>1000</query>
2. Query: "Recently updated React UI component libraries"
<query>UI components library in:readme in:description language:javascript topic:react pushed:>2023-01-01</query>
3. Query: "Open source projects developed by Facebook"
<query>org:facebook is:public</query>
4. Query: "Depth-first search implementation in JavaScript"
<query>depth first search in:file language:javascript</query>
Please analyze the query and answer using only the XML tag:
<query>Provide the optimized GitHub search query, using appropriate fields and operators</query>"""
# 3. 生成中文搜索条件
chinese_github_prompt = f"""优化以下GitHub搜索查询:
查询: {query}
任务: 将自然语言查询转换为优化的GitHub搜索查询语句。
为了搜索中文内容,请提取原始查询的关键词并使用中文形式,同时保留GitHub特定的搜索语法为英文。
可用的搜索字段和过滤器:
1. 基本字段:
- in:name - 在仓库名称中搜索
- in:description - 在仓库描述中搜索
- in:readme - 在README文件中搜索
- in:topic - 在主题中搜索
- language:X - 按编程语言筛选
- user:X - 特定用户的仓库
- org:X - 特定组织的仓库
2. 代码搜索字段:
- extension:X - 按文件扩展名筛选
- path:X - 按路径筛选
- filename:X - 按文件名筛选
3. 指标过滤器:
- stars:>X - 有超过X颗星
- forks:>X - 有超过X个分支
- size:>X - 大小超过X KB
- created:>YYYY-MM-DD - 在特定日期后创建
- pushed:>YYYY-MM-DD - 在特定日期后更新
4. 其他过滤器:
- is:public/private - 公开或私有仓库
- archived:true/false - 已归档或未归档
- license:X - 特定许可证
- topic:X - 含特定主题标签
示例:
1. 查询: "找有关机器学习的Python库,至少1000颗星"
<query>机器学习 in:description language:python stars:>1000</query>
2. 查询: "最近更新的React UI组件库"
<query>UI 组件库 in:readme in:description language:javascript topic:react pushed:>2023-01-01</query>
3. 查询: "微信小程序开发框架"
<query>微信小程序 开发框架 in:name in:description in:readme</query>
请分析查询并仅使用XML标签回答:
<query>提供优化的GitHub搜索查询,使用适当的字段和运算符,保留中文关键词</query>"""
try:
# 构建提示数组
prompts = [
type_prompt,
github_prompt,
chinese_github_prompt,
]
show_messages = [
"分析查询类型...",
"优化英文GitHub搜索参数...",
"优化中文GitHub搜索参数...",
]
sys_prompts = [
"你是一个精通GitHub生态系统的专家,擅长分析与GitHub相关的查询。",
"You are a GitHub search expert, specialized in converting natural language queries into optimized GitHub search queries in English.",
"你是一个GitHub搜索专家,擅长处理查询并保留中文关键词进行搜索。",
]
# 使用同步方式调用LLM
responses = yield from request_gpt(
inputs_array=prompts,
inputs_show_user_array=show_messages,
llm_kwargs=llm_kwargs,
chatbot=chatbot,
history_array=[[] for _ in prompts],
sys_prompt_array=sys_prompts,
max_workers=3
)
# 从收集的响应中提取我们需要的内容
extracted_responses = []
for i in range(len(prompts)):
if (i * 2 + 1) < len(responses):
response = responses[i * 2 + 1]
if response is None:
raise Exception(f"Response {i} is None")
if not isinstance(response, str):
try:
response = str(response)
except:
raise Exception(f"Cannot convert response {i} to string")
extracted_responses.append(response)
else:
raise Exception(f"未收到第 {i + 1} 个响应")
# 解析基本信息
query_type = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "query_type")
if not query_type:
print(
f"Debug - Failed to extract query_type. Response was: {extracted_responses[self.BASIC_QUERY_INDEX]}")
raise Exception("无法提取query_type标签内容")
query_type = query_type.lower()
main_topic = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "main_topic")
if not main_topic:
print(f"Debug - Failed to extract main_topic. Using query as fallback.")
main_topic = query
query_type = self._normalize_query_type(query_type, query)
# 提取子主题
sub_topics = []
sub_topics_text = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "sub_topics")
if sub_topics_text:
sub_topics = [topic.strip() for topic in sub_topics_text.split(",")]
# 提取语言
language = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "language")
# 提取最低星标数
min_stars = 0
min_stars_text = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "min_stars")
if min_stars_text and min_stars_text.isdigit():
min_stars = int(min_stars_text)
# 解析GitHub搜索参数 - 英文
english_github_query = self._extract_tag(extracted_responses[self.GITHUB_QUERY_INDEX], "query")
# 解析GitHub搜索参数 - 中文
chinese_github_query = self._extract_tag(extracted_responses[2], "query")
# 构建GitHub参数
github_params = {
"query": english_github_query,
"chinese_query": chinese_github_query,
"sort": "stars", # 默认按星标排序
"order": "desc", # 默认降序
"per_page": 30, # 默认每页30条
"page": 1 # 默认第1页
}
# 检查是否为特定仓库查询
repo_id = ""
if "repo:" in english_github_query or "repository:" in english_github_query:
repo_match = re.search(r'(repo|repository):([a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+)', english_github_query)
if repo_match:
repo_id = repo_match.group(2)
print(f"Debug - 提取的信息:")
print(f"查询类型: {query_type}")
print(f"主题: {main_topic}")
print(f"子主题: {sub_topics}")
print(f"语言: {language}")
print(f"最低星标数: {min_stars}")
print(f"英文GitHub参数: {english_github_query}")
print(f"中文GitHub参数: {chinese_github_query}")
print(f"特定仓库: {repo_id}")
# 更新返回的 SearchCriteria,包含中英文查询
return SearchCriteria(
query_type=query_type,
main_topic=main_topic,
sub_topics=sub_topics,
language=language,
min_stars=min_stars,
github_params=github_params,
original_query=query,
repo_id=repo_id
)
except Exception as e:
raise Exception(f"分析查询失败: {str(e)}")
def _normalize_query_type(self, query_type: str, query: str) -> str:
"""规范化查询类型"""
if query_type in ["repo", "code", "user", "topic"]:
return query_type
query_lower = query.lower()
for type_name, keywords in self.valid_types.items():
for keyword in keywords:
if keyword in query_lower:
return type_name
query_type_lower = query_type.lower()
for type_name, keywords in self.valid_types.items():
for keyword in keywords:
if keyword in query_type_lower:
return type_name
return "repo" # 默认返回repo类型
def _extract_tag(self, text: str, tag: str) -> str:
"""提取标记内容"""
if not text:
return ""
# 标准XML格式处理多行和特殊字符
pattern = f"<{tag}>(.*?)</{tag}>"
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
if match:
content = match.group(1).strip()
if content:
return content
# 备用模式
patterns = [
rf"<{tag}>\s*([\s\S]*?)\s*</{tag}>", # 标准XML格式
rf"<{tag}>([\s\S]*?)(?:</{tag}>|$)", # 未闭合的标签
rf"[{tag}]([\s\S]*?)[/{tag}]", # 方括号格式
rf"{tag}:\s*(.*?)(?=\n\w|$)", # 冒号格式
rf"<{tag}>\s*(.*?)(?=<|$)" # 部分闭合
]
# 尝试所有模式
for pattern in patterns:
match = re.search(pattern, text, re.IGNORECASE | re.DOTALL)
if match:
content = match.group(1).strip()
if content: # 确保提取的内容不为空
return content
# 如果所有模式都失败,返回空字符串
return ""

查看文件

@@ -0,0 +1,701 @@
import aiohttp
import asyncio
import base64
import json
import random
from datetime import datetime
from typing import List, Dict, Optional, Union, Any
class GitHubSource:
"""GitHub API实现"""
# 默认API密钥列表 - 可以放置多个GitHub令牌
API_KEYS = [
"github_pat_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"github_pat_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
# "your_github_token_1",
# "your_github_token_2",
# "your_github_token_3"
]
def __init__(self, api_key: Optional[Union[str, List[str]]] = None):
"""初始化GitHub API客户端
Args:
api_key: GitHub个人访问令牌或令牌列表
"""
if api_key is None:
self.api_keys = self.API_KEYS
elif isinstance(api_key, str):
self.api_keys = [api_key]
else:
self.api_keys = api_key
self._initialize()
def _initialize(self) -> None:
"""初始化客户端,设置默认参数"""
self.base_url = "https://api.github.com"
self.headers = {
"Accept": "application/vnd.github+json",
"X-GitHub-Api-Version": "2022-11-28",
"User-Agent": "GitHub-API-Python-Client"
}
# 如果有可用的API密钥,随机选择一个
if self.api_keys:
selected_key = random.choice(self.api_keys)
self.headers["Authorization"] = f"Bearer {selected_key}"
print(f"已随机选择API密钥进行认证")
else:
print("警告: 未提供API密钥,将受到GitHub API请求限制")
async def _request(self, method: str, endpoint: str, params: Dict = None, data: Dict = None) -> Any:
"""发送API请求
Args:
method: HTTP方法 (GET, POST, PUT, DELETE等)
endpoint: API端点
params: URL参数
data: 请求体数据
Returns:
解析后的响应JSON
"""
async with aiohttp.ClientSession(headers=self.headers) as session:
url = f"{self.base_url}{endpoint}"
# 为调试目的打印请求信息
print(f"请求: {method} {url}")
if params:
print(f"参数: {params}")
# 发送请求
request_kwargs = {}
if params:
request_kwargs["params"] = params
if data:
request_kwargs["json"] = data
async with session.request(method, url, **request_kwargs) as response:
response_text = await response.text()
# 检查HTTP状态码
if response.status >= 400:
print(f"API请求失败: HTTP {response.status}")
print(f"响应内容: {response_text}")
return None
# 解析JSON响应
try:
return json.loads(response_text)
except json.JSONDecodeError:
print(f"JSON解析错误: {response_text}")
return None
# ===== 用户相关方法 =====
async def get_user(self, username: Optional[str] = None) -> Dict:
"""获取用户信息
Args:
username: 指定用户名,不指定则获取当前授权用户
Returns:
用户信息字典
"""
endpoint = "/user" if username is None else f"/users/{username}"
return await self._request("GET", endpoint)
async def get_user_repos(self, username: Optional[str] = None, sort: str = "updated",
direction: str = "desc", per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取用户的仓库列表
Args:
username: 指定用户名,不指定则获取当前授权用户
sort: 排序方式 (created, updated, pushed, full_name)
direction: 排序方向 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
仓库列表
"""
endpoint = "/user/repos" if username is None else f"/users/{username}/repos"
params = {
"sort": sort,
"direction": direction,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def get_user_starred(self, username: Optional[str] = None,
per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取用户星标的仓库
Args:
username: 指定用户名,不指定则获取当前授权用户
per_page: 每页结果数量
page: 页码
Returns:
星标仓库列表
"""
endpoint = "/user/starred" if username is None else f"/users/{username}/starred"
params = {
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
# ===== 仓库相关方法 =====
async def get_repo(self, owner: str, repo: str) -> Dict:
"""获取仓库信息
Args:
owner: 仓库所有者
repo: 仓库名
Returns:
仓库信息
"""
endpoint = f"/repos/{owner}/{repo}"
return await self._request("GET", endpoint)
async def get_repo_branches(self, owner: str, repo: str, per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取仓库的分支列表
Args:
owner: 仓库所有者
repo: 仓库名
per_page: 每页结果数量
page: 页码
Returns:
分支列表
"""
endpoint = f"/repos/{owner}/{repo}/branches"
params = {
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def get_repo_commits(self, owner: str, repo: str, sha: Optional[str] = None,
path: Optional[str] = None, per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取仓库的提交历史
Args:
owner: 仓库所有者
repo: 仓库名
sha: 特定提交SHA或分支名
path: 文件路径筛选
per_page: 每页结果数量
page: 页码
Returns:
提交列表
"""
endpoint = f"/repos/{owner}/{repo}/commits"
params = {
"per_page": per_page,
"page": page
}
if sha:
params["sha"] = sha
if path:
params["path"] = path
return await self._request("GET", endpoint, params=params)
async def get_commit_details(self, owner: str, repo: str, commit_sha: str) -> Dict:
"""获取特定提交的详情
Args:
owner: 仓库所有者
repo: 仓库名
commit_sha: 提交SHA
Returns:
提交详情
"""
endpoint = f"/repos/{owner}/{repo}/commits/{commit_sha}"
return await self._request("GET", endpoint)
# ===== 内容相关方法 =====
async def get_file_content(self, owner: str, repo: str, path: str, ref: Optional[str] = None) -> Dict:
"""获取文件内容
Args:
owner: 仓库所有者
repo: 仓库名
path: 文件路径
ref: 分支名、标签名或提交SHA
Returns:
文件内容信息
"""
endpoint = f"/repos/{owner}/{repo}/contents/{path}"
params = {}
if ref:
params["ref"] = ref
response = await self._request("GET", endpoint, params=params)
if response and isinstance(response, dict) and "content" in response:
try:
# 解码Base64编码的文件内容
content = base64.b64decode(response["content"].encode()).decode()
response["decoded_content"] = content
except Exception as e:
print(f"解码文件内容时出错: {str(e)}")
return response
async def get_directory_content(self, owner: str, repo: str, path: str, ref: Optional[str] = None) -> List[Dict]:
"""获取目录内容
Args:
owner: 仓库所有者
repo: 仓库名
path: 目录路径
ref: 分支名、标签名或提交SHA
Returns:
目录内容列表
"""
# 注意此方法与get_file_content使用相同的端点,但对于目录会返回列表
endpoint = f"/repos/{owner}/{repo}/contents/{path}"
params = {}
if ref:
params["ref"] = ref
return await self._request("GET", endpoint, params=params)
# ===== Issues相关方法 =====
async def get_issues(self, owner: str, repo: str, state: str = "open",
sort: str = "created", direction: str = "desc",
per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取仓库的Issues列表
Args:
owner: 仓库所有者
repo: 仓库名
state: Issue状态 (open, closed, all)
sort: 排序方式 (created, updated, comments)
direction: 排序方向 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
Issues列表
"""
endpoint = f"/repos/{owner}/{repo}/issues"
params = {
"state": state,
"sort": sort,
"direction": direction,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def get_issue(self, owner: str, repo: str, issue_number: int) -> Dict:
"""获取特定Issue的详情
Args:
owner: 仓库所有者
repo: 仓库名
issue_number: Issue编号
Returns:
Issue详情
"""
endpoint = f"/repos/{owner}/{repo}/issues/{issue_number}"
return await self._request("GET", endpoint)
async def get_issue_comments(self, owner: str, repo: str, issue_number: int) -> List[Dict]:
"""获取Issue的评论
Args:
owner: 仓库所有者
repo: 仓库名
issue_number: Issue编号
Returns:
评论列表
"""
endpoint = f"/repos/{owner}/{repo}/issues/{issue_number}/comments"
return await self._request("GET", endpoint)
# ===== Pull Requests相关方法 =====
async def get_pull_requests(self, owner: str, repo: str, state: str = "open",
sort: str = "created", direction: str = "desc",
per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取仓库的Pull Request列表
Args:
owner: 仓库所有者
repo: 仓库名
state: PR状态 (open, closed, all)
sort: 排序方式 (created, updated, popularity, long-running)
direction: 排序方向 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
Pull Request列表
"""
endpoint = f"/repos/{owner}/{repo}/pulls"
params = {
"state": state,
"sort": sort,
"direction": direction,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def get_pull_request(self, owner: str, repo: str, pr_number: int) -> Dict:
"""获取特定Pull Request的详情
Args:
owner: 仓库所有者
repo: 仓库名
pr_number: Pull Request编号
Returns:
Pull Request详情
"""
endpoint = f"/repos/{owner}/{repo}/pulls/{pr_number}"
return await self._request("GET", endpoint)
async def get_pull_request_files(self, owner: str, repo: str, pr_number: int) -> List[Dict]:
"""获取Pull Request中修改的文件
Args:
owner: 仓库所有者
repo: 仓库名
pr_number: Pull Request编号
Returns:
修改文件列表
"""
endpoint = f"/repos/{owner}/{repo}/pulls/{pr_number}/files"
return await self._request("GET", endpoint)
# ===== 搜索相关方法 =====
async def search_repositories(self, query: str, sort: str = "stars",
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
"""搜索仓库
Args:
query: 搜索关键词
sort: 排序方式 (stars, forks, updated)
order: 排序顺序 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
搜索结果
"""
endpoint = "/search/repositories"
params = {
"q": query,
"sort": sort,
"order": order,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def search_code(self, query: str, sort: str = "indexed",
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
"""搜索代码
Args:
query: 搜索关键词
sort: 排序方式 (indexed)
order: 排序顺序 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
搜索结果
"""
endpoint = "/search/code"
params = {
"q": query,
"sort": sort,
"order": order,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def search_issues(self, query: str, sort: str = "created",
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
"""搜索Issues和Pull Requests
Args:
query: 搜索关键词
sort: 排序方式 (created, updated, comments)
order: 排序顺序 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
搜索结果
"""
endpoint = "/search/issues"
params = {
"q": query,
"sort": sort,
"order": order,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def search_users(self, query: str, sort: str = "followers",
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
"""搜索用户
Args:
query: 搜索关键词
sort: 排序方式 (followers, repositories, joined)
order: 排序顺序 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
搜索结果
"""
endpoint = "/search/users"
params = {
"q": query,
"sort": sort,
"order": order,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
# ===== 组织相关方法 =====
async def get_organization(self, org: str) -> Dict:
"""获取组织信息
Args:
org: 组织名称
Returns:
组织信息
"""
endpoint = f"/orgs/{org}"
return await self._request("GET", endpoint)
async def get_organization_repos(self, org: str, type: str = "all",
sort: str = "created", direction: str = "desc",
per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取组织的仓库列表
Args:
org: 组织名称
type: 仓库类型 (all, public, private, forks, sources, member, internal)
sort: 排序方式 (created, updated, pushed, full_name)
direction: 排序方向 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
仓库列表
"""
endpoint = f"/orgs/{org}/repos"
params = {
"type": type,
"sort": sort,
"direction": direction,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def get_organization_members(self, org: str, per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取组织成员列表
Args:
org: 组织名称
per_page: 每页结果数量
page: 页码
Returns:
成员列表
"""
endpoint = f"/orgs/{org}/members"
params = {
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
# ===== 更复杂的操作 =====
async def get_repository_languages(self, owner: str, repo: str) -> Dict:
"""获取仓库使用的编程语言及其比例
Args:
owner: 仓库所有者
repo: 仓库名
Returns:
语言使用情况
"""
endpoint = f"/repos/{owner}/{repo}/languages"
return await self._request("GET", endpoint)
async def get_repository_stats_contributors(self, owner: str, repo: str) -> List[Dict]:
"""获取仓库的贡献者统计
Args:
owner: 仓库所有者
repo: 仓库名
Returns:
贡献者统计信息
"""
endpoint = f"/repos/{owner}/{repo}/stats/contributors"
return await self._request("GET", endpoint)
async def get_repository_stats_commit_activity(self, owner: str, repo: str) -> List[Dict]:
"""获取仓库的提交活动
Args:
owner: 仓库所有者
repo: 仓库名
Returns:
提交活动统计
"""
endpoint = f"/repos/{owner}/{repo}/stats/commit_activity"
return await self._request("GET", endpoint)
async def example_usage():
"""GitHubSource使用示例"""
# 创建客户端实例可选传入API令牌
# github = GitHubSource(api_key="your_github_token")
github = GitHubSource()
try:
# 示例1搜索热门Python仓库
print("\n=== 示例1搜索热门Python仓库 ===")
repos = await github.search_repositories(
query="language:python stars:>1000",
sort="stars",
order="desc",
per_page=5
)
if repos and "items" in repos:
for i, repo in enumerate(repos["items"], 1):
print(f"\n--- 仓库 {i} ---")
print(f"名称: {repo['full_name']}")
print(f"描述: {repo['description']}")
print(f"星标数: {repo['stargazers_count']}")
print(f"Fork数: {repo['forks_count']}")
print(f"最近更新: {repo['updated_at']}")
print(f"URL: {repo['html_url']}")
# 示例2获取特定仓库的详情
print("\n=== 示例2获取特定仓库的详情 ===")
repo_details = await github.get_repo("microsoft", "vscode")
if repo_details:
print(f"名称: {repo_details['full_name']}")
print(f"描述: {repo_details['description']}")
print(f"星标数: {repo_details['stargazers_count']}")
print(f"Fork数: {repo_details['forks_count']}")
print(f"默认分支: {repo_details['default_branch']}")
print(f"开源许可: {repo_details.get('license', {}).get('name', '')}")
print(f"语言: {repo_details['language']}")
print(f"Open Issues数: {repo_details['open_issues_count']}")
# 示例3获取仓库的提交历史
print("\n=== 示例3获取仓库的最近提交 ===")
commits = await github.get_repo_commits("tensorflow", "tensorflow", per_page=5)
if commits:
for i, commit in enumerate(commits, 1):
print(f"\n--- 提交 {i} ---")
print(f"SHA: {commit['sha'][:7]}")
print(f"作者: {commit['commit']['author']['name']}")
print(f"日期: {commit['commit']['author']['date']}")
print(f"消息: {commit['commit']['message'].splitlines()[0]}")
# 示例4搜索代码
print("\n=== 示例4搜索代码 ===")
code_results = await github.search_code(
query="filename:README.md language:markdown pytorch in:file",
per_page=3
)
if code_results and "items" in code_results:
print(f"共找到: {code_results['total_count']} 个结果")
for i, item in enumerate(code_results["items"], 1):
print(f"\n--- 代码 {i} ---")
print(f"仓库: {item['repository']['full_name']}")
print(f"文件: {item['path']}")
print(f"URL: {item['html_url']}")
# 示例5获取文件内容
print("\n=== 示例5获取文件内容 ===")
file_content = await github.get_file_content("python", "cpython", "README.rst")
if file_content and "decoded_content" in file_content:
content = file_content["decoded_content"]
print(f"文件名: {file_content['name']}")
print(f"大小: {file_content['size']} 字节")
print(f"内容预览: {content[:200]}...")
# 示例6获取仓库使用的编程语言
print("\n=== 示例6获取仓库使用的编程语言 ===")
languages = await github.get_repository_languages("facebook", "react")
if languages:
print(f"React仓库使用的编程语言:")
for lang, bytes_of_code in languages.items():
print(f"- {lang}: {bytes_of_code} 字节")
# 示例7获取组织信息
print("\n=== 示例7获取组织信息 ===")
org_info = await github.get_organization("google")
if org_info:
print(f"名称: {org_info['name']}")
print(f"描述: {org_info.get('description', '')}")
print(f"位置: {org_info.get('location', '未指定')}")
print(f"公共仓库数: {org_info['public_repos']}")
print(f"成员数: {org_info.get('public_members', 0)}")
print(f"URL: {org_info['html_url']}")
# 示例8获取用户信息
print("\n=== 示例8获取用户信息 ===")
user_info = await github.get_user("torvalds")
if user_info:
print(f"名称: {user_info['name']}")
print(f"公司: {user_info.get('company', '')}")
print(f"博客: {user_info.get('blog', '')}")
print(f"位置: {user_info.get('location', '未指定')}")
print(f"公共仓库数: {user_info['public_repos']}")
print(f"关注者数: {user_info['followers']}")
print(f"URL: {user_info['html_url']}")
except Exception as e:
print(f"发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
import asyncio
# 运行示例
asyncio.run(example_usage())

查看文件

@@ -0,0 +1,593 @@
from typing import List, Dict, Optional, Tuple, Union, Any
from dataclasses import dataclass, field
import os
import re
import logging
from crazy_functions.doc_fns.read_fns.unstructured_all.paper_structure_extractor import (
PaperStructureExtractor, PaperSection, StructuredPaper
)
from unstructured.partition.auto import partition
from unstructured.documents.elements import (
Text, Title, NarrativeText, ListItem, Table,
Footer, Header, PageBreak, Image, Address
)
@dataclass
class DocumentSection:
"""通用文档章节数据类"""
title: str # 章节标题,如果没有标题则为空字符串
content: str # 章节内容
level: int = 0 # 标题级别,0为主标题,1为一级标题,以此类推
section_type: str = "content" # 章节类型
is_heading_only: bool = False # 是否仅包含标题
subsections: List['DocumentSection'] = field(default_factory=list) # 子章节列表
@dataclass
class StructuredDocument:
"""结构化文档数据类"""
title: str = "" # 文档标题
metadata: Dict[str, Any] = field(default_factory=dict) # 元数据
sections: List[DocumentSection] = field(default_factory=list) # 章节列表
full_text: str = "" # 完整文本
is_paper: bool = False # 是否为学术论文
class GenericDocumentStructureExtractor:
"""通用文档结构提取器
可以从各种文档格式中提取结构信息,包括标题和内容。
支持论文、报告、文章和一般文本文档。
"""
# 支持的文件扩展名
SUPPORTED_EXTENSIONS = [
'.pdf', '.docx', '.doc', '.pptx', '.ppt',
'.txt', '.md', '.html', '.htm', '.xml',
'.rtf', '.odt', '.epub', '.msg', '.eml'
]
# 常见的标题前缀模式
HEADING_PATTERNS = [
# 数字标题 (1., 1.1., etc.)
r'^\s*(\d+\.)+\s+',
# 中文数字标题 (一、, 二、, etc.)
r'^\s*[一二三四五六七八九十]+[、::]\s+',
# 带括号的数字标题 ((1), (2), etc.)
r'^\s*\(\s*\d+\s*\)\s+',
# 特定标记的标题 (Chapter 1, Section 1, etc.)
r'^\s*(chapter|section|part|附录|章|节)\s+\d+[\.:]\s+',
]
# 常见的文档分段标记词
SECTION_MARKERS = {
'introduction': ['简介', '导言', '引言', 'introduction', '概述', 'overview'],
'background': ['背景', '现状', 'background', '理论基础', '相关工作'],
'main_content': ['主要内容', '正文', 'main content', '分析', '讨论'],
'conclusion': ['结论', '总结', 'conclusion', '结语', '小结', 'summary'],
'reference': ['参考', '参考文献', 'references', '文献', 'bibliography'],
'appendix': ['附录', 'appendix', '补充资料', 'supplementary']
}
def __init__(self):
"""初始化提取器"""
self.paper_extractor = PaperStructureExtractor() # 论文专用提取器
self._setup_logging()
def _setup_logging(self):
"""配置日志"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
self.logger = logging.getLogger(__name__)
def extract_document_structure(self, file_path: str, strategy: str = "fast") -> StructuredDocument:
"""提取文档结构
Args:
file_path: 文件路径
strategy: 提取策略 ("fast""accurate")
Returns:
StructuredDocument: 结构化文档对象
"""
try:
self.logger.info(f"正在处理文档结构: {file_path}")
# 1. 首先尝试使用论文提取器
try:
paper_result = self.paper_extractor.extract_paper_structure(file_path)
if paper_result and len(paper_result.sections) > 2: # 如果成功识别为论文结构
self.logger.info(f"成功识别为学术论文: {file_path}")
# 将论文结构转换为通用文档结构
return self._convert_paper_to_document(paper_result)
except Exception as e:
self.logger.debug(f"论文结构提取失败,将尝试通用提取: {str(e)}")
# 2. 使用通用方法提取文档结构
elements = partition(
str(file_path),
strategy=strategy,
include_metadata=True,
nlp=False
)
# 3. 使用通用提取器处理
doc = self._extract_generic_structure(elements)
return doc
except Exception as e:
self.logger.error(f"文档结构提取失败: {str(e)}")
# 返回一个空的结构化文档
return StructuredDocument(
title="未能提取文档标题",
sections=[DocumentSection(
title="",
content="",
level=0,
section_type="content"
)]
)
def _convert_paper_to_document(self, paper: StructuredPaper) -> StructuredDocument:
"""将论文结构转换为通用文档结构
Args:
paper: 结构化论文对象
Returns:
StructuredDocument: 转换后的通用文档结构
"""
doc = StructuredDocument(
title=paper.metadata.title,
is_paper=True,
full_text=paper.full_text
)
# 转换元数据
doc.metadata = {
'title': paper.metadata.title,
'authors': paper.metadata.authors,
'keywords': paper.keywords,
'abstract': paper.metadata.abstract if hasattr(paper.metadata, 'abstract') else "",
'is_paper': True
}
# 转换章节结构
doc.sections = self._convert_paper_sections(paper.sections)
return doc
def _convert_paper_sections(self, paper_sections: List[PaperSection], level: int = 0) -> List[DocumentSection]:
"""递归转换论文章节为通用文档章节
Args:
paper_sections: 论文章节列表
level: 当前章节级别
Returns:
List[DocumentSection]: 通用文档章节列表
"""
doc_sections = []
for section in paper_sections:
doc_section = DocumentSection(
title=section.title,
content=section.content,
level=section.level,
section_type=section.section_type,
is_heading_only=False if section.content else True
)
# 递归处理子章节
if section.subsections:
doc_section.subsections = self._convert_paper_sections(
section.subsections, level + 1
)
doc_sections.append(doc_section)
return doc_sections
def _extract_generic_structure(self, elements) -> StructuredDocument:
"""从元素列表中提取通用文档结构
Args:
elements: 文档元素列表
Returns:
StructuredDocument: 结构化文档对象
"""
# 创建结构化文档对象
doc = StructuredDocument(full_text="")
# 1. 提取文档标题
title_candidates = []
for i, element in enumerate(elements[:5]): # 只检查前5个元素
if isinstance(element, Title):
title_text = str(element).strip()
title_candidates.append((i, title_text))
if title_candidates:
# 使用第一个标题作为文档标题
doc.title = title_candidates[0][1]
# 2. 识别所有标题元素和内容
title_elements = []
# 2.1 首先识别所有标题
for i, element in enumerate(elements):
is_heading = False
title_text = ""
level = 0
# 检查元素类型
if isinstance(element, Title):
is_heading = True
title_text = str(element).strip()
# 进一步检查是否为真正的标题
if self._is_likely_heading(title_text, element, i, elements):
level = self._estimate_heading_level(title_text, element)
else:
is_heading = False
# 也检查格式像标题的普通文本
elif isinstance(element, (Text, NarrativeText)) and i > 0:
text = str(element).strip()
# 检查是否匹配标题模式
if any(re.match(pattern, text) for pattern in self.HEADING_PATTERNS):
# 检查长度和后续内容以确认是否为标题
if len(text) < 100 and self._has_sufficient_following_content(i, elements):
is_heading = True
title_text = text
level = self._estimate_heading_level(title_text, element)
if is_heading:
section_type = self._identify_section_type(title_text)
title_elements.append((i, title_text, level, section_type))
# 2.2 为每个标题提取内容
sections = []
for i, (index, title_text, level, section_type) in enumerate(title_elements):
# 确定内容范围
content_start = index + 1
content_end = elements[-1] # 默认到文档结束
# 如果有下一个标题,内容到下一个标题开始
if i < len(title_elements) - 1:
content_end = title_elements[i+1][0]
else:
content_end = len(elements)
# 提取内容
content = self._extract_content_between(elements, content_start, content_end)
# 创建章节
section = DocumentSection(
title=title_text,
content=content,
level=level,
section_type=section_type,
is_heading_only=False if content.strip() else True
)
sections.append(section)
# 3. 如果没有识别到任何章节,创建一个默认章节
if not sections:
all_content = self._extract_content_between(elements, 0, len(elements))
# 尝试从内容中提取标题
first_line = all_content.split('\n')[0] if all_content else ""
if first_line and len(first_line) < 100:
doc.title = first_line
all_content = '\n'.join(all_content.split('\n')[1:])
default_section = DocumentSection(
title="",
content=all_content,
level=0,
section_type="content"
)
sections.append(default_section)
# 4. 构建层次结构
doc.sections = self._build_section_hierarchy(sections)
# 5. 提取完整文本
doc.full_text = "\n\n".join([str(element) for element in elements if isinstance(element, (Text, NarrativeText, Title, ListItem))])
return doc
def _build_section_hierarchy(self, sections: List[DocumentSection]) -> List[DocumentSection]:
"""构建章节层次结构
Args:
sections: 章节列表
Returns:
List[DocumentSection]: 具有层次结构的章节列表
"""
if not sections:
return []
# 按层级排序
top_level_sections = []
current_parents = {0: None} # 每个层级的当前父节点
for section in sections:
# 找到当前节点的父节点
parent_level = None
for level in sorted([k for k in current_parents.keys() if k < section.level], reverse=True):
parent_level = level
break
if parent_level is None:
# 顶级章节
top_level_sections.append(section)
else:
# 子章节
parent = current_parents[parent_level]
if parent:
parent.subsections.append(section)
else:
top_level_sections.append(section)
# 更新当前层级的父节点
current_parents[section.level] = section
# 清除所有更深层级的父节点缓存
deeper_levels = [k for k in current_parents.keys() if k > section.level]
for level in deeper_levels:
current_parents.pop(level, None)
return top_level_sections
def _is_likely_heading(self, text: str, element, index: int, elements) -> bool:
"""判断文本是否可能是标题
Args:
text: 文本内容
element: 元素对象
index: 元素索引
elements: 所有元素列表
Returns:
bool: 是否可能是标题
"""
# 1. 检查文本长度 - 标题通常不会太长
if len(text) > 150: # 标题通常不超过150个字符
return False
# 2. 检查是否匹配标题的数字编号模式
if any(re.match(pattern, text) for pattern in self.HEADING_PATTERNS):
return True
# 3. 检查是否包含常见章节标记词
lower_text = text.lower()
for markers in self.SECTION_MARKERS.values():
if any(marker.lower() in lower_text for marker in markers):
return True
# 4. 检查后续内容数量 - 标题后通常有足够多的内容
if not self._has_sufficient_following_content(index, elements, min_chars=100):
# 但如果文本很短且以特定格式开头,仍可能是标题
if len(text) < 50 and (text.endswith(':') or text.endswith('')):
return True
return False
# 5. 检查格式特征
# 标题通常是元素的开头,不在段落中间
if len(text.split('\n')) > 1:
# 多行文本不太可能是标题
return False
# 如果有元数据,检查字体特征(字体大小等)
if hasattr(element, 'metadata') and element.metadata:
try:
font_size = getattr(element.metadata, 'font_size', None)
is_bold = getattr(element.metadata, 'is_bold', False)
# 字体较大或加粗的文本更可能是标题
if font_size and font_size > 12:
return True
if is_bold:
return True
except (AttributeError, TypeError):
pass
# 默认返回True,因为元素已被识别为Title类型
return True
def _estimate_heading_level(self, text: str, element) -> int:
"""估计标题的层级
Args:
text: 标题文本
element: 元素对象
Returns:
int: 标题层级 (0为主标题,1为一级标题, 等等)
"""
# 1. 通过编号模式判断层级
for pattern, level in [
(r'^\s*\d+\.\s+', 1), # 1. 开头 (一级标题)
(r'^\s*\d+\.\d+\.\s+', 2), # 1.1. 开头 (二级标题)
(r'^\s*\d+\.\d+\.\d+\.\s+', 3), # 1.1.1. 开头 (三级标题)
(r'^\s*\d+\.\d+\.\d+\.\d+\.\s+', 4), # 1.1.1.1. 开头 (四级标题)
]:
if re.match(pattern, text):
return level
# 2. 检查是否是常见的主要章节标题
lower_text = text.lower()
main_sections = [
'abstract', 'introduction', 'background', 'methodology',
'results', 'discussion', 'conclusion', 'references'
]
for section in main_sections:
if section in lower_text:
return 1 # 主要章节为一级标题
# 3. 根据文本特征判断
if text.isupper(): # 全大写文本可能是章标题
return 1
# 4. 通过元数据判断层级
if hasattr(element, 'metadata') and element.metadata:
try:
# 根据字体大小判断层级
font_size = getattr(element.metadata, 'font_size', None)
if font_size is not None:
if font_size > 18: # 假设主标题字体最大
return 0
elif font_size > 16:
return 1
elif font_size > 14:
return 2
else:
return 3
except (AttributeError, TypeError):
pass
# 默认为二级标题
return 2
def _identify_section_type(self, title_text: str) -> str:
"""识别章节类型,包括参考文献部分"""
lower_text = title_text.lower()
# 特别检查是否为参考文献部分
references_patterns = [
r'references', r'参考文献', r'bibliography', r'引用文献',
r'literature cited', r'^cited\s+literature', r'^文献$', r'^引用$'
]
for pattern in references_patterns:
if re.search(pattern, lower_text, re.IGNORECASE):
return "references"
# 检查是否匹配其他常见章节类型
for section_type, markers in self.SECTION_MARKERS.items():
if any(marker.lower() in lower_text for marker in markers):
return section_type
# 检查带编号的章节
if re.match(r'^\d+\.', lower_text):
return "content"
# 默认为内容章节
return "content"
def _has_sufficient_following_content(self, index: int, elements, min_chars: int = 150) -> bool:
"""检查元素后是否有足够的内容
Args:
index: 当前元素索引
elements: 所有元素列表
min_chars: 最小字符数要求
Returns:
bool: 是否有足够的内容
"""
total_chars = 0
for i in range(index + 1, min(index + 5, len(elements))):
if isinstance(elements[i], Title):
# 如果紧接着是标题,就停止检查
break
if isinstance(elements[i], (Text, NarrativeText, ListItem, Table)):
total_chars += len(str(elements[i]))
if total_chars >= min_chars:
return True
return total_chars >= min_chars
def _extract_content_between(self, elements, start_index: int, end_index: int) -> str:
"""提取指定范围内的内容文本
Args:
elements: 元素列表
start_index: 开始索引
end_index: 结束索引
Returns:
str: 提取的内容文本
"""
content_parts = []
for i in range(start_index, end_index):
if isinstance(elements[i], (Text, NarrativeText, ListItem, Table)):
content_parts.append(str(elements[i]).strip())
return "\n\n".join([part for part in content_parts if part])
def generate_markdown(self, doc: StructuredDocument) -> str:
"""将结构化文档转换为Markdown格式
Args:
doc: 结构化文档对象
Returns:
str: Markdown格式文本
"""
md_parts = []
# 添加标题
if doc.title:
md_parts.append(f"# {doc.title}\n")
# 添加元数据
if doc.is_paper:
# 作者信息
if 'authors' in doc.metadata and doc.metadata['authors']:
authors_str = ", ".join(doc.metadata['authors'])
md_parts.append(f"**作者:** {authors_str}\n")
# 关键词
if 'keywords' in doc.metadata and doc.metadata['keywords']:
keywords_str = ", ".join(doc.metadata['keywords'])
md_parts.append(f"**关键词:** {keywords_str}\n")
# 摘要
if 'abstract' in doc.metadata and doc.metadata['abstract']:
md_parts.append(f"## 摘要\n\n{doc.metadata['abstract']}\n")
# 添加章节内容
md_parts.append(self._format_sections_markdown(doc.sections))
return "\n".join(md_parts)
def _format_sections_markdown(self, sections: List[DocumentSection], base_level: int = 0) -> str:
"""递归格式化章节为Markdown
Args:
sections: 章节列表
base_level: 基础层级
Returns:
str: Markdown格式文本
"""
md_parts = []
for section in sections:
# 计算标题级别 (确保不超过6级)
header_level = min(section.level + base_level + 1, 6)
# 添加标题和内容
if section.title:
md_parts.append(f"{'#' * header_level} {section.title}\n")
if section.content:
md_parts.append(f"{section.content}\n")
# 递归处理子章节
if section.subsections:
md_parts.append(self._format_sections_markdown(
section.subsections, base_level
))
return "\n".join(md_parts)

查看文件

@@ -0,0 +1,4 @@
from .txt_doc import TxtFormatter
from .markdown_doc import MarkdownFormatter
from .html_doc import HtmlFormatter
from .word_doc import WordFormatter

查看文件

@@ -0,0 +1,300 @@
class HtmlFormatter:
"""HTML格式文档生成器 - 保留原始文档结构"""
def __init__(self, processing_type="文本处理"):
self.processing_type = processing_type
self.css_styles = """
:root {
--primary-color: #2563eb;
--primary-light: #eff6ff;
--secondary-color: #1e293b;
--background-color: #f8fafc;
--text-color: #334155;
--border-color: #e2e8f0;
--card-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1);
}
body {
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
line-height: 1.8;
margin: 0;
padding: 2rem;
color: var(--text-color);
background-color: var(--background-color);
}
.container {
max-width: 1200px;
margin: 0 auto;
background: white;
padding: 2rem;
border-radius: 16px;
box-shadow: var(--card-shadow);
}
::selection {
background: var(--primary-light);
color: var(--primary-color);
}
@keyframes fadeIn {
from { opacity: 0; transform: translateY(20px); }
to { opacity: 1; transform: translateY(0); }
}
.container {
animation: fadeIn 0.6s ease-out;
}
.document-title {
color: var(--primary-color);
font-size: 2em;
text-align: center;
margin: 1rem 0 2rem;
padding-bottom: 1rem;
border-bottom: 2px solid var(--primary-color);
}
.document-body {
display: flex;
flex-direction: column;
gap: 1.5rem;
margin: 2rem 0;
}
.document-header {
display: flex;
flex-direction: column;
align-items: center;
margin-bottom: 2rem;
}
.processing-type {
color: var(--secondary-color);
font-size: 1.2em;
margin: 0.5rem 0;
}
.processing-date {
color: var(--text-color);
font-size: 0.9em;
opacity: 0.8;
}
.document-content {
background: white;
padding: 1.5rem;
border-radius: 8px;
border-left: 4px solid var(--primary-color);
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
}
/* 保留文档结构的样式 */
h1, h2, h3, h4, h5, h6 {
color: var(--secondary-color);
margin-top: 1.5em;
margin-bottom: 0.5em;
}
h1 { font-size: 1.8em; }
h2 { font-size: 1.5em; }
h3 { font-size: 1.3em; }
h4 { font-size: 1.1em; }
p {
margin: 0.8em 0;
}
ul, ol {
margin: 1em 0;
padding-left: 2em;
}
li {
margin: 0.5em 0;
}
blockquote {
margin: 1em 0;
padding: 0.5em 1em;
border-left: 4px solid var(--primary-light);
background: rgba(0,0,0,0.02);
}
code {
font-family: monospace;
background: rgba(0,0,0,0.05);
padding: 0.2em 0.4em;
border-radius: 3px;
}
pre {
background: rgba(0,0,0,0.05);
padding: 1em;
border-radius: 5px;
overflow-x: auto;
}
pre code {
background: transparent;
padding: 0;
}
@media (prefers-color-scheme: dark) {
:root {
--background-color: #0f172a;
--text-color: #e2e8f0;
--border-color: #1e293b;
}
.container, .document-content {
background: #1e293b;
}
blockquote {
background: rgba(255,255,255,0.05);
}
code, pre {
background: rgba(255,255,255,0.05);
}
}
"""
def _escape_html(self, text):
"""转义HTML特殊字符"""
import html
return html.escape(text)
def _markdown_to_html(self, text):
"""将Markdown格式转换为HTML格式,保留文档结构"""
try:
import markdown
# 使用Python-Markdown库将markdown转换为HTML,启用更多扩展以支持嵌套列表
return markdown.markdown(text, extensions=['tables', 'fenced_code', 'codehilite', 'nl2br', 'sane_lists', 'smarty', 'extra'])
except ImportError:
# 如果没有markdown库,使用更复杂的替换来处理嵌套列表
import re
# 替换标题
text = re.sub(r'^# (.+)$', r'<h1>\1</h1>', text, flags=re.MULTILINE)
text = re.sub(r'^## (.+)$', r'<h2>\1</h2>', text, flags=re.MULTILINE)
text = re.sub(r'^### (.+)$', r'<h3>\1</h3>', text, flags=re.MULTILINE)
# 预处理列表 - 在列表项之间添加空行以正确分隔
# 处理编号列表
text = re.sub(r'(\n\d+\.\s.+)(\n\d+\.\s)', r'\1\n\2', text)
# 处理项目符号列表
text = re.sub(r'(\n•\s.+)(\n•\s)', r'\1\n\2', text)
text = re.sub(r'(\n\*\s.+)(\n\*\s)', r'\1\n\2', text)
text = re.sub(r'(\n-\s.+)(\n-\s)', r'\1\n\2', text)
# 处理嵌套列表 - 确保正确的缩进和结构
lines = text.split('\n')
in_list = False
list_type = None # 'ol' 或 'ul'
list_html = []
normal_lines = []
i = 0
while i < len(lines):
line = lines[i]
# 匹配编号列表项
numbered_match = re.match(r'^(\d+)\.\s+(.+)$', line)
# 匹配项目符号列表项
bullet_match = re.match(r'^[•\*-]\s+(.+)$', line)
if numbered_match:
if not in_list or list_type != 'ol':
# 开始新的编号列表
if in_list:
# 关闭前一个列表
list_html.append(f'</{list_type}>')
list_html.append('<ol>')
in_list = True
list_type = 'ol'
num, content = numbered_match.groups()
list_html.append(f'<li>{content}</li>')
elif bullet_match:
if not in_list or list_type != 'ul':
# 开始新的项目符号列表
if in_list:
# 关闭前一个列表
list_html.append(f'</{list_type}>')
list_html.append('<ul>')
in_list = True
list_type = 'ul'
content = bullet_match.group(1)
list_html.append(f'<li>{content}</li>')
else:
if in_list:
# 结束当前列表
list_html.append(f'</{list_type}>')
in_list = False
# 将完成的列表添加到正常行中
normal_lines.append(''.join(list_html))
list_html = []
normal_lines.append(line)
i += 1
# 如果最后还在列表中,确保关闭列表
if in_list:
list_html.append(f'</{list_type}>')
normal_lines.append(''.join(list_html))
# 重建文本
text = '\n'.join(normal_lines)
# 替换段落,但避免处理已经是HTML标签的部分
paragraphs = text.split('\n\n')
for i, p in enumerate(paragraphs):
# 如果不是以HTML标签开始且不为空
if not (p.strip().startswith('<') and p.strip().endswith('>')) and p.strip() != '':
paragraphs[i] = f'<p>{p}</p>'
return '\n'.join(paragraphs)
def create_document(self, content: str) -> str:
"""生成完整的HTML文档,保留原始文档结构
Args:
content: 处理后的文档内容
Returns:
str: 完整的HTML文档字符串
"""
from datetime import datetime
# 将markdown内容转换为HTML
html_content = self._markdown_to_html(content)
return f"""
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>文档处理结果</title>
<style>{self.css_styles}</style>
</head>
<body>
<div class="container">
<h1 class="document-title">文档处理结果</h1>
<div class="document-header">
<div class="processing-type">处理方式: {self._escape_html(self.processing_type)}</div>
<div class="processing-date">处理时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</div>
</div>
<div class="document-content">
{html_content}
</div>
</div>
</body>
</html>
"""

查看文件

@@ -0,0 +1,40 @@
class MarkdownFormatter:
"""Markdown格式文档生成器 - 保留原始文档结构"""
def __init__(self):
self.content = []
def _add_content(self, text: str):
"""添加正文内容"""
if text:
self.content.append(f"\n{text}\n")
def create_document(self, content: str, processing_type: str = "文本处理") -> str:
"""
创建完整的Markdown文档,保留原始文档结构
Args:
content: 处理后的文档内容
processing_type: 处理类型(润色、翻译等)
Returns:
str: 生成的Markdown文本
"""
self.content = []
# 添加标题和说明
self.content.append(f"# 文档处理结果\n")
self.content.append(f"## 处理方式: {processing_type}\n")
# 添加处理时间
from datetime import datetime
self.content.append(f"*处理时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}*\n")
# 添加分隔线
self.content.append("---\n")
# 添加原始内容,保留结构
self.content.append(content)
# 添加结尾分隔线
self.content.append("\n---\n")
return "\n".join(self.content)

查看文件

@@ -0,0 +1,69 @@
import re
def convert_markdown_to_txt(markdown_text):
"""Convert markdown text to plain text while preserving formatting"""
# Standardize line endings
markdown_text = markdown_text.replace('\r\n', '\n').replace('\r', '\n')
# 1. Handle headers but keep their formatting instead of removing them
markdown_text = re.sub(r'^#\s+(.+)$', r'# \1', markdown_text, flags=re.MULTILINE)
markdown_text = re.sub(r'^##\s+(.+)$', r'## \1', markdown_text, flags=re.MULTILINE)
markdown_text = re.sub(r'^###\s+(.+)$', r'### \1', markdown_text, flags=re.MULTILINE)
# 2. Handle bold and italic - simply remove markers
markdown_text = re.sub(r'\*\*(.+?)\*\*', r'\1', markdown_text)
markdown_text = re.sub(r'\*(.+?)\*', r'\1', markdown_text)
# 3. Handle lists but preserve formatting
markdown_text = re.sub(r'^\s*[-*+]\s+(.+?)(?=\n|$)', r'\1', markdown_text, flags=re.MULTILINE)
# 4. Handle links - keep only the text
markdown_text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', r'\1 (\2)', markdown_text)
# 5. Handle HTML links - convert to user-friendly format
markdown_text = re.sub(r'<a href=[\'"]([^\'"]+)[\'"](?:\s+target=[\'"][^\'"]+[\'"])?>([^<]+)</a>', r'\2 (\1)', markdown_text)
# 6. Preserve paragraph breaks
markdown_text = re.sub(r'\n{3,}', '\n\n', markdown_text) # normalize multiple newlines to double newlines
# 7. Clean up extra spaces but maintain indentation
markdown_text = re.sub(r' +', ' ', markdown_text)
return markdown_text.strip()
class TxtFormatter:
"""文本格式化器 - 保留原始文档结构"""
def __init__(self):
self.content = []
self._setup_document()
def _setup_document(self):
"""初始化文档标题"""
self.content.append("=" * 50)
self.content.append("处理后文档".center(48))
self.content.append("=" * 50)
def _format_header(self):
"""创建文档头部信息"""
from datetime import datetime
date_str = datetime.now().strftime('%Y年%m月%d')
return [
date_str.center(48),
"\n" # 添加空行
]
def create_document(self, content):
"""生成保留原始结构的文档"""
# 添加头部信息
self.content.extend(self._format_header())
# 处理内容,保留原始结构
processed_content = convert_markdown_to_txt(content)
# 添加处理后的内容
self.content.append(processed_content)
# 合并所有内容
return "\n".join(self.content)

查看文件

@@ -0,0 +1,125 @@
from docx2pdf import convert
import os
import platform
from typing import Union
from pathlib import Path
from datetime import datetime
class WordToPdfConverter:
"""Word文档转PDF转换器"""
@staticmethod
def convert_to_pdf(word_path: Union[str, Path], pdf_path: Union[str, Path] = None) -> str:
"""
将Word文档转换为PDF
参数:
word_path: Word文档的路径
pdf_path: 可选,PDF文件的输出路径。如果未指定,将使用与Word文档相同的名称和位置
返回:
生成的PDF文件路径
异常:
如果转换失败,将抛出相应异常
"""
try:
# 确保输入路径是Path对象
word_path = Path(word_path)
# 如果未指定pdf_path,则使用与word文档相同的名称
if pdf_path is None:
pdf_path = word_path.with_suffix('.pdf')
else:
pdf_path = Path(pdf_path)
# 检查操作系统
if platform.system() == 'Linux':
# Linux系统需要安装libreoffice
if not os.system('which libreoffice') == 0:
raise RuntimeError("请先安装LibreOffice: sudo apt-get install libreoffice")
# 使用libreoffice进行转换
os.system(f'libreoffice --headless --convert-to pdf "{word_path}" --outdir "{pdf_path.parent}"')
# 如果输出路径与默认生成的不同,则重命名
default_pdf = word_path.with_suffix('.pdf')
if default_pdf != pdf_path:
os.rename(default_pdf, pdf_path)
else:
# Windows和MacOS使用docx2pdf
convert(word_path, pdf_path)
return str(pdf_path)
except Exception as e:
raise Exception(f"转换PDF失败: {str(e)}")
@staticmethod
def batch_convert(word_dir: Union[str, Path], pdf_dir: Union[str, Path] = None) -> list:
"""
批量转换目录下的所有Word文档
参数:
word_dir: 包含Word文档的目录路径
pdf_dir: 可选,PDF文件的输出目录。如果未指定,将使用与Word文档相同的目录
返回:
生成的PDF文件路径列表
"""
word_dir = Path(word_dir)
if pdf_dir:
pdf_dir = Path(pdf_dir)
pdf_dir.mkdir(parents=True, exist_ok=True)
converted_files = []
for word_file in word_dir.glob("*.docx"):
try:
if pdf_dir:
pdf_path = pdf_dir / word_file.with_suffix('.pdf').name
else:
pdf_path = word_file.with_suffix('.pdf')
pdf_file = WordToPdfConverter.convert_to_pdf(word_file, pdf_path)
converted_files.append(pdf_file)
except Exception as e:
print(f"转换 {word_file} 失败: {str(e)}")
return converted_files
@staticmethod
def convert_doc_to_pdf(doc, output_dir: Union[str, Path] = None) -> str:
"""
将docx对象直接转换为PDF
参数:
doc: python-docx的Document对象
output_dir: 可选,输出目录。如果未指定,将使用当前目录
返回:
生成的PDF文件路径
"""
try:
# 设置临时文件路径和输出路径
output_dir = Path(output_dir) if output_dir else Path.cwd()
output_dir.mkdir(parents=True, exist_ok=True)
# 生成临时word文件
temp_docx = output_dir / f"temp_{datetime.now().strftime('%Y%m%d_%H%M%S')}.docx"
doc.save(temp_docx)
# 转换为PDF
pdf_path = temp_docx.with_suffix('.pdf')
WordToPdfConverter.convert_to_pdf(temp_docx, pdf_path)
# 删除临时word文件
temp_docx.unlink()
return str(pdf_path)
except Exception as e:
if temp_docx.exists():
temp_docx.unlink()
raise Exception(f"转换PDF失败: {str(e)}")

查看文件

@@ -0,0 +1,236 @@
import re
from docx import Document
from docx.shared import Cm, Pt
from docx.enum.text import WD_PARAGRAPH_ALIGNMENT, WD_LINE_SPACING
from docx.enum.style import WD_STYLE_TYPE
from docx.oxml.ns import qn
from datetime import datetime
def convert_markdown_to_word(markdown_text):
# 0. 首先标准化所有换行符为\n
markdown_text = markdown_text.replace('\r\n', '\n').replace('\r', '\n')
# 1. 处理标题 - 支持更多级别的标题,使用更精确的正则
# 保留标题标记,以便后续处理时还能识别出标题级别
markdown_text = re.sub(r'^(#{1,6})\s+(.+?)(?:\s+#+)?$', r'\1 \2', markdown_text, flags=re.MULTILINE)
# 2. 处理粗体、斜体和加粗斜体
markdown_text = re.sub(r'\*\*\*(.+?)\*\*\*', r'\1', markdown_text) # 加粗斜体
markdown_text = re.sub(r'\*\*(.+?)\*\*', r'\1', markdown_text) # 加粗
markdown_text = re.sub(r'\*(.+?)\*', r'\1', markdown_text) # 斜体
markdown_text = re.sub(r'_(.+?)_', r'\1', markdown_text) # 下划线斜体
markdown_text = re.sub(r'__(.+?)__', r'\1', markdown_text) # 下划线加粗
# 3. 处理代码块 - 不移除,而是简化格式
# 多行代码块
markdown_text = re.sub(r'```(?:\w+)?\n([\s\S]*?)```', r'[代码块]\n\1[/代码块]', markdown_text)
# 单行代码
markdown_text = re.sub(r'`([^`]+)`', r'[代码]\1[/代码]', markdown_text)
# 4. 处理列表 - 保留列表结构
# 匹配无序列表
markdown_text = re.sub(r'^(\s*)[-*+]\s+(.+?)$', r'\1• \2', markdown_text, flags=re.MULTILINE)
# 5. 处理Markdown链接
markdown_text = re.sub(r'\[([^\]]+)\]\(([^)]+?)\s*(?:"[^"]*")?\)', r'\1 (\2)', markdown_text)
# 6. 处理HTML链接
markdown_text = re.sub(r'<a href=[\'"]([^\'"]+)[\'"](?:\s+target=[\'"][^\'"]+[\'"])?>([^<]+)</a>', r'\2 (\1)', markdown_text)
# 7. 处理图片
markdown_text = re.sub(r'!\[([^\]]*)\]\([^)]+\)', r'[图片:\1]', markdown_text)
return markdown_text
class WordFormatter:
"""文档Word格式化器 - 保留原始文档结构"""
def __init__(self):
self.doc = Document()
self._setup_document()
self._create_styles()
def _setup_document(self):
"""设置文档基本格式,包括页面设置和页眉"""
sections = self.doc.sections
for section in sections:
# 设置页面大小为A4
section.page_width = Cm(21)
section.page_height = Cm(29.7)
# 设置页边距
section.top_margin = Cm(3.7) # 上边距37mm
section.bottom_margin = Cm(3.5) # 下边距35mm
section.left_margin = Cm(2.8) # 左边距28mm
section.right_margin = Cm(2.6) # 右边距26mm
# 设置页眉页脚距离
section.header_distance = Cm(2.0)
section.footer_distance = Cm(2.0)
# 添加页眉
header = section.header
header_para = header.paragraphs[0]
header_para.alignment = WD_PARAGRAPH_ALIGNMENT.RIGHT
header_run = header_para.add_run("文档处理结果")
header_run.font.name = '仿宋'
header_run._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
header_run.font.size = Pt(9)
def _create_styles(self):
"""创建文档样式"""
# 创建正文样式
style = self.doc.styles.add_style('Normal_Custom', WD_STYLE_TYPE.PARAGRAPH)
style.font.name = '仿宋'
style._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
style.font.size = Pt(12) # 调整为12磅
style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
style.paragraph_format.space_after = Pt(0)
# 创建标题样式
title_style = self.doc.styles.add_style('Title_Custom', WD_STYLE_TYPE.PARAGRAPH)
title_style.font.name = '黑体'
title_style._element.rPr.rFonts.set(qn('w:eastAsia'), '黑体')
title_style.font.size = Pt(22) # 调整为22磅
title_style.font.bold = True
title_style.paragraph_format.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
title_style.paragraph_format.space_before = Pt(0)
title_style.paragraph_format.space_after = Pt(24)
title_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
# 创建标题1样式
h1_style = self.doc.styles.add_style('Heading1_Custom', WD_STYLE_TYPE.PARAGRAPH)
h1_style.font.name = '黑体'
h1_style._element.rPr.rFonts.set(qn('w:eastAsia'), '黑体')
h1_style.font.size = Pt(18)
h1_style.font.bold = True
h1_style.paragraph_format.space_before = Pt(12)
h1_style.paragraph_format.space_after = Pt(6)
# 创建标题2样式
h2_style = self.doc.styles.add_style('Heading2_Custom', WD_STYLE_TYPE.PARAGRAPH)
h2_style.font.name = '黑体'
h2_style._element.rPr.rFonts.set(qn('w:eastAsia'), '黑体')
h2_style.font.size = Pt(16)
h2_style.font.bold = True
h2_style.paragraph_format.space_before = Pt(10)
h2_style.paragraph_format.space_after = Pt(6)
# 创建标题3样式
h3_style = self.doc.styles.add_style('Heading3_Custom', WD_STYLE_TYPE.PARAGRAPH)
h3_style.font.name = '黑体'
h3_style._element.rPr.rFonts.set(qn('w:eastAsia'), '黑体')
h3_style.font.size = Pt(14)
h3_style.font.bold = True
h3_style.paragraph_format.space_before = Pt(8)
h3_style.paragraph_format.space_after = Pt(4)
# 创建代码块样式
code_style = self.doc.styles.add_style('Code_Custom', WD_STYLE_TYPE.PARAGRAPH)
code_style.font.name = 'Courier New'
code_style.font.size = Pt(11)
code_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.SINGLE
code_style.paragraph_format.space_before = Pt(6)
code_style.paragraph_format.space_after = Pt(6)
code_style.paragraph_format.left_indent = Pt(36)
code_style.paragraph_format.right_indent = Pt(36)
# 创建列表样式
list_style = self.doc.styles.add_style('List_Custom', WD_STYLE_TYPE.PARAGRAPH)
list_style.font.name = '仿宋'
list_style._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
list_style.font.size = Pt(12)
list_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
list_style.paragraph_format.left_indent = Pt(21)
list_style.paragraph_format.first_line_indent = Pt(-21)
def create_document(self, content: str, processing_type: str = "文本处理"):
"""创建文档,保留原始结构"""
# 添加标题
title_para = self.doc.add_paragraph(style='Title_Custom')
title_run = title_para.add_run('文档处理结果')
# 添加处理类型
processing_para = self.doc.add_paragraph()
processing_para.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
processing_run = processing_para.add_run(f"处理方式: {processing_type}")
processing_run.font.name = '仿宋'
processing_run._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
processing_run.font.size = Pt(14)
# 添加日期
date_para = self.doc.add_paragraph()
date_para.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
date_run = date_para.add_run(f"处理时间: {datetime.now().strftime('%Y年%m月%d')}")
date_run.font.name = '仿宋'
date_run._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
date_run.font.size = Pt(14)
self.doc.add_paragraph() # 添加空行
# 预处理内容,将Markdown格式转换为适合Word的格式
processed_content = convert_markdown_to_word(content)
# 按行处理文本,保留结构
lines = processed_content.split('\n')
in_code_block = False
current_paragraph = None
for line in lines:
# 检查是否为标题
header_match = re.match(r'^(#{1,6})\s+(.+)$', line)
if header_match:
# 根据#的数量确定标题级别
level = len(header_match.group(1))
title_text = header_match.group(2)
if level == 1:
style = 'Heading1_Custom'
elif level == 2:
style = 'Heading2_Custom'
else:
style = 'Heading3_Custom'
self.doc.add_paragraph(title_text, style=style)
current_paragraph = None
# 检查代码块标记
elif '[代码块]' in line:
in_code_block = True
current_paragraph = self.doc.add_paragraph(style='Code_Custom')
code_line = line.replace('[代码块]', '').strip()
if code_line:
current_paragraph.add_run(code_line)
elif '[/代码块]' in line:
in_code_block = False
code_line = line.replace('[/代码块]', '').strip()
if code_line and current_paragraph:
current_paragraph.add_run(code_line)
current_paragraph = None
# 检查列表项
elif line.strip().startswith(''):
p = self.doc.add_paragraph(style='List_Custom')
p.add_run(line.strip())
current_paragraph = None
# 处理普通文本行
elif line.strip():
if in_code_block:
if current_paragraph:
current_paragraph.add_run('\n' + line)
else:
current_paragraph = self.doc.add_paragraph(line, style='Code_Custom')
else:
if current_paragraph is None or not current_paragraph.text:
current_paragraph = self.doc.add_paragraph(line, style='Normal_Custom')
else:
current_paragraph.add_run('\n' + line)
# 处理空行,创建新段落
elif not in_code_block:
current_paragraph = None
return self.doc

查看文件

@@ -0,0 +1,278 @@
from typing import List, Dict, Tuple
import asyncio
from dataclasses import dataclass
from toolbox import CatchException, update_ui, promote_file_to_downloadzone, get_log_folder, get_user
from toolbox import update_ui, CatchException, report_exception, write_history_to_file
from crazy_functions.paper_fns.auto_git.query_analyzer import QueryAnalyzer, SearchCriteria
from crazy_functions.paper_fns.auto_git.handlers.repo_handler import RepositoryHandler
from crazy_functions.paper_fns.auto_git.handlers.code_handler import CodeSearchHandler
from crazy_functions.paper_fns.auto_git.handlers.user_handler import UserSearchHandler
from crazy_functions.paper_fns.auto_git.handlers.topic_handler import TopicHandler
from crazy_functions.paper_fns.auto_git.sources.github_source import GitHubSource
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
import re
from datetime import datetime
import os
import json
from pathlib import Path
import time
# 导入格式化器
from crazy_functions.paper_fns.file2file_doc import (
TxtFormatter,
MarkdownFormatter,
HtmlFormatter,
WordFormatter
)
from crazy_functions.paper_fns.file2file_doc.word2pdf import WordToPdfConverter
@CatchException
def GitHub项目智能检索(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
history: List, system_prompt: str, user_request: str):
"""GitHub项目智能检索主函数"""
# 初始化GitHub API调用源
github_source = GitHubSource(api_key=plugin_kwargs.get("github_api_key"))
# 初始化处理器
handlers = {
"repo": RepositoryHandler(github_source, llm_kwargs),
"code": CodeSearchHandler(github_source, llm_kwargs),
"user": UserSearchHandler(github_source, llm_kwargs),
"topic": TopicHandler(github_source, llm_kwargs),
}
# 分析查询意图
chatbot.append(["分析查询意图", "正在分析您的查询需求..."])
yield from update_ui(chatbot=chatbot, history=history)
query_analyzer = QueryAnalyzer()
search_criteria = yield from query_analyzer.analyze_query(
txt, chatbot, llm_kwargs
)
# 根据查询类型选择处理器
handler = handlers.get(search_criteria.query_type)
if not handler:
handler = handlers["repo"] # 默认使用仓库处理器
# 处理查询
chatbot.append(["开始搜索", f"使用{handler.__class__.__name__}处理您的请求,正在搜索GitHub..."])
yield from update_ui(chatbot=chatbot, history=history)
final_prompt = asyncio.run(handler.handle(
criteria=search_criteria,
chatbot=chatbot,
history=history,
system_prompt=system_prompt,
llm_kwargs=llm_kwargs,
plugin_kwargs=plugin_kwargs
))
if final_prompt:
# 检查是否是道歉提示
if "很抱歉,我们未能找到" in final_prompt:
chatbot.append([txt, final_prompt])
yield from update_ui(chatbot=chatbot, history=history)
return
# 在 final_prompt 末尾添加用户原始查询要求
final_prompt += f"""
原始用户查询: "{txt}"
重要提示:
- 你的回答必须直接满足用户的原始查询要求
- 在遵循之前指南的同时,优先回答用户明确提出的问题
- 确保回答格式和内容与用户期望一致
- 对于GitHub仓库需要提供链接地址, 回复中请采用以下格式的HTML链接:
* 对于GitHub仓库: <a href='Github_URL' target='_blank'>仓库名</a>
- 不要生成参考列表,引用信息将另行处理
"""
# 使用最终的prompt生成回答
response = yield from request_gpt_model_in_new_thread_with_ui_alive(
inputs=final_prompt,
inputs_show_user=txt,
llm_kwargs=llm_kwargs,
chatbot=chatbot,
history=[],
sys_prompt=f"你是一个熟悉GitHub生态系统的专业助手,能帮助用户找到合适的项目、代码和开发者。除非用户指定,否则请使用中文回复。"
)
# 1. 获取项目列表
repos_list = handler.ranked_repos # 直接使用原始仓库数据
# 在新的对话中添加格式化的仓库参考列表
if repos_list:
references = ""
for idx, repo in enumerate(repos_list, 1):
# 构建仓库引用
stars_str = f"{repo.get('stargazers_count', 'N/A')}" if repo.get('stargazers_count') else ""
forks_str = f"🍴 {repo.get('forks_count', 'N/A')}" if repo.get('forks_count') else ""
stats = f"{stars_str} {forks_str}".strip()
stats = f" ({stats})" if stats else ""
language = f" [{repo.get('language', '')}]" if repo.get('language') else ""
reference = f"[{idx}] **{repo.get('name', '')}**{language}{stats} \n"
reference += f"👤 {repo.get('owner', {}).get('login', 'N/A') if repo.get('owner') is not None else 'N/A'} | "
reference += f"📅 {repo.get('updated_at', 'N/A')[:10]} | "
reference += f"<a href='{repo.get('html_url', '')}' target='_blank'>GitHub</a> \n"
if repo.get('description'):
reference += f"{repo.get('description')} \n"
reference += " \n"
references += reference
# 添加新的对话显示参考仓库
chatbot.append(["推荐项目如下:", references])
yield from update_ui(chatbot=chatbot, history=history)
# 2. 保存结果到文件
# 创建保存目录
save_dir = get_log_folder(get_user(chatbot), plugin_name='github_search')
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# 生成文件名
def get_safe_filename(txt, max_length=10):
# 获取文本前max_length个字符作为文件名
filename = txt[:max_length].strip()
# 移除不安全的文件名字符
filename = re.sub(r'[\\/:*?"<>|]', '', filename)
# 如果文件名为空,使用时间戳
if not filename:
filename = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
return filename
base_filename = get_safe_filename(txt)
# 准备保存的内容 - 优化文档结构
md_content = f"# GitHub搜索结果: {txt}\n\n"
md_content += f"搜索时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
# 添加模型回复
md_content += "## 搜索分析与总结\n\n"
md_content += response + "\n\n"
# 添加所有搜索到的仓库详细信息
md_content += "## 推荐项目详情\n\n"
if not repos_list:
md_content += "未找到匹配的项目\n\n"
else:
md_content += f"共找到 {len(repos_list)} 个相关项目\n\n"
# 添加项目简表
md_content += "### 项目一览表\n\n"
md_content += "| 序号 | 项目名称 | 作者 | 语言 | 星标数 | 更新时间 |\n"
md_content += "| ---- | -------- | ---- | ---- | ------ | -------- |\n"
for idx, repo in enumerate(repos_list, 1):
md_content += f"| {idx} | [{repo.get('name', '')}]({repo.get('html_url', '')}) | {repo.get('owner', {}).get('login', 'N/A') if repo.get('owner') is not None else 'N/A'} | {repo.get('language', 'N/A')} | {repo.get('stargazers_count', 'N/A')} | {repo.get('updated_at', 'N/A')[:10]} |\n"
md_content += "\n"
# 添加详细项目信息
md_content += "### 项目详细信息\n\n"
for idx, repo in enumerate(repos_list, 1):
md_content += f"#### {idx}. {repo.get('name', '')}\n\n"
md_content += f"- **仓库**: [{repo.get('full_name', '')}]({repo.get('html_url', '')})\n"
md_content += f"- **作者**: [{repo.get('owner', {}).get('login', '') if repo.get('owner') is not None else 'N/A'}]({repo.get('owner', {}).get('html_url', '') if repo.get('owner') is not None else '#'})\n"
md_content += f"- **描述**: {repo.get('description', 'N/A')}\n"
md_content += f"- **语言**: {repo.get('language', 'N/A')}\n"
md_content += f"- **星标**: {repo.get('stargazers_count', 'N/A')}\n"
md_content += f"- **Fork数**: {repo.get('forks_count', 'N/A')}\n"
md_content += f"- **最近更新**: {repo.get('updated_at', 'N/A')[:10]}\n"
md_content += f"- **创建时间**: {repo.get('created_at', 'N/A')[:10]}\n"
md_content += f"- **开源许可**: {repo.get('license', {}).get('name', 'N/A') if repo.get('license') is not None else 'N/A'}\n"
if repo.get('topics'):
md_content += f"- **主题标签**: {', '.join(repo.get('topics', []))}\n"
if repo.get('homepage'):
md_content += f"- **项目主页**: [{repo.get('homepage')}]({repo.get('homepage')})\n"
md_content += "\n"
# 添加查询信息和元数据
md_content += "## 查询元数据\n\n"
md_content += f"- **原始查询**: {txt}\n"
md_content += f"- **查询类型**: {search_criteria.query_type}\n"
md_content += f"- **关键词**: {', '.join(search_criteria.keywords) if hasattr(search_criteria, 'keywords') and search_criteria.keywords else 'N/A'}\n"
md_content += f"- **搜索日期**: {datetime.now().strftime('%Y-%m-%d')}\n\n"
# 保存为多种格式
saved_files = []
failed_files = []
# 1. 保存为TXT
try:
txt_formatter = TxtFormatter()
txt_content = txt_formatter.create_document(md_content)
txt_file = os.path.join(save_dir, f"github_results_{base_filename}.txt")
with open(txt_file, 'w', encoding='utf-8') as f:
f.write(txt_content)
promote_file_to_downloadzone(txt_file, chatbot=chatbot)
saved_files.append("TXT")
except Exception as e:
failed_files.append(f"TXT (错误: {str(e)})")
# 2. 保存为Markdown
try:
md_formatter = MarkdownFormatter()
formatted_md_content = md_formatter.create_document(md_content, "GitHub项目搜索")
md_file = os.path.join(save_dir, f"github_results_{base_filename}.md")
with open(md_file, 'w', encoding='utf-8') as f:
f.write(formatted_md_content)
promote_file_to_downloadzone(md_file, chatbot=chatbot)
saved_files.append("Markdown")
except Exception as e:
failed_files.append(f"Markdown (错误: {str(e)})")
# 3. 保存为HTML
try:
html_formatter = HtmlFormatter(processing_type="GitHub项目搜索")
html_content = html_formatter.create_document(md_content)
html_file = os.path.join(save_dir, f"github_results_{base_filename}.html")
with open(html_file, 'w', encoding='utf-8') as f:
f.write(html_content)
promote_file_to_downloadzone(html_file, chatbot=chatbot)
saved_files.append("HTML")
except Exception as e:
failed_files.append(f"HTML (错误: {str(e)})")
# 4. 保存为Word
word_file = None
try:
word_formatter = WordFormatter()
doc = word_formatter.create_document(md_content, "GitHub项目搜索")
word_file = os.path.join(save_dir, f"github_results_{base_filename}.docx")
doc.save(word_file)
promote_file_to_downloadzone(word_file, chatbot=chatbot)
saved_files.append("Word")
except Exception as e:
failed_files.append(f"Word (错误: {str(e)})")
word_file = None
# 5. 保存为PDF (仅当Word保存成功时)
if word_file and os.path.exists(word_file):
try:
pdf_file = WordToPdfConverter.convert_to_pdf(word_file)
promote_file_to_downloadzone(pdf_file, chatbot=chatbot)
saved_files.append("PDF")
except Exception as e:
failed_files.append(f"PDF (错误: {str(e)})")
# 报告保存结果
if saved_files:
success_message = f"成功保存以下格式: {', '.join(saved_files)}"
if failed_files:
failure_message = f"以下格式保存失败: {', '.join(failed_files)}"
chatbot.append(["部分格式保存成功", f"{success_message}{failure_message}"])
else:
chatbot.append(["所有格式保存成功", success_message])
else:
chatbot.append(["保存失败", f"所有格式均保存失败: {', '.join(failed_files)}"])
else:
report_exception(chatbot, history, a=f"处理失败", b=f"请尝试其他查询")
yield from update_ui(chatbot=chatbot, history=history)

查看文件

@@ -0,0 +1,635 @@
import os
import time
import glob
from typing import Dict, List, Generator, Tuple
from dataclasses import dataclass
from crazy_functions.pdf_fns.text_content_loader import TextContentLoader
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
from toolbox import update_ui, promote_file_to_downloadzone, write_history_to_file, CatchException, report_exception
from shared_utils.fastapi_server import validate_path_safety
# 导入论文下载相关函数
from crazy_functions.论文下载 import extract_paper_id, extract_paper_ids, get_arxiv_paper, format_arxiv_id, SciHub
from pathlib import Path
from datetime import datetime, timedelta
import calendar
@dataclass
class RecommendationQuestion:
"""期刊会议推荐分析问题类"""
id: str # 问题ID
question: str # 问题内容
importance: int # 重要性 (1-5,5最高)
description: str # 问题描述
class JournalConferenceRecommender:
"""论文期刊会议推荐器"""
def __init__(self, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List, history: List, system_prompt: str):
"""初始化推荐器"""
self.llm_kwargs = llm_kwargs
self.plugin_kwargs = plugin_kwargs
self.chatbot = chatbot
self.history = history
self.system_prompt = system_prompt
self.paper_content = ""
self.analysis_results = {}
# 定义论文分析问题库(针对期刊会议推荐)
self.questions = [
RecommendationQuestion(
id="research_field_and_topic",
question="请分析这篇论文的研究领域、主题和关键词。具体包括1)论文属于哪个主要学科领域如自然科学、工程技术、医学、社会科学、人文学科等;2)具体的研究子领域或方向;3)论文的核心主题和关键概念;4)重要的学术关键词和专业术语;5)研究的跨学科特征如果有;6)研究的地域性特征(国际性研究还是特定地区研究)。",
importance=5,
description="研究领域与主题分析"
),
RecommendationQuestion(
id="methodology_and_approach",
question="请分析论文的研究方法和技术路线。包括1)采用的主要研究方法定量研究、定性研究、理论分析、实验研究、田野调查、文献综述、案例研究等;2)使用的技术手段、工具或分析方法;3)研究设计的严谨性和创新性;4)数据收集和分析方法的适当性;5)研究方法在该学科中的先进性或传统性;6)方法学上的贡献或局限性。",
importance=4,
description="研究方法与技术路线"
),
RecommendationQuestion(
id="novelty_and_contribution",
question="请评估论文的创新性和学术贡献。包括1)研究的新颖性程度理论创新、方法创新、应用创新等;2)对现有知识体系的贡献或突破;3)解决问题的重要性和学术价值;4)研究成果的理论意义和实践价值;5)在该学科领域的地位和影响潜力;6)与国际前沿研究的关系;7)对后续研究的启发意义。",
importance=4,
description="创新性与学术贡献"
),
RecommendationQuestion(
id="target_audience_and_scope",
question="请分析论文的目标受众和应用范围。包括1)主要面向的学术群体研究者、从业者、政策制定者等;2)研究成果的潜在应用领域和受益群体;3)对学术界和实践界的价值;4)研究的国际化程度和跨文化适用性;5)是否适合国际期刊还是区域性期刊;6)语言发表偏好英文、中文或其他语言;7)开放获取的必要性和可行性。",
importance=3,
description="目标受众与应用范围"
),
]
# 按重要性排序
self.questions.sort(key=lambda q: q.importance, reverse=True)
def _load_paper(self, paper_path: str) -> Generator:
"""加载论文内容"""
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 使用TextContentLoader读取文件
loader = TextContentLoader(self.chatbot, self.history)
yield from loader.execute_single_file(paper_path)
# 获取加载的内容
if len(self.history) >= 2 and self.history[-2]:
self.paper_content = self.history[-2]
yield from update_ui(chatbot=self.chatbot, history=self.history)
return True
else:
self.chatbot.append(["错误", "无法读取论文内容,请检查文件是否有效"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return False
def _analyze_question(self, question: RecommendationQuestion) -> Generator:
"""分析单个问题"""
try:
# 创建分析提示
prompt = f"请基于以下论文内容回答问题:\n\n{self.paper_content}\n\n问题:{question.question}"
# 使用单线程版本的请求函数
response = yield from request_gpt_model_in_new_thread_with_ui_alive(
inputs=prompt,
inputs_show_user=question.question, # 显示问题本身
llm_kwargs=self.llm_kwargs,
chatbot=self.chatbot,
history=[], # 空历史,确保每个问题独立分析
sys_prompt="你是一个专业的学术期刊会议推荐专家,需要仔细分析论文内容并提供准确的分析。请保持客观、专业,并基于论文内容提供深入分析。"
)
if response:
self.analysis_results[question.id] = response
return True
return False
except Exception as e:
self.chatbot.append(["错误", f"分析问题时出错: {str(e)}"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return False
def _generate_journal_recommendations(self) -> Generator:
"""生成期刊推荐"""
self.chatbot.append(["生成期刊推荐", "正在基于论文分析结果生成期刊推荐..."])
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 构建期刊推荐提示
journal_prompt = """请基于以下论文分析结果,为这篇论文推荐合适的学术期刊。
推荐要求:
1. 根据论文的创新性和工作质量,分别推荐不同级别的期刊:
- 顶级期刊(影响因子>8或该领域顶级期刊2-3个
- 高质量期刊影响因子4-8或该领域知名期刊3-4个
- 中等期刊影响因子1.5-4或该领域认可期刊3-4个
- 入门期刊(影响因子<1.5但声誉良好的期刊2-3个
注意:不同学科的影响因子标准差异很大,请根据论文所属学科的实际情况调整标准。
特别是医学领域,需要考虑:
- 临床医学期刊通常影响因子较高顶级期刊IF>20,高质量期刊IF>10
- 基础医学期刊影响因子相对较低但学术价值很高
- 专科医学期刊在各自领域内具有权威性
- 医学期刊的临床实用性和循证医学价值
2. 对每个期刊提供详细信息:
- 期刊全名和缩写
- 最新影响因子(如果知道)
- 期刊级别分类Q1/Q2/Q3/Q4或该学科的分类标准
- 主要研究领域和范围
- 与论文内容的匹配度评分1-10分
- 发表难度评估(容易/中等/困难/极难)
- 平均审稿周期
- 开放获取政策
- 期刊的学科分类如SCI、SSCI、A&HCI等
- 医学期刊特殊信息(如适用):
* PubMed收录情况
* 是否为核心临床期刊
* 专科领域权威性
* 循证医学等级要求
* 临床试验注册要求
* 伦理委员会批准要求
3. 按推荐优先级排序,并说明推荐理由
4. 提供针对性的投稿建议,考虑该学科的特点
论文分析结果:"""
for q in self.questions:
if q.id in self.analysis_results:
journal_prompt += f"\n\n{q.description}:\n{self.analysis_results[q.id]}"
journal_prompt += "\n\n请提供详细的期刊推荐报告,重点关注期刊的层次性和适配性。请根据论文的具体学科领域,采用该领域通用的期刊评价标准和分类体系。"
try:
response = yield from request_gpt_model_in_new_thread_with_ui_alive(
inputs=journal_prompt,
inputs_show_user="生成期刊推荐报告",
llm_kwargs=self.llm_kwargs,
chatbot=self.chatbot,
history=[],
sys_prompt="你是一个资深的跨学科学术期刊推荐专家,熟悉各个学科领域不同层次的期刊。请根据论文的具体学科和创新性,推荐从顶级到入门级的各层次期刊。不同学科有不同的期刊评价标准理工科重视影响因子和SCI收录,社会科学重视SSCI和学科声誉,人文学科重视A&HCI和同行评议,医学领域重视PubMed收录、临床实用性、循证医学价值和伦理规范。请根据论文所属学科采用相应的评价标准。"
)
if response:
return response
return "期刊推荐生成失败"
except Exception as e:
self.chatbot.append(["错误", f"生成期刊推荐时出错: {str(e)}"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return "期刊推荐生成失败: " + str(e)
def _generate_conference_recommendations(self) -> Generator:
"""生成会议推荐"""
self.chatbot.append(["生成会议推荐", "正在基于论文分析结果生成会议推荐..."])
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 获取当前时间信息
current_time = datetime.now()
current_date_str = current_time.strftime("%Y年%m月%d")
current_year = current_time.year
current_month = current_time.month
# 构建会议推荐提示
conference_prompt = f"""请基于以下论文分析结果,为这篇论文推荐合适的学术会议。
**重要提示:当前时间是{current_date_str}{current_year}{current_month}月),请基于这个时间点推断会议的举办时间和投稿截止时间。**
推荐要求:
1. 根据论文的创新性和工作质量,分别推荐不同级别的会议:
- 顶级会议该领域最权威的国际会议2-3个
- 高质量会议该领域知名的国际或区域会议3-4个
- 中等会议该领域认可的专业会议3-4个
- 专业会议该领域细分方向的专门会议2-3个
注意:不同学科的会议评价标准不同:
- 计算机科学可参考CCF分类A/B/C类
- 工程学可参考EI收录和影响力
- 医学:可参考会议的临床影响和同行认可度
- 社会科学:可参考会议的学术声誉和参与度
- 人文学科:可参考会议的历史和学术传统
- 自然科学:可参考会议的国际影响力和发表质量
特别是医学会议,需要考虑:
- 临床医学会议重视实用性和临床指导价值
- 基础医学会议重视科学创新和机制研究
- 专科医学会议在各自领域内具有权威性
- 国际医学会议的CME学分认证情况
2. 对每个会议提供详细信息:
- 会议全名和缩写
- 会议级别分类(根据该学科的评价标准)
- 主要研究领域和主题
- 与论文内容的匹配度评分1-10分
- 录用难度评估(容易/中等/困难/极难)
- 会议举办周期(年会/双年会/不定期等)
- **基于当前时间{current_date_str},推断{current_year}年和{current_year+1}年的举办时间和地点**(请根据往年的举办时间规律进行推断)
- **基于推断的会议时间,估算论文提交截止时间**通常在会议前3-6个月
- 会议的国际化程度和影响范围
- 医学会议特殊信息(如适用):
* 是否提供CME学分
* 临床实践指导价值
* 专科认证机构认可情况
* 会议论文集的PubMed收录情况
* 伦理和临床试验相关要求
3. 按推荐优先级排序,并说明推荐理由
4. **基于当前时间{current_date_str},提供会议投稿的时间规划建议**
- 哪些会议可以赶上{current_year}年的投稿截止时间
- 哪些会议需要准备{current_year+1}年的投稿
- 具体的时间安排建议
论文分析结果:"""
for q in self.questions:
if q.id in self.analysis_results:
conference_prompt += f"\n\n{q.description}:\n{self.analysis_results[q.id]}"
conference_prompt += f"\n\n请提供详细的会议推荐报告,重点关注会议的层次性和时效性。请根据论文的具体学科领域,采用该领域通用的会议评价标准。\n\n**特别注意:请根据当前时间{current_date_str}和各会议的历史举办时间规律,准确推断{current_year}年和{current_year+1}年的会议时间安排,不要使用虚构的时间。**"
try:
response = yield from request_gpt_model_in_new_thread_with_ui_alive(
inputs=conference_prompt,
inputs_show_user="生成会议推荐报告",
llm_kwargs=self.llm_kwargs,
chatbot=self.chatbot,
history=[],
sys_prompt="你是一个资深的跨学科学术会议推荐专家,熟悉各个学科领域不同层次的学术会议。请根据论文的具体学科和创新性,推荐从顶级到专业级的各层次会议。不同学科有不同的会议评价标准和文化理工科重视技术创新和国际影响力,社会科学重视理论贡献和社会意义,人文学科重视学术深度和文化价值,医学领域重视临床实用性、CME学分认证、专科权威性和伦理规范。请根据论文所属学科采用相应的评价标准和推荐策略。"
)
if response:
return response
return "会议推荐生成失败"
except Exception as e:
self.chatbot.append(["错误", f"生成会议推荐时出错: {str(e)}"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return "会议推荐生成失败: " + str(e)
def _generate_priority_summary(self, journal_recommendations: str, conference_recommendations: str) -> Generator:
"""生成优先级总结"""
self.chatbot.append(["生成优先级总结", "正在生成投稿优先级总结..."])
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 获取当前时间信息
current_time = datetime.now()
current_date_str = current_time.strftime("%Y年%m月%d")
current_month = current_time.strftime("%Y年%m月")
# 计算未来时间点
def add_months(date, months):
"""安全地添加月份"""
month = date.month - 1 + months
year = date.year + month // 12
month = month % 12 + 1
day = min(date.day, calendar.monthrange(year, month)[1])
return date.replace(year=year, month=month, day=day)
future_6_months = add_months(current_time, 6).strftime('%Y年%m月')
future_12_months = add_months(current_time, 12).strftime('%Y年%m月')
future_year = (current_time.year + 1)
priority_prompt = f"""请基于以下期刊和会议推荐结果,生成一个综合的投稿优先级总结。
**重要提示:当前时间是{current_date_str}{current_month}),请基于这个时间点制定投稿计划。**
期刊推荐结果:
{journal_recommendations}
会议推荐结果:
{conference_recommendations}
请提供:
1. 综合投稿策略建议(考虑该学科的发表文化和惯例)
- 期刊优先还是会议优先(不同学科有不同偏好)
- 国际期刊/会议 vs 国内期刊/会议的选择策略
- 英文发表 vs 中文发表的考虑
2. 按时间线排列的投稿计划(**基于当前时间{current_date_str},考虑截止时间和审稿周期**
- 短期目标({current_month}起3-6个月内,即到{future_6_months}
- 中期目标6-12个月内,即到{future_12_months}
- 长期目标1年以上,即{future_year}年以后)
3. 风险分散策略
- 同时投稿多个不同级别的目标
- 考虑该学科的一稿多投政策
- 备选方案和应急策略
4. 针对论文可能需要的改进建议
- 根据目标期刊/会议的要求调整内容
- 语言和格式的优化建议
- 补充实验或分析的建议
5. 预期的发表时间线和成功概率评估(基于当前时间{current_date_str}
6. 该学科特有的发表注意事项
- 伦理审查要求(如医学、心理学等)
- 数据开放要求(如某些自然科学领域)
- 利益冲突声明(如医学、工程等)
- 医学领域特殊要求:
* 临床试验注册要求ClinicalTrials.gov、中国临床试验注册中心等
* 患者知情同意和隐私保护
* 医学伦理委员会批准证明
* CONSORT、STROBE、PRISMA等报告规范遵循
* 药物/器械安全性数据要求
* CME学分认证相关要求
* 临床指南和循证医学等级要求
- 其他学科特殊要求
请以表格形式总结前10个最推荐的投稿目标期刊+会议),包括优先级排序、预期时间线和成功概率。
**注意:所有时间规划都应基于当前时间{current_date_str}进行计算,不要使用虚构的时间。**"""
try:
response = yield from request_gpt_model_in_new_thread_with_ui_alive(
inputs=priority_prompt,
inputs_show_user="生成投稿优先级总结",
llm_kwargs=self.llm_kwargs,
chatbot=self.chatbot,
history=[],
sys_prompt="你是一个资深的跨学科学术发表策略专家,熟悉各个学科的发表文化、惯例和要求。请综合考虑不同学科的特点:理工科通常重视期刊发表和影响因子,社会科学平衡期刊和专著,人文学科重视同行评议和学术声誉,医学重视临床意义和伦理规范。请为作者制定最适合其学科背景的投稿策略和时间规划。"
)
if response:
return response
return "优先级总结生成失败"
except Exception as e:
self.chatbot.append(["错误", f"生成优先级总结时出错: {str(e)}"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return "优先级总结生成失败: " + str(e)
def save_recommendations(self, journal_recommendations: str, conference_recommendations: str, priority_summary: str) -> Generator:
"""保存推荐报告"""
timestamp = time.strftime("%Y%m%d_%H%M%S")
# 保存为Markdown文件
try:
md_content = f"""# 论文期刊会议推荐报告
## 投稿优先级总结
{priority_summary}
## 期刊推荐
{journal_recommendations}
## 会议推荐
{conference_recommendations}
---
# 详细分析结果
"""
# 添加详细分析结果
for q in self.questions:
if q.id in self.analysis_results:
md_content += f"\n\n## {q.description}\n\n{self.analysis_results[q.id]}"
result_file = write_history_to_file(
history=[md_content],
file_basename=f"期刊会议推荐_{timestamp}.md"
)
if result_file and os.path.exists(result_file):
promote_file_to_downloadzone(result_file, chatbot=self.chatbot)
self.chatbot.append(["保存成功", f"推荐报告已保存至: {os.path.basename(result_file)}"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
else:
self.chatbot.append(["警告", "保存报告成功但找不到文件"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
except Exception as e:
self.chatbot.append(["警告", f"保存报告失败: {str(e)}"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
def recommend_venues(self, paper_path: str) -> Generator:
"""推荐期刊会议主流程"""
# 加载论文
success = yield from self._load_paper(paper_path)
if not success:
return
# 分析关键问题
for question in self.questions:
yield from self._analyze_question(question)
# 分别生成期刊和会议推荐
journal_recommendations = yield from self._generate_journal_recommendations()
conference_recommendations = yield from self._generate_conference_recommendations()
# 生成优先级总结
priority_summary = yield from self._generate_priority_summary(journal_recommendations, conference_recommendations)
# 显示结果
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 保存报告
yield from self.save_recommendations(journal_recommendations, conference_recommendations, priority_summary)
# 将完整的分析结果和推荐内容添加到历史记录中,方便用户继续提问
self._add_to_history(journal_recommendations, conference_recommendations, priority_summary)
def _add_to_history(self, journal_recommendations: str, conference_recommendations: str, priority_summary: str):
"""将分析结果和推荐内容添加到历史记录中"""
try:
# 构建完整的内容摘要
history_content = f"""# 论文期刊会议推荐分析完成
## 📊 投稿优先级总结
{priority_summary}
## 📚 期刊推荐
{journal_recommendations}
## 🏛️ 会议推荐
{conference_recommendations}
## 📋 详细分析结果
"""
# 添加详细分析结果
for q in self.questions:
if q.id in self.analysis_results:
history_content += f"\n### {q.description}\n{self.analysis_results[q.id]}\n"
history_content += "\n---\n💡 您现在可以基于以上分析结果继续提问,比如询问特定期刊的详细信息、投稿策略建议、或者对推荐结果的进一步解释。"
# 添加到历史记录中
self.history.append("论文期刊会议推荐分析")
self.history.append(history_content)
self.chatbot.append(["✅ 分析完成", "所有分析结果和推荐内容已添加到对话历史中,您可以继续基于这些内容提问。"])
except Exception as e:
self.chatbot.append(["警告", f"添加到历史记录时出错: {str(e)},但推荐报告已正常生成"])
# 即使添加历史失败,也不影响主要功能
def _find_paper_file(path: str) -> str:
"""查找路径中的论文文件(简化版)"""
if os.path.isfile(path):
return path
# 支持的文件扩展名(按优先级排序)
extensions = ["pdf", "docx", "doc", "txt", "md", "tex"]
# 简单地遍历目录
if os.path.isdir(path):
try:
for ext in extensions:
# 手动检查每个可能的文件,而不使用glob
potential_file = os.path.join(path, f"paper.{ext}")
if os.path.exists(potential_file) and os.path.isfile(potential_file):
return potential_file
# 如果没找到特定命名的文件,检查目录中的所有文件
for file in os.listdir(path):
file_path = os.path.join(path, file)
if os.path.isfile(file_path):
file_ext = file.split('.')[-1].lower() if '.' in file else ""
if file_ext in extensions:
return file_path
except Exception:
pass # 忽略任何错误
return None
def download_paper_by_id(paper_info, chatbot, history) -> str:
"""下载论文并返回保存路径
Args:
paper_info: 元组,包含论文ID类型arxiv或doi和ID值
chatbot: 聊天机器人对象
history: 历史记录
Returns:
str: 下载的论文路径或None
"""
id_type, paper_id = paper_info
# 创建保存目录 - 使用时间戳创建唯一文件夹
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
user_name = chatbot.get_user() if hasattr(chatbot, 'get_user') else "default"
from toolbox import get_log_folder, get_user
base_save_dir = get_log_folder(get_user(chatbot), plugin_name='paper_download')
save_dir = os.path.join(base_save_dir, f"papers_{timestamp}")
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_path = Path(save_dir)
chatbot.append([f"下载论文", f"正在下载{'arXiv' if id_type == 'arxiv' else 'DOI'} {paper_id} 的论文..."])
update_ui(chatbot=chatbot, history=history)
pdf_path = None
try:
if id_type == 'arxiv':
# 使用改进的arxiv查询方法
formatted_id = format_arxiv_id(paper_id)
paper_result = get_arxiv_paper(formatted_id)
if not paper_result:
chatbot.append([f"下载失败", f"未找到arXiv论文: {paper_id}"])
update_ui(chatbot=chatbot, history=history)
return None
# 下载PDF
filename = f"arxiv_{paper_id.replace('/', '_')}.pdf"
pdf_path = str(save_path / filename)
paper_result.download_pdf(filename=pdf_path)
else: # doi
# 下载DOI
sci_hub = SciHub(
doi=paper_id,
path=save_path
)
pdf_path = sci_hub.fetch()
# 检查下载结果
if pdf_path and os.path.exists(pdf_path):
promote_file_to_downloadzone(pdf_path, chatbot=chatbot)
chatbot.append([f"下载成功", f"已成功下载论文: {os.path.basename(pdf_path)}"])
update_ui(chatbot=chatbot, history=history)
return pdf_path
else:
chatbot.append([f"下载失败", f"论文下载失败: {paper_id}"])
update_ui(chatbot=chatbot, history=history)
return None
except Exception as e:
chatbot.append([f"下载错误", f"下载论文时出错: {str(e)}"])
update_ui(chatbot=chatbot, history=history)
return None
@CatchException
def 论文期刊会议推荐(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
history: List, system_prompt: str, user_request: str):
"""主函数 - 论文期刊会议推荐"""
# 初始化推荐器
chatbot.append(["函数插件功能及使用方式", "论文期刊会议推荐:基于论文内容分析,为您推荐合适的学术期刊和会议投稿目标。适用于各个学科专业(自然科学、工程技术、医学、社会科学、人文学科等),根据不同学科的评价标准和发表文化,提供分层次的期刊会议推荐、影响因子分析、发表难度评估、投稿策略建议等。<br><br>📋 使用方式:<br>1、直接上传PDF文件<br>2、输入DOI号或arXiv ID<br>3、点击插件开始分析"])
yield from update_ui(chatbot=chatbot, history=history)
paper_file = None
# 检查输入是否为论文IDarxiv或DOI
paper_info = extract_paper_id(txt)
if paper_info:
# 如果是论文ID,下载论文
chatbot.append(["检测到论文ID", f"检测到{'arXiv' if paper_info[0] == 'arxiv' else 'DOI'} ID: {paper_info[1]},准备下载论文..."])
yield from update_ui(chatbot=chatbot, history=history)
# 下载论文
paper_file = download_paper_by_id(paper_info, chatbot, history)
if not paper_file:
report_exception(chatbot, history, a=f"下载论文失败", b=f"无法下载{'arXiv' if paper_info[0] == 'arxiv' else 'DOI'}论文: {paper_info[1]}")
yield from update_ui(chatbot=chatbot, history=history)
return
else:
# 检查输入路径
if not os.path.exists(txt):
report_exception(chatbot, history, a=f"解析论文: {txt}", b=f"找不到文件或无权访问: {txt}")
yield from update_ui(chatbot=chatbot, history=history)
return
# 验证路径安全性
user_name = chatbot.get_user()
validate_path_safety(txt, user_name)
# 查找论文文件
paper_file = _find_paper_file(txt)
if not paper_file:
report_exception(chatbot, history, a=f"解析论文", b=f"在路径 {txt} 中未找到支持的论文文件")
yield from update_ui(chatbot=chatbot, history=history)
return
yield from update_ui(chatbot=chatbot, history=history)
# 确保paper_file是字符串
if paper_file is not None and not isinstance(paper_file, str):
# 尝试转换为字符串
try:
paper_file = str(paper_file)
except:
report_exception(chatbot, history, a=f"类型错误", b=f"论文路径不是有效的字符串: {type(paper_file)}")
yield from update_ui(chatbot=chatbot, history=history)
return
# 开始推荐
chatbot.append(["开始分析", f"正在分析论文并生成期刊会议推荐: {os.path.basename(paper_file)}"])
yield from update_ui(chatbot=chatbot, history=history)
recommender = JournalConferenceRecommender(llm_kwargs, plugin_kwargs, chatbot, history, system_prompt)
yield from recommender.recommend_venues(paper_file)

查看文件

@@ -0,0 +1,295 @@
import re
import os
import zipfile
from toolbox import CatchException, update_ui, promote_file_to_downloadzone, get_log_folder, get_user
from pathlib import Path
from datetime import datetime
def extract_paper_id(txt):
"""从输入文本中提取论文ID"""
# 尝试匹配DOI将DOI匹配提前,因为其格式更加明确
doi_patterns = [
r'doi.org/([\w\./-]+)', # doi.org/10.1234/xxx
r'doi:\s*([\w\./-]+)', # doi: 10.1234/xxx
r'(10\.\d{4,}/[\w\.-]+)', # 直接输入DOI: 10.1234/xxx
]
for pattern in doi_patterns:
match = re.search(pattern, txt, re.IGNORECASE)
if match:
return ('doi', match.group(1))
# 尝试匹配arXiv ID
arxiv_patterns = [
r'arxiv.org/abs/(\d+\.\d+)', # arxiv.org/abs/2103.14030
r'arxiv.org/pdf/(\d+\.\d+)', # arxiv.org/pdf/2103.14030
r'arxiv/(\d+\.\d+)', # arxiv/2103.14030
r'^(\d{4}\.\d{4,5})$', # 直接输入ID: 2103.14030
# 添加对早期arXiv ID的支持
r'arxiv.org/abs/([\w-]+/\d{7})', # arxiv.org/abs/math/0211159
r'arxiv.org/pdf/([\w-]+/\d{7})', # arxiv.org/pdf/hep-th/9901001
r'^([\w-]+/\d{7})$', # 直接输入: math/0211159
]
for pattern in arxiv_patterns:
match = re.search(pattern, txt, re.IGNORECASE)
if match:
paper_id = match.group(1)
# 如果是新格式YYMM.NNNNN或旧格式category/NNNNNNN,都直接返回
if re.match(r'^\d{4}\.\d{4,5}$', paper_id) or re.match(r'^[\w-]+/\d{7}$', paper_id):
return ('arxiv', paper_id)
return None
def extract_paper_ids(txt):
"""从输入文本中提取多个论文ID"""
paper_ids = []
# 首先按换行符分割
for line in txt.strip().split('\n'):
line = line.strip()
if not line: # 跳过空行
continue
# 对每一行再按空格分割
for item in line.split():
item = item.strip()
if not item: # 跳过空项
continue
paper_info = extract_paper_id(item)
if paper_info:
paper_ids.append(paper_info)
# 去除重复项,保持顺序
unique_paper_ids = []
seen = set()
for paper_info in paper_ids:
if paper_info not in seen:
seen.add(paper_info)
unique_paper_ids.append(paper_info)
return unique_paper_ids
def format_arxiv_id(paper_id):
"""格式化arXiv ID,处理新旧两种格式"""
# 如果是旧格式 (e.g. astro-ph/0404140),需要去掉arxiv:前缀
if '/' in paper_id:
return paper_id.replace('arxiv:', '') # 确保移除可能存在的arxiv:前缀
return paper_id
def get_arxiv_paper(paper_id):
"""获取arXiv论文,处理新旧两种格式"""
import arxiv
# 尝试不同的查询方式
query_formats = [
paper_id, # 原始ID
paper_id.replace('/', ''), # 移除斜杠
f"id:{paper_id}", # 添加id:前缀
]
for query in query_formats:
try:
# 使用Search查询
search = arxiv.Search(
query=query,
max_results=1
)
result = next(arxiv.Client().results(search))
if result:
return result
except:
continue
try:
# 使用id_list查询
search = arxiv.Search(id_list=[query])
result = next(arxiv.Client().results(search))
if result:
return result
except:
continue
return None
def create_zip_archive(files, save_path):
"""将多个PDF文件打包成zip"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
zip_filename = f"papers_{timestamp}.zip"
zip_path = str(save_path / zip_filename)
with zipfile.ZipFile(zip_path, 'w') as zipf:
for file in files:
if os.path.exists(file):
# 只添加文件名,不包含路径
zipf.write(file, os.path.basename(file))
return zip_path
@CatchException
def 论文下载(txt: str, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
"""
txt: 用户输入,可以是DOI、arxiv ID或相关链接,支持多行输入进行批量下载
"""
from crazy_functions.doc_fns.text_content_loader import TextContentLoader
from crazy_functions.review_fns.data_sources.arxiv_source import ArxivSource
from crazy_functions.review_fns.data_sources.scihub_source import SciHub
# 解析输入
paper_infos = extract_paper_ids(txt)
if not paper_infos:
chatbot.append(["输入解析", "未能识别任何论文ID或DOI,请检查输入格式。支持以下格式\n- arXiv ID (例如2103.14030)\n- arXiv链接\n- DOI (例如10.1234/xxx)\n- DOI链接\n\n多个论文ID请用换行分隔。"])
yield from update_ui(chatbot=chatbot, history=history)
return
# 创建保存目录 - 使用时间戳创建唯一文件夹
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
base_save_dir = get_log_folder(get_user(chatbot), plugin_name='paper_download')
save_dir = os.path.join(base_save_dir, f"papers_{timestamp}")
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_path = Path(save_dir)
# 记录下载结果
success_count = 0
failed_papers = []
downloaded_files = [] # 记录成功下载的文件路径
chatbot.append([f"开始下载", f"支持多行输入下载多篇论文,共检测到 {len(paper_infos)} 篇论文,开始下载..."])
yield from update_ui(chatbot=chatbot, history=history)
for id_type, paper_id in paper_infos:
try:
if id_type == 'arxiv':
chatbot.append([f"正在下载", f"从arXiv下载论文 {paper_id}..."])
yield from update_ui(chatbot=chatbot, history=history)
# 使用改进的arxiv查询方法
formatted_id = format_arxiv_id(paper_id)
paper_result = get_arxiv_paper(formatted_id)
if not paper_result:
failed_papers.append((paper_id, "未找到论文"))
continue
# 下载PDF
try:
filename = f"arxiv_{paper_id.replace('/', '_')}.pdf"
pdf_path = str(save_path / filename)
paper_result.download_pdf(filename=pdf_path)
if os.path.exists(pdf_path):
downloaded_files.append(pdf_path)
except Exception as e:
failed_papers.append((paper_id, f"PDF下载失败: {str(e)}"))
continue
else: # doi
chatbot.append([f"正在下载", f"从Sci-Hub下载论文 {paper_id}..."])
yield from update_ui(chatbot=chatbot, history=history)
sci_hub = SciHub(
doi=paper_id,
path=save_path
)
pdf_path = sci_hub.fetch()
if pdf_path and os.path.exists(pdf_path):
downloaded_files.append(pdf_path)
# 检查下载结果
if pdf_path and os.path.exists(pdf_path):
promote_file_to_downloadzone(pdf_path, chatbot=chatbot)
success_count += 1
else:
failed_papers.append((paper_id, "下载失败"))
except Exception as e:
failed_papers.append((paper_id, str(e)))
yield from update_ui(chatbot=chatbot, history=history)
# 创建ZIP压缩包
if downloaded_files:
try:
zip_path = create_zip_archive(downloaded_files, Path(base_save_dir))
promote_file_to_downloadzone(zip_path, chatbot=chatbot)
chatbot.append([
f"创建压缩包",
f"已将所有下载的论文打包为: {os.path.basename(zip_path)}"
])
yield from update_ui(chatbot=chatbot, history=history)
except Exception as e:
chatbot.append([
f"创建压缩包失败",
f"打包文件时出现错误: {str(e)}"
])
yield from update_ui(chatbot=chatbot, history=history)
# 生成最终报告
summary = f"下载完成!成功下载 {success_count} 篇论文。\n"
if failed_papers:
summary += "\n以下论文下载失败:\n"
for paper_id, reason in failed_papers:
summary += f"- {paper_id}: {reason}\n"
if downloaded_files:
summary += f"\n所有论文已存放在文件夹 '{save_dir}' 中,并打包到压缩文件中。您可以在下载区找到单个PDF文件和压缩包。"
chatbot.append([
f"下载完成",
summary
])
yield from update_ui(chatbot=chatbot, history=history)
# 如果下载成功且用户想要直接阅读内容
if downloaded_files:
chatbot.append([
"提示",
"正在读取论文内容进行分析,请稍候..."
])
yield from update_ui(chatbot=chatbot, history=history)
# 使用TextContentLoader加载整个文件夹的PDF文件内容
loader = TextContentLoader(chatbot, history)
# 删除提示信息
chatbot.pop()
# 加载PDF内容 - 传入文件夹路径而不是单个文件路径
yield from loader.execute(save_dir)
# 添加提示信息
chatbot.append([
"提示",
"论文内容已加载完毕,您可以直接向AI提问有关该论文的问题。"
])
yield from update_ui(chatbot=chatbot, history=history)
if __name__ == "__main__":
# 测试代码
import asyncio
async def test():
# 测试批量输入
batch_inputs = [
# 换行分隔的测试
"""https://arxiv.org/abs/2103.14030
math/0211159
10.1038/s41586-021-03819-2""",
# 空格分隔的测试
"https://arxiv.org/abs/2103.14030 math/0211159 10.1038/s41586-021-03819-2",
# 混合分隔的测试
"""https://arxiv.org/abs/2103.14030 math/0211159
10.1038/s41586-021-03819-2 https://doi.org/10.1038/s41586-021-03819-2
2103.14030""",
]
for i, test_input in enumerate(batch_inputs, 1):
print(f"\n测试用例 {i}:")
print(f"输入: {test_input}")
results = extract_paper_ids(test_input)
print(f"解析结果:")
for result in results:
print(f" {result}")
asyncio.run(test())

查看文件

@@ -0,0 +1,867 @@
import os
import time
import glob
import re
import threading
from typing import Dict, List, Generator, Tuple
from dataclasses import dataclass
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
from crazy_functions.pdf_fns.breakdown_txt import breakdown_text_to_satisfy_token_limit
from crazy_functions.rag_fns.rag_file_support import extract_text, convert_to_markdown
from request_llms.bridge_all import model_info
from toolbox import update_ui, CatchException, report_exception, promote_file_to_downloadzone, write_history_to_file
from shared_utils.fastapi_server import validate_path_safety
# 新增:导入结构化论文提取器
from crazy_functions.doc_fns.read_fns.unstructured_all.paper_structure_extractor import PaperStructureExtractor, ExtractorConfig, StructuredPaper
# 导入格式化器
from crazy_functions.paper_fns.file2file_doc import (
TxtFormatter,
MarkdownFormatter,
HtmlFormatter,
WordFormatter
)
@dataclass
class TextFragment:
"""文本片段数据类,用于组织处理单元"""
content: str
fragment_index: int
total_fragments: int
class DocumentProcessor:
"""文档处理器 - 处理单个文档并输出结果"""
def __init__(self, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List, history: List, system_prompt: str):
"""初始化处理器"""
self.llm_kwargs = llm_kwargs.copy() # 创建原始llm_kwargs的副本
# 固定使用deepseek-reasoner模型
self.llm_kwargs['llm_model'] = "deepseek-reasoner"
self.plugin_kwargs = plugin_kwargs
self.chatbot = chatbot
self.history = history
self.system_prompt = system_prompt
self.processed_results = []
self.failed_fragments = []
# 新增:初始化论文结构提取器
self.paper_extractor = PaperStructureExtractor()
# 修改设置降重次数,默认为1次原为3次
self.reduction_times = int(self.plugin_kwargs.get("reduction_times", 1))
def _get_token_limit(self) -> int:
"""获取模型token限制,返回更小的值以确保更细粒度的分割"""
max_token = model_info[self.llm_kwargs['llm_model']]['max_token']
# 降低token限制,使每个片段更小
return max_token // 4 # 从3/4降低到1/4
def _create_batch_inputs(self, fragments: List[TextFragment], current_iteration: int = 1) -> Tuple[List, List, List]:
"""创建批处理输入"""
inputs_array = []
inputs_show_user_array = []
history_array = []
# 基础指令 - 增强单次调用的效果
user_instruction = self.plugin_kwargs.get("advanced_arg", """你是一个人类科研工作者,请对以下学术文本进行深度改写,以彻底消除AI生成特征。具体要求如下
1. 保持学术写作的严谨性和专业性,但加入一些人类写作的不完美元素
2. 维持原文的核心论述和逻辑框架,但打破过于完美的结构
3. 彻底改变句式结构:
- 灵活运用主动句与被动句,偏好使用主动语态
- 适当使用不同长度的句子,包括一些较长的复杂句
- 引入变化多样的句式,打破规律性
- 完全避免AI常用的模板化句式和套路表达
- 增加一些学术写作中常见的转折和连接方式
4. 全面改善用词:
- 使用更多学术语境下的专业词汇和同义词替换
- 避免过于机械和规律性的连接词,使用更自然的过渡
- 重构专业术语的表达方式,但保持准确性
- 增加词汇多样性,减少重复用词
- 偶尔使用一些不太常见但恰当的词汇
5. 模拟真实学者的写作风格:
- 注重论证的严密性,但允许存在一些微小的不对称性
- 保持表达的客观性,同时适度体现个人学术见解
- 在适当位置表达观点时更加自信和坚定
- 避免过于完美和机械均衡的论述结构
- 允许段落长度有所变化,不要过于均匀
6. 引入人类学者常见的写作特点:
- 段落之间的过渡更加自然流畅
- 适当使用一些学术界常见的修辞手法,但不过度使用
- 偶尔使用一些强调和限定性表达
- 适当使用一些学术界认可的个人化表达
7. 彻底消除AI痕迹
- 避免过于规整和均衡的段落结构
- 避免机械性的句式变化和词汇替换模式
- 避免过于完美的逻辑推导,适当增加一些转折
- 减少公式化的表达方式""")
# 对于单次调用的场景,不需要迭代前缀,直接使用更强力的改写指令
for frag in fragments:
# 在单次调用时使用更强力的指令
if self.reduction_times == 1:
i_say = (f'请对以下学术文本进行彻底改写,完全消除AI特征,使其像真实人类学者撰写的内容。\n\n{user_instruction}\n\n'
f'请记住以下几点:\n'
f'1. 避免过于规整和均衡的结构\n'
f'2. 引入一些人类写作的微小不完美之处\n'
f'3. 使用多样化的句式和词汇\n'
f'4. 打破可能的AI规律性表达模式\n'
f'5. 适当使用一些专业领域内的表达习惯\n\n'
f'请将对文本的处理结果放在<decision>和</decision>标签之间。\n\n'
f'文本内容:\n```\n{frag.content}\n```')
else:
# 原有的迭代前缀逻辑
iteration_prefix = ""
if current_iteration > 1:
iteration_prefix = f"这是第{current_iteration}次改写,请在保持学术性的基础上,采用更加人性化、不同的表达方式。"
if current_iteration == 2:
iteration_prefix += "在保持专业性的同时,进一步优化句式结构和用词,显著降低AI痕迹。"
elif current_iteration >= 3:
iteration_prefix += "请在确保不损失任何学术内容的前提下,彻底重构表达方式,并适当引入少量人类学者常用的表达技巧,避免过度使用比喻和类比。"
i_say = (f'请按照以下要求处理文本内容:{iteration_prefix}{user_instruction}\n\n'
f'请将对文本的处理结果放在<decision>和</decision>标签之间。\n\n'
f'文本内容:\n```\n{frag.content}\n```')
i_say_show_user = f'正在处理文本片段 {frag.fragment_index + 1}/{frag.total_fragments}'
inputs_array.append(i_say)
inputs_show_user_array.append(i_say_show_user)
history_array.append([])
return inputs_array, inputs_show_user_array, history_array
def _extract_decision(self, text: str) -> str:
"""从LLM响应中提取<decision>标签内的内容"""
import re
pattern = r'<decision>(.*?)</decision>'
matches = re.findall(pattern, text, re.DOTALL)
if matches:
return matches[0].strip()
else:
# 如果没有找到标签,返回原始文本
return text.strip()
def process_file(self, file_path: str) -> Generator:
"""处理单个文件"""
self.chatbot.append(["开始处理文件", f"文件路径: {file_path}"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
try:
# 首先尝试转换为Markdown
file_path = convert_to_markdown(file_path)
# 1. 检查文件是否为支持的论文格式
is_paper_format = any(file_path.lower().endswith(ext) for ext in self.paper_extractor.SUPPORTED_EXTENSIONS)
if is_paper_format:
# 使用结构化提取器处理论文
return (yield from self._process_structured_paper(file_path))
else:
# 使用原有方式处理普通文档
return (yield from self._process_regular_file(file_path))
except Exception as e:
self.chatbot.append(["处理错误", f"文件处理失败: {str(e)}"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return None
def _process_structured_paper(self, file_path: str) -> Generator:
"""处理结构化论文文件"""
# 1. 提取论文结构
self.chatbot[-1] = ["正在分析论文结构", f"文件路径: {file_path}"]
yield from update_ui(chatbot=self.chatbot, history=self.history)
try:
paper = self.paper_extractor.extract_paper_structure(file_path)
if not paper or not paper.sections:
self.chatbot.append(["无法提取论文结构", "将使用全文内容进行处理"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 使用全文内容进行段落切分
if paper and paper.full_text:
# 使用增强的分割函数进行更细致的分割
fragments = self._breakdown_section_content(paper.full_text)
# 创建文本片段对象
text_fragments = []
for i, frag in enumerate(fragments):
if frag.strip():
text_fragments.append(TextFragment(
content=frag,
fragment_index=i,
total_fragments=len(fragments)
))
# 多次降重处理
if text_fragments:
current_fragments = text_fragments
# 进行多轮降重处理
for iteration in range(1, self.reduction_times + 1):
# 处理当前片段
processed_content = yield from self._process_text_fragments(current_fragments, iteration)
# 如果这是最后一次迭代,保存结果
if iteration == self.reduction_times:
final_content = processed_content
break
# 否则,准备下一轮迭代的片段
# 从处理结果中提取处理后的内容
next_fragments = []
for idx, item in enumerate(self.processed_results):
next_fragments.append(TextFragment(
content=item['content'],
fragment_index=idx,
total_fragments=len(self.processed_results)
))
current_fragments = next_fragments
# 更新UI显示最终结果
self.chatbot[-1] = ["处理完成", f"共完成 {self.reduction_times} 轮降重"]
yield from update_ui(chatbot=self.chatbot, history=self.history)
return final_content
else:
self.chatbot.append(["处理失败", "未能提取到有效的文本内容"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return None
else:
self.chatbot.append(["处理失败", "未能提取到论文内容"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return None
# 2. 准备处理章节内容(不处理标题)
self.chatbot[-1] = ["已提取论文结构", f"{len(paper.sections)} 个主要章节"]
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 3. 收集所有需要处理的章节内容并分割为合适大小
sections_to_process = []
section_map = {} # 用于映射处理前后的内容
def collect_section_contents(sections, parent_path=""):
"""递归收集章节内容,跳过参考文献部分"""
for i, section in enumerate(sections):
current_path = f"{parent_path}/{i}" if parent_path else f"{i}"
# 检查是否为参考文献部分,如果是则跳过
if section.section_type == 'references' or section.title.lower() in ['references', '参考文献', 'bibliography', '文献']:
continue # 跳过参考文献部分
# 只处理内容非空的章节
if section.content and section.content.strip():
# 使用增强的分割函数进行更细致的分割
fragments = self._breakdown_section_content(section.content)
for fragment_idx, fragment_content in enumerate(fragments):
if fragment_content.strip():
fragment_index = len(sections_to_process)
sections_to_process.append(TextFragment(
content=fragment_content,
fragment_index=fragment_index,
total_fragments=0 # 临时值,稍后更新
))
# 保存映射关系,用于稍后更新章节内容
# 为每个片段存储原始章节和片段索引信息
section_map[fragment_index] = (current_path, section, fragment_idx, len(fragments))
# 递归处理子章节
if section.subsections:
collect_section_contents(section.subsections, current_path)
# 收集所有章节内容
collect_section_contents(paper.sections)
# 更新总片段数
total_fragments = len(sections_to_process)
for frag in sections_to_process:
frag.total_fragments = total_fragments
# 4. 如果没有内容需要处理,直接返回
if not sections_to_process:
self.chatbot.append(["处理完成", "未找到需要处理的内容"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return None
# 5. 批量处理章节内容
self.chatbot[-1] = ["开始处理论文内容", f"{len(sections_to_process)} 个内容片段"]
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 一次性准备所有输入
inputs_array, inputs_show_user_array, history_array = self._create_batch_inputs(sections_to_process)
# 使用系统提示
instruction = self.plugin_kwargs.get("advanced_arg", """请对以下学术文本进行彻底改写,以显著降低AI生成特征。具体要求如下
1. 保持学术写作的严谨性和专业性
2. 维持原文的核心论述和逻辑框架
3. 优化句式结构:
- 灵活运用主动句与被动句
- 适当拆分复杂句式,提高可读性
- 注意句式的多样性,避免重复模式
- 打破AI常用的句式模板
4. 改善用词:
- 使用更多学术语境下的同义词替换
- 避免过于机械和规律性的连接词
- 适当调整专业术语的表达方式
- 增加词汇多样性,减少重复用词
5. 增强文本的学术特征:
- 注重论证的严密性
- 保持表达的客观性
- 适度体现作者的学术见解
- 避免过于完美和均衡的论述结构
6. 确保语言风格的一致性
7. 减少AI生成文本常见的套路和模式""")
sys_prompt_array = [f"""作为一位专业的学术写作顾问,请按照以下要求改写文本:
1. 严格保持学术写作规范
2. 维持原文的核心论述和逻辑框架
3. 通过优化句式结构和用词降低AI生成特征
4. 确保语言风格的一致性和专业性
5. 保持内容的客观性和准确性
6. 避免AI常见的套路化表达和过于完美的结构"""] * len(sections_to_process)
# 调用LLM一次性处理所有片段
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
inputs_array=inputs_array,
inputs_show_user_array=inputs_show_user_array,
llm_kwargs=self.llm_kwargs,
chatbot=self.chatbot,
history_array=history_array,
sys_prompt_array=sys_prompt_array,
)
# 处理响应,重组章节内容
section_contents = {} # 用于重组各章节的处理后内容
for j, frag in enumerate(sections_to_process):
try:
llm_response = response_collection[j * 2 + 1]
processed_text = self._extract_decision(llm_response)
if processed_text and processed_text.strip():
# 保存处理结果
self.processed_results.append({
'index': frag.fragment_index,
'content': processed_text
})
# 存储处理后的文本片段,用于后续重组
fragment_index = frag.fragment_index
if fragment_index in section_map:
path, section, fragment_idx, total_fragments = section_map[fragment_index]
# 初始化此章节的内容容器(如果尚未创建)
if path not in section_contents:
section_contents[path] = [""] * total_fragments
# 将处理后的片段放入正确位置
section_contents[path][fragment_idx] = processed_text
else:
self.failed_fragments.append(frag)
except Exception as e:
self.failed_fragments.append(frag)
# 重组每个章节的内容
for path, fragments in section_contents.items():
section = None
for idx in section_map:
if section_map[idx][0] == path:
section = section_map[idx][1]
break
if section:
# 合并该章节的所有处理后片段
section.content = "\n".join(fragments)
# 6. 更新UI
success_count = total_fragments - len(self.failed_fragments)
self.chatbot[-1] = ["处理完成", f"成功处理 {success_count}/{total_fragments} 个内容片段"]
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 收集参考文献部分(不进行处理)
references_sections = []
def collect_references(sections, parent_path=""):
"""递归收集参考文献部分"""
for i, section in enumerate(sections):
current_path = f"{parent_path}/{i}" if parent_path else f"{i}"
# 检查是否为参考文献部分
if section.section_type == 'references' or section.title.lower() in ['references', '参考文献', 'bibliography', '文献']:
references_sections.append((current_path, section))
# 递归检查子章节
if section.subsections:
collect_references(section.subsections, current_path)
# 收集参考文献
collect_references(paper.sections)
# 7. 将处理后的结构化论文转换为Markdown
markdown_content = self.paper_extractor.generate_markdown(paper)
# 8. 返回处理后的内容
self.chatbot[-1] = ["处理完成", f"成功处理 {success_count}/{total_fragments} 个内容片段,参考文献部分未处理"]
yield from update_ui(chatbot=self.chatbot, history=self.history)
return markdown_content
except Exception as e:
self.chatbot.append(["结构化处理失败", f"错误: {str(e)},将尝试作为普通文件处理"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return (yield from self._process_regular_file(file_path))
def _process_regular_file(self, file_path: str) -> Generator:
"""使用原有方式处理普通文件"""
# 原有的文件处理逻辑
self.chatbot[-1] = ["正在读取文件", f"文件路径: {file_path}"]
yield from update_ui(chatbot=self.chatbot, history=self.history)
content = extract_text(file_path)
if not content or not content.strip():
self.chatbot.append(["处理失败", "文件内容为空或无法提取内容"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return None
# 2. 分割文本
self.chatbot[-1] = ["正在分析文件", "将文件内容分割为适当大小的片段"]
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 使用增强的分割函数
fragments = self._breakdown_section_content(content)
# 3. 创建文本片段对象
text_fragments = []
for i, frag in enumerate(fragments):
if frag.strip():
text_fragments.append(TextFragment(
content=frag,
fragment_index=i,
total_fragments=len(fragments)
))
# 4. 多轮降重处理
if not text_fragments:
self.chatbot.append(["处理失败", "未能提取到有效的文本内容"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return None
# 批处理大小
batch_size = 8 # 每批处理的片段数
# 第一次迭代
current_batches = []
for i in range(0, len(text_fragments), batch_size):
current_batches.append(text_fragments[i:i + batch_size])
all_processed_fragments = []
# 进行多轮降重处理
for iteration in range(1, self.reduction_times + 1):
self.chatbot[-1] = ["开始处理文本", f"{iteration}/{self.reduction_times} 次降重"]
yield from update_ui(chatbot=self.chatbot, history=self.history)
next_batches = []
all_processed_fragments = []
# 分批处理当前迭代的片段
for batch in current_batches:
# 处理当前批次
_ = yield from self._process_text_fragments(batch, iteration)
# 收集处理结果
processed_batch = []
for item in self.processed_results:
processed_batch.append(TextFragment(
content=item['content'],
fragment_index=len(all_processed_fragments) + len(processed_batch),
total_fragments=0 # 临时值,稍后更新
))
all_processed_fragments.extend(processed_batch)
# 如果不是最后一轮迭代,准备下一批次
if iteration < self.reduction_times:
for i in range(0, len(processed_batch), batch_size):
next_batches.append(processed_batch[i:i + batch_size])
# 更新总片段数
for frag in all_processed_fragments:
frag.total_fragments = len(all_processed_fragments)
# 为下一轮迭代准备批次
current_batches = next_batches
# 合并最终结果
final_content = "\n\n".join([frag.content for frag in all_processed_fragments])
# 5. 更新UI显示最终结果
self.chatbot[-1] = ["处理完成", f"共完成 {self.reduction_times} 轮降重"]
yield from update_ui(chatbot=self.chatbot, history=self.history)
return final_content
def save_results(self, content: str, original_file_path: str) -> List[str]:
"""保存处理结果为TXT格式"""
if not content:
return []
timestamp = time.strftime("%Y%m%d_%H%M%S")
original_filename = os.path.basename(original_file_path)
filename_without_ext = os.path.splitext(original_filename)[0]
base_filename = f"{filename_without_ext}_processed_{timestamp}"
result_files = []
# 只保存为TXT
try:
txt_formatter = TxtFormatter()
txt_content = txt_formatter.create_document(content)
txt_file = write_history_to_file(
history=[txt_content],
file_basename=f"{base_filename}.txt"
)
result_files.append(txt_file)
except Exception as e:
self.chatbot.append(["警告", f"TXT格式保存失败: {str(e)}"])
# 添加到下载区
for file in result_files:
promote_file_to_downloadzone(file, chatbot=self.chatbot)
return result_files
def _breakdown_section_content(self, content: str) -> List[str]:
"""对文本内容进行分割与合并
主要按段落进行组织,只合并较小的段落以减少片段数量
保留原始段落结构,不对长段落进行强制分割
针对中英文设置不同的阈值,因为字符密度不同
"""
# 先按段落分割文本
paragraphs = content.split('\n\n')
# 检测语言类型
chinese_char_count = sum(1 for char in content if '\u4e00' <= char <= '\u9fff')
is_chinese_text = chinese_char_count / max(1, len(content)) > 0.3
# 根据语言类型设置不同的阈值(只用于合并小段落)
if is_chinese_text:
# 中文文本:一个汉字就是一个字符,信息密度高
min_chunk_size = 300 # 段落合并的最小阈值
target_size = 800 # 理想的段落大小
else:
# 英文文本:一个单词由多个字符组成,信息密度低
min_chunk_size = 600 # 段落合并的最小阈值
target_size = 1600 # 理想的段落大小
# 1. 只合并小段落,不对长段落进行分割
result_fragments = []
current_chunk = []
current_length = 0
for para in paragraphs:
# 如果段落太小且不会超过目标大小,则合并
if len(para) < min_chunk_size and current_length + len(para) <= target_size:
current_chunk.append(para)
current_length += len(para)
# 否则,创建新段落
else:
# 如果当前块非空且与当前段落无关,先保存它
if current_chunk and current_length > 0:
result_fragments.append('\n\n'.join(current_chunk))
# 当前段落作为新块
current_chunk = [para]
current_length = len(para)
# 如果当前块大小已接近目标大小,保存并开始新块
if current_length >= target_size:
result_fragments.append('\n\n'.join(current_chunk))
current_chunk = []
current_length = 0
# 保存最后一个块
if current_chunk:
result_fragments.append('\n\n'.join(current_chunk))
# 2. 处理可能过大的片段确保不超过token限制
final_fragments = []
max_token = self._get_token_limit()
for fragment in result_fragments:
# 检查fragment是否可能超出token限制
# 根据语言类型调整token估算
if is_chinese_text:
estimated_tokens = len(fragment) / 1.5 # 中文每个token约1-2个字符
else:
estimated_tokens = len(fragment) / 4 # 英文每个token约4个字符
if estimated_tokens > max_token:
# 即使可能超出限制,也尽量保持段落的完整性
# 使用breakdown_text但设置更大的限制来减少分割
larger_limit = max_token * 0.95 # 使用95%的限制
sub_fragments = breakdown_text_to_satisfy_token_limit(
txt=fragment,
limit=larger_limit,
llm_model=self.llm_kwargs['llm_model']
)
final_fragments.extend(sub_fragments)
else:
final_fragments.append(fragment)
return final_fragments
def _process_text_fragments(self, text_fragments: List[TextFragment], current_iteration: int = 1) -> str:
"""处理文本片段,支持多次降重
Args:
text_fragments: 要处理的文本片段列表
current_iteration: 当前是第几次降重迭代
Returns:
处理后的文本内容
"""
self.chatbot[-1] = ["开始处理文本", f"{current_iteration}/{self.reduction_times} 次降重,共 {len(text_fragments)} 个片段"]
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 重置处理结果,为当前迭代做准备
self.processed_results = []
self.failed_fragments = []
# 一次性准备所有输入
inputs_array, inputs_show_user_array, history_array = self._create_batch_inputs(text_fragments, current_iteration)
# 对于单次调用的特殊处理 - 使用更强力的系统提示词
if self.reduction_times == 1:
sys_prompt = f"""你是一个伪装成人类科研学者的AI,能够将各种AI生成内容改写成人类学者常用的表达方式,让人类学者无法分辨。
你的任务是对学术文本进行深度改写,以彻底消除AI特征,使其具有真实人类学者的写作特点
1. 打破AI常见的规整句式结构和词汇选择模式
2. 引入学术写作中的自然变化和不规则性
3. 使用多样化的句式和段落结构,避免明显的平衡性和规律性
4. 加入人类写作中常见的微小不完美之处,如段落长度变化、论述强调点不均等
5. 在维持核心论点的同时,完全重构表达方式
6. 适当使用学科领域内专家常用的表达习惯和术语
7. 确保文本读起来自然流畅,避免机械感
8. 打破AI生成内容常见的逻辑结构和表达套路
9. 根据文本内容,在保持专业性的基础上增加写作的特色和个性
对于每个文本片段,尽可能彻底地重构,使其无法被识别为AI生成。"""
sys_prompt_array = [sys_prompt] * len(text_fragments)
else:
# 原有的多次迭代处理逻辑
# 根据迭代次数调整系统提示词强度
base_instruction = self.plugin_kwargs.get("advanced_arg", """你是一个人类科研工作者,请对以下学术文本进行彻底改写,以显著降低AI生成特征。具体要求如下
1. 保持学术写作的严谨性和专业性
2. 维持原文的核心论述和逻辑框架
3. 优化句式结构:
- 灵活运用主动句与被动句
- 适当拆分复杂句式,提高可读性
- 注意句式的多样性,避免重复模式
- 打破AI常用的句式模板
4. 改善用词:
- 使用更多学术语境下的同义词替换
- 避免过于机械和规律性的连接词
- 适当调整专业术语的表达方式
- 增加词汇多样性,减少重复用词
5. 增强文本的学术特征:
- 注重论证的严密性
- 保持表达的客观性
- 适度体现作者的学术见解
- 避免过于完美和均衡的论述结构
6. 确保语言风格的一致性
7. 减少AI生成文本常见的套路和模式""")
# 根据迭代次数增强强度
if current_iteration == 1:
# 第一次迭代使用基础强度
instruction = base_instruction
sys_prompt = f"""你是一个伪装成人类科研学者的AI,能够将各种AI生成内容改写成人类学者常用的表达方式,让人类学者无法分辨。请进行第{current_iteration}轮深度改写:
1. 严格保持学术写作规范
2. 维持原文的核心论述和逻辑框架
3. 通过优化句式结构和用词降低AI生成特征
4. 确保语言风格的一致性和专业性
5. 保持内容的客观性和准确性
6. 避免AI常见的套路化表达和过于完美的结构"""
elif current_iteration == 2:
# 第二次迭代使用中等强度
instruction = base_instruction + """这是第二次改写,请进一步优化:
1. 更深入地调整句式结构,增加表达多样性
2. 在保持准确性的前提下,替换更多专业术语的表达方式
3. 优化段落结构,增强逻辑连贯性
4. 适当调整学术论证的表达方式
5. 打破AI生成内容的常见模式,如过于均衡的结构和规律性的表达方式"""
sys_prompt = f"""你是一个伪装成人类科研学者的AI,能够将各种AI生成内容改写成人类学者常用的表达方式,让人类学者无法分辨。请进行第二轮改写
1. 在保持严谨性的基础上,优化表达方式
2. 改善文本结构,提升学术论述的清晰度
3. 增强专业术语使用的灵活性
4. 确保论证过程的完整性和逻辑性
5. 进一步减少AI生成内容的特征和模式
6. 使文本表现出更自然、更人性化的写作风格"""
else:
# 第三次或更多次迭代使用高强度
instruction = base_instruction + f"""这是第{current_iteration}次改写,需要在保持学术性的前提下彻底重构文本:
1. 全面重构句式,采用不同的表达策略
2. 系统性替换学术用语,保持专业准确性
3. 重组段落结构,优化论证层次
4. 深化学术观点的表达方式
5. 增强文本的逻辑严密性
6. 提升论述的专业深度
7. 确保不遗漏任何学术观点和论证要素
8. 适当使用学术表达技巧(如精准举例、恰当转折等),但不过分依赖比喻和类比
9. 彻底消除AI生成内容的特征,使文本具有真实学者的写作风格"""
sys_prompt = f"""你是一个伪装成人类科研学者的AI,能够将各种AI生成内容改写成人类学者常用的表达方式,让人类学者无法分辨。请进行第{current_iteration}轮深度改写:
1. 在保持专业水准的前提下,彻底重构表达方式,引入长难句
2. 确保学术论证的严密性和完整性
3. 优化专业术语的运用
4. 提升文本的学术价值
5. 保持论述的逻辑性和连贯性
6. 适当使用少量学术表达技巧,提高文本说服力,但避免过度使用比喻和类比
7. 消除所有明显的AI生成痕迹,使文本更接近真实学者的写作风格"""
sys_prompt_array = [sys_prompt] * len(text_fragments)
# 调用LLM一次性处理所有片段
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
inputs_array=inputs_array,
inputs_show_user_array=inputs_show_user_array,
llm_kwargs=self.llm_kwargs,
chatbot=self.chatbot,
history_array=history_array,
sys_prompt_array=sys_prompt_array,
)
# 处理响应
for j, frag in enumerate(text_fragments):
try:
llm_response = response_collection[j * 2 + 1]
processed_text = self._extract_decision(llm_response)
if processed_text and processed_text.strip():
self.processed_results.append({
'index': frag.fragment_index,
'content': processed_text
})
else:
self.failed_fragments.append(frag)
self.processed_results.append({
'index': frag.fragment_index,
'content': frag.content
})
except Exception as e:
self.failed_fragments.append(frag)
self.processed_results.append({
'index': frag.fragment_index,
'content': frag.content
})
# 按原始顺序合并结果
self.processed_results.sort(key=lambda x: x['index'])
final_content = "\n".join([item['content'] for item in self.processed_results])
# 更新UI
success_count = len(text_fragments) - len(self.failed_fragments)
self.chatbot[-1] = ["当前阶段处理完成", f"{current_iteration}/{self.reduction_times} 次降重,成功处理 {success_count}/{len(text_fragments)} 个片段"]
yield from update_ui(chatbot=self.chatbot, history=self.history)
return final_content
@CatchException
def 学术降重(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
history: List, system_prompt: str, user_request: str):
"""主函数 - 文件到文件处理"""
# 初始化
# 从高级参数中提取降重次数
if "advanced_arg" in plugin_kwargs and plugin_kwargs["advanced_arg"]:
# 检查是否包含降重次数的设置
match = re.search(r'reduction_times\s*=\s*(\d+)', plugin_kwargs["advanced_arg"])
if match:
reduction_times = int(match.group(1))
# 替换掉高级参数中的reduction_times设置,但保留其他内容
plugin_kwargs["advanced_arg"] = re.sub(r'reduction_times\s*=\s*\d+', '', plugin_kwargs["advanced_arg"]).strip()
# 添加到plugin_kwargs中作为单独的参数
plugin_kwargs["reduction_times"] = reduction_times
processor = DocumentProcessor(llm_kwargs, plugin_kwargs, chatbot, history, system_prompt)
chatbot.append(["函数插件功能", f"文件内容处理:将文档内容进行{processor.reduction_times}次降重处理"])
# 更新用户提示,提供关于降重策略的详细说明
if processor.reduction_times == 1:
chatbot.append(["降重策略", "将使用单次深度降重,这种方式能更有效地降低AI特征,减少查重率。我们采用特殊优化的提示词,通过一次性强力改写来实现降重效果。"])
elif processor.reduction_times > 1:
chatbot.append(["降重策略", f"将进行{processor.reduction_times}轮迭代降重,每轮降重都会基于上一轮的结果,并逐渐增加降重强度。请注意,多轮迭代可能会引入新的AI特征,单次强力降重通常效果更好。"])
yield from update_ui(chatbot=chatbot, history=history)
# 验证输入路径
if not os.path.exists(txt):
report_exception(chatbot, history, a=f"解析路径: {txt}", b=f"找不到路径或无权访问: {txt}")
yield from update_ui(chatbot=chatbot, history=history)
return
# 验证路径安全性
user_name = chatbot.get_user()
validate_path_safety(txt, user_name)
# 获取文件列表
if os.path.isfile(txt):
# 单个文件处理
file_paths = [txt]
else:
# 目录处理 - 类似批量文件询问插件
project_folder = txt
extract_folder = next((d for d in glob.glob(f'{project_folder}/*')
if os.path.isdir(d) and d.endswith('.extract')), project_folder)
# 排除压缩文件
exclude_patterns = r'/[^/]+\.(zip|rar|7z|tar|gz)$'
file_paths = [f for f in glob.glob(f'{extract_folder}/**', recursive=True)
if os.path.isfile(f) and not re.search(exclude_patterns, f)]
# 过滤支持的文件格式
file_paths = [f for f in file_paths if any(f.lower().endswith(ext) for ext in
list(processor.paper_extractor.SUPPORTED_EXTENSIONS) + ['.json', '.csv', '.xlsx', '.xls'])]
if not file_paths:
report_exception(chatbot, history, a=f"解析路径: {txt}", b="未找到支持的文件类型")
yield from update_ui(chatbot=chatbot, history=history)
return
# 处理文件
if len(file_paths) > 1:
chatbot.append(["发现多个文件", f"共找到 {len(file_paths)} 个文件,将处理第一个文件"])
yield from update_ui(chatbot=chatbot, history=history)
# 只处理第一个文件
file_to_process = file_paths[0]
processed_content = yield from processor.process_file(file_to_process)
if processed_content:
# 保存结果
result_files = processor.save_results(processed_content, file_to_process)
if result_files:
chatbot.append(["处理完成", f"已生成 {len(result_files)} 个结果文件"])
else:
chatbot.append(["处理完成", "但未能保存任何结果文件"])
else:
chatbot.append(["处理失败", "未能生成有效的处理结果"])
yield from update_ui(chatbot=chatbot, history=history)

查看文件

@@ -0,0 +1,387 @@
import aiohttp
import asyncio
from typing import List, Dict, Optional
import re
import random
import time
class WikipediaAPI:
"""维基百科API调用实现"""
def __init__(self, language: str = "zh", user_agent: str = None,
max_concurrent: int = 5, request_delay: float = 0.5):
"""
初始化维基百科API客户端
Args:
language: 语言代码 (zh: 中文, en: 英文, ja: 日文等)
user_agent: 用户代理信息,如果为None将使用默认值
max_concurrent: 最大并发请求数
request_delay: 请求间隔时间(秒)
"""
self.language = language
self.base_url = f"https://{language}.wikipedia.org/w/api.php"
self.user_agent = user_agent or "WikipediaAPIClient/1.0 (chatscholar@163.com)"
self.headers = {
"User-Agent": self.user_agent,
"Accept": "application/json"
}
# 添加并发控制
self.semaphore = asyncio.Semaphore(max_concurrent)
self.request_delay = request_delay
self.last_request_time = 0
async def _make_request(self, url, params=None):
"""
发起API请求,包含并发控制和请求延迟
Args:
url: 请求URL
params: 请求参数
Returns:
API响应数据
"""
# 使用信号量控制并发
async with self.semaphore:
# 添加请求间隔
current_time = time.time()
time_since_last_request = current_time - self.last_request_time
if time_since_last_request < self.request_delay:
await asyncio.sleep(self.request_delay - time_since_last_request)
# 设置随机延迟,避免规律性请求
jitter = random.uniform(0, 0.2)
await asyncio.sleep(jitter)
# 记录本次请求时间
self.last_request_time = time.time()
# 发起请求
try:
async with aiohttp.ClientSession(headers=self.headers) as session:
async with session.get(url, params=params) as response:
if response.status == 429: # Too Many Requests
retry_after = int(response.headers.get('Retry-After', 5))
print(f"达到请求限制,等待 {retry_after} 秒后重试...")
await asyncio.sleep(retry_after)
return await self._make_request(url, params)
if response.status != 200:
print(f"API请求失败: HTTP {response.status}")
print(f"响应内容: {await response.text()}")
return None
return await response.json()
except aiohttp.ClientError as e:
print(f"请求错误: {str(e)}")
return None
async def search(self, query: str, limit: int = 10, namespace: int = 0) -> List[Dict]:
"""
搜索维基百科文章
Args:
query: 搜索关键词
limit: 返回结果数量
namespace: 命名空间 (0表示文章, 14表示分类等)
Returns:
搜索结果列表
"""
params = {
"action": "query",
"list": "search",
"srsearch": query,
"format": "json",
"srlimit": limit,
"srnamespace": namespace,
"srprop": "snippet|titlesnippet|sectiontitle|categorysnippet|size|wordcount|timestamp|redirecttitle"
}
data = await self._make_request(self.base_url, params)
if not data:
return []
search_results = data.get("query", {}).get("search", [])
return search_results
async def get_page_content(self, title: str, section: Optional[int] = None) -> Dict:
"""
获取维基百科页面内容
Args:
title: 页面标题
section: 特定章节编号(可选)
Returns:
页面内容字典
"""
async with aiohttp.ClientSession(headers=self.headers) as session:
params = {
"action": "parse",
"page": title,
"format": "json",
"prop": "text|langlinks|categories|links|templates|images|externallinks|sections|revid|displaytitle|iwlinks|properties"
}
# 如果指定了章节,只获取该章节内容
if section is not None:
params["section"] = section
async with session.get(self.base_url, params=params) as response:
if response.status != 200:
print(f"API请求失败: HTTP {response.status}")
return {}
data = await response.json()
if "error" in data:
print(f"API错误: {data['error'].get('info', '未知错误')}")
return {}
return data.get("parse", {})
async def get_summary(self, title: str, sentences: int = 3) -> str:
"""
获取页面摘要
Args:
title: 页面标题
sentences: 返回的句子数量
Returns:
页面摘要文本
"""
async with aiohttp.ClientSession(headers=self.headers) as session:
params = {
"action": "query",
"prop": "extracts",
"exintro": "1",
"exsentences": sentences,
"explaintext": "1",
"titles": title,
"format": "json"
}
async with session.get(self.base_url, params=params) as response:
if response.status != 200:
print(f"API请求失败: HTTP {response.status}")
return ""
data = await response.json()
pages = data.get("query", {}).get("pages", {})
# 获取第一个页面ID的内容
for page_id in pages:
return pages[page_id].get("extract", "")
return ""
async def get_random_articles(self, count: int = 1, namespace: int = 0) -> List[Dict]:
"""
获取随机文章
Args:
count: 需要的随机文章数量
namespace: 命名空间
Returns:
随机文章列表
"""
async with aiohttp.ClientSession(headers=self.headers) as session:
params = {
"action": "query",
"list": "random",
"rnlimit": count,
"rnnamespace": namespace,
"format": "json"
}
async with session.get(self.base_url, params=params) as response:
if response.status != 200:
print(f"API请求失败: HTTP {response.status}")
return []
data = await response.json()
return data.get("query", {}).get("random", [])
async def login(self, username: str, password: str) -> bool:
"""
使用维基百科账户登录
Args:
username: 维基百科用户名
password: 维基百科密码
Returns:
登录是否成功
"""
async with aiohttp.ClientSession(headers=self.headers) as session:
# 获取登录令牌
params = {
"action": "query",
"meta": "tokens",
"type": "login",
"format": "json"
}
async with session.get(self.base_url, params=params) as response:
if response.status != 200:
print(f"获取登录令牌失败: HTTP {response.status}")
return False
data = await response.json()
login_token = data.get("query", {}).get("tokens", {}).get("logintoken")
if not login_token:
print("获取登录令牌失败")
return False
# 使用令牌登录
login_params = {
"action": "login",
"lgname": username,
"lgpassword": password,
"lgtoken": login_token,
"format": "json"
}
async with session.post(self.base_url, data=login_params) as login_response:
login_data = await login_response.json()
if login_data.get("login", {}).get("result") == "Success":
print(f"登录成功: {username}")
return True
else:
print(f"登录失败: {login_data.get('login', {}).get('reason', '未知原因')}")
return False
async def setup_oauth(self, consumer_token: str, consumer_secret: str,
access_token: str = None, access_secret: str = None) -> bool:
"""
设置OAuth认证
Args:
consumer_token: 消费者令牌
consumer_secret: 消费者密钥
access_token: 访问令牌(可选)
access_secret: 访问密钥(可选)
Returns:
设置是否成功
"""
try:
# 需要安装 mwoauth 库: pip install mwoauth
import mwoauth
import requests_oauthlib
# 设置OAuth
self.consumer_token = consumer_token
self.consumer_secret = consumer_secret
if access_token and access_secret:
# 如果已有访问令牌
self.auth = requests_oauthlib.OAuth1(
consumer_token,
consumer_secret,
access_token,
access_secret
)
print("OAuth设置成功")
return True
else:
# 需要获取访问令牌(这通常需要用户在网页上授权)
print("请在开发环境中完成以下OAuth授权流程:")
# 创建消费者
consumer = mwoauth.Consumer(
consumer_token, consumer_secret
)
# 初始化握手
redirect, request_token = mwoauth.initiate(
f"https://{self.language}.wikipedia.org/w/index.php",
consumer
)
print(f"请访问此URL授权应用: {redirect}")
# 这里通常会提示用户访问URL并输入授权码
# 实际应用中需要实现适当的授权流程
return False
except ImportError:
print("请安装 mwoauth 库: pip install mwoauth")
return False
except Exception as e:
print(f"设置OAuth时发生错误: {str(e)}")
return False
async def example_usage():
"""演示WikipediaAPI的使用方法"""
# 创建默认中文维基百科API客户端
wiki_zh = WikipediaAPI(language="zh")
try:
# 示例1: 基本搜索
print("\n=== 示例1: 搜索维基百科 ===")
results = await wiki_zh.search("人工智能", limit=3)
for i, result in enumerate(results, 1):
print(f"\n--- 结果 {i} ---")
print(f"标题: {result.get('title')}")
snippet = result.get('snippet', '')
# 清理HTML标签
snippet = re.sub(r'<.*?>', '', snippet)
print(f"摘要: {snippet}")
print(f"字数: {result.get('wordcount')}")
print(f"大小: {result.get('size')} 字节")
# 示例2: 获取页面摘要
print("\n=== 示例2: 获取页面摘要 ===")
summary = await wiki_zh.get_summary("深度学习", sentences=2)
print(f"深度学习摘要: {summary}")
# 示例3: 获取页面内容
print("\n=== 示例3: 获取页面内容 ===")
content = await wiki_zh.get_page_content("机器学习")
if content and "text" in content:
text = content["text"].get("*", "")
# 移除HTML标签以便控制台显示
clean_text = re.sub(r'<.*?>', '', text)
print(f"机器学习页面内容片段: {clean_text[:200]}...")
# 显示页面包含的分类数量
categories = content.get("categories", [])
print(f"分类数量: {len(categories)}")
# 显示页面包含的链接数量
links = content.get("links", [])
print(f"链接数量: {len(links)}")
# 示例4: 获取特定章节内容
print("\n=== 示例4: 获取特定章节内容 ===")
# 获取引言部分(通常是0号章节)
intro_content = await wiki_zh.get_page_content("人工智能", section=0)
if intro_content and "text" in intro_content:
intro_text = intro_content["text"].get("*", "")
clean_intro = re.sub(r'<.*?>', '', intro_text)
print(f"人工智能引言内容片段: {clean_intro[:200]}...")
# 示例5: 获取随机文章
print("\n=== 示例5: 获取随机文章 ===")
random_articles = await wiki_zh.get_random_articles(count=2)
print("随机文章:")
for i, article in enumerate(random_articles, 1):
print(f"{i}. {article.get('title')}")
# 显示随机文章的简短摘要
article_summary = await wiki_zh.get_summary(article.get('title'), sentences=1)
print(f" 摘要: {article_summary[:100]}...")
except Exception as e:
print(f"发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
import asyncio
# 运行示例
asyncio.run(example_usage())

查看文件

@@ -0,0 +1,275 @@
from crazy_functions.ipc_fns.mp import run_in_subprocess_with_timeout
from loguru import logger
import time
import re
def force_breakdown(txt, limit, get_token_fn):
""" 当无法用标点、空行分割时,我们用最暴力的方法切割
"""
for i in reversed(range(len(txt))):
if get_token_fn(txt[:i]) < limit:
return txt[:i], txt[i:]
return "Tiktoken未知错误", "Tiktoken未知错误"
def maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage):
""" 为了加速计算,我们采样一个特殊的手段。当 remain_txt_to_cut > `_max` 时, 我们把 _max 后的文字转存至 remain_txt_to_cut_storage
当 remain_txt_to_cut < `_min` 时,我们再把 remain_txt_to_cut_storage 中的部分文字取出
"""
_min = int(5e4)
_max = int(1e5)
# print(len(remain_txt_to_cut), len(remain_txt_to_cut_storage))
if len(remain_txt_to_cut) < _min and len(remain_txt_to_cut_storage) > 0:
remain_txt_to_cut = remain_txt_to_cut + remain_txt_to_cut_storage
remain_txt_to_cut_storage = ""
if len(remain_txt_to_cut) > _max:
remain_txt_to_cut_storage = remain_txt_to_cut[_max:] + remain_txt_to_cut_storage
remain_txt_to_cut = remain_txt_to_cut[:_max]
return remain_txt_to_cut, remain_txt_to_cut_storage
def cut(limit, get_token_fn, txt_tocut, must_break_at_empty_line, break_anyway=False):
""" 文本切分
"""
res = []
total_len = len(txt_tocut)
fin_len = 0
remain_txt_to_cut = txt_tocut
remain_txt_to_cut_storage = ""
# 为了加速计算,我们采样一个特殊的手段。当 remain_txt_to_cut > `_max` 时, 我们把 _max 后的文字转存至 remain_txt_to_cut_storage
remain_txt_to_cut, remain_txt_to_cut_storage = maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage)
while True:
if get_token_fn(remain_txt_to_cut) <= limit:
# 如果剩余文本的token数小于限制,那么就不用切了
res.append(remain_txt_to_cut); fin_len+=len(remain_txt_to_cut)
break
else:
# 如果剩余文本的token数大于限制,那么就切
lines = remain_txt_to_cut.split('\n')
# 估计一个切分点
estimated_line_cut = limit / get_token_fn(remain_txt_to_cut) * len(lines)
estimated_line_cut = int(estimated_line_cut)
# 开始查找合适切分点的偏移cnt
cnt = 0
for cnt in reversed(range(estimated_line_cut)):
if must_break_at_empty_line:
# 首先尝试用双空行(\n\n作为切分点
if lines[cnt] != "":
continue
prev = "\n".join(lines[:cnt])
post = "\n".join(lines[cnt:])
if get_token_fn(prev) < limit:
break
if cnt == 0:
# 如果没有找到合适的切分点
if break_anyway:
# 是否允许暴力切分
prev, post = force_breakdown(remain_txt_to_cut, limit, get_token_fn)
else:
# 不允许直接报错
raise RuntimeError(f"存在一行极长的文本!{remain_txt_to_cut}")
# 追加列表
res.append(prev); fin_len+=len(prev)
# 准备下一次迭代
remain_txt_to_cut = post
remain_txt_to_cut, remain_txt_to_cut_storage = maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage)
process = fin_len/total_len
logger.info(f'正在文本切分 {int(process*100)}%')
if len(remain_txt_to_cut.strip()) == 0:
break
return res
def breakdown_text_to_satisfy_token_limit_(txt, limit, llm_model="gpt-3.5-turbo"):
""" 使用多种方式尝试切分文本,以满足 token 限制
"""
from request_llms.bridge_all import model_info
enc = model_info[llm_model]['tokenizer']
def get_token_fn(txt): return len(enc.encode(txt, disallowed_special=()))
try:
# 第1次尝试,将双空行\n\n作为切分点
return cut(limit, get_token_fn, txt, must_break_at_empty_line=True)
except RuntimeError:
try:
# 第2次尝试,将单空行\n作为切分点
return cut(limit, get_token_fn, txt, must_break_at_empty_line=False)
except RuntimeError:
try:
# 第3次尝试,将英文句号.)作为切分点
res = cut(limit, get_token_fn, txt.replace('.', '\n'), must_break_at_empty_line=False) # 这个中文的句号是故意的,作为一个标识而存在
return [r.replace('\n', '.') for r in res]
except RuntimeError as e:
try:
# 第4次尝试,将中文句号作为切分点
res = cut(limit, get_token_fn, txt.replace('', '。。\n'), must_break_at_empty_line=False)
return [r.replace('。。\n', '') for r in res]
except RuntimeError as e:
# 第5次尝试,没办法了,随便切一下吧
return cut(limit, get_token_fn, txt, must_break_at_empty_line=False, break_anyway=True)
breakdown_text_to_satisfy_token_limit = run_in_subprocess_with_timeout(breakdown_text_to_satisfy_token_limit_, timeout=60)
def cut_new(limit, get_token_fn, txt_tocut, must_break_at_empty_line, must_break_at_one_empty_line=False, break_anyway=False):
""" 文本切分
"""
res = []
res_empty_line = []
total_len = len(txt_tocut)
fin_len = 0
remain_txt_to_cut = txt_tocut
remain_txt_to_cut_storage = ""
# 为了加速计算,我们采样一个特殊的手段。当 remain_txt_to_cut > `_max` 时, 我们把 _max 后的文字转存至 remain_txt_to_cut_storage
remain_txt_to_cut, remain_txt_to_cut_storage = maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage)
empty=0
while True:
if get_token_fn(remain_txt_to_cut) <= limit:
# 如果剩余文本的token数小于限制,那么就不用切了
res.append(remain_txt_to_cut); fin_len+=len(remain_txt_to_cut)
res_empty_line.append(empty)
break
else:
# 如果剩余文本的token数大于限制,那么就切
lines = remain_txt_to_cut.split('\n')
# 估计一个切分点
estimated_line_cut = limit / get_token_fn(remain_txt_to_cut) * len(lines)
estimated_line_cut = int(estimated_line_cut)
# 开始查找合适切分点的偏移cnt
cnt = 0
for cnt in reversed(range(estimated_line_cut)):
if must_break_at_empty_line:
# 首先尝试用双空行(\n\n作为切分点
if lines[cnt] != "":
continue
if must_break_at_empty_line or must_break_at_one_empty_line:
empty=1
prev = "\n".join(lines[:cnt])
post = "\n".join(lines[cnt:])
if get_token_fn(prev) < limit :
break
# empty=0
if get_token_fn(prev)>limit:
if '.' not in prev or '' not in prev:
# empty = 0
break
# if cnt
if cnt == 0:
# 如果没有找到合适的切分点
if break_anyway:
# 是否允许暴力切分
prev, post = force_breakdown(remain_txt_to_cut, limit, get_token_fn)
empty =0
else:
# 不允许直接报错
raise RuntimeError(f"存在一行极长的文本!{remain_txt_to_cut}")
# 追加列表
res.append(prev); fin_len+=len(prev)
res_empty_line.append(empty)
# 准备下一次迭代
remain_txt_to_cut = post
remain_txt_to_cut, remain_txt_to_cut_storage = maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage)
process = fin_len/total_len
logger.info(f'正在文本切分 {int(process*100)}%')
if len(remain_txt_to_cut.strip()) == 0:
break
return res,res_empty_line
def breakdown_text_to_satisfy_token_limit_new_(txt, limit, llm_model="gpt-3.5-turbo"):
""" 使用多种方式尝试切分文本,以满足 token 限制
"""
from request_llms.bridge_all import model_info
enc = model_info[llm_model]['tokenizer']
def get_token_fn(txt): return len(enc.encode(txt, disallowed_special=()))
try:
# 第1次尝试,将双空行\n\n作为切分点
res, empty_line =cut_new(limit, get_token_fn, txt, must_break_at_empty_line=True)
return res,empty_line
except RuntimeError:
try:
# 第2次尝试,将单空行\n作为切分点
res, _ = cut_new(limit, get_token_fn, txt, must_break_at_empty_line=False,must_break_at_one_empty_line=True)
return res, _
except RuntimeError:
try:
# 第3次尝试,将英文句号.)作为切分点
res, _ = cut_new(limit, get_token_fn, txt.replace('.', '\n'), must_break_at_empty_line=False) # 这个中文的句号是故意的,作为一个标识而存在
return [r.replace('\n', '.') for r in res],_
except RuntimeError as e:
try:
# 第4次尝试,将中文句号作为切分点
res,_ = cut_new(limit, get_token_fn, txt.replace('', '。。\n'), must_break_at_empty_line=False)
return [r.replace('。。\n', '') for r in res], _
except RuntimeError as e:
# 第5次尝试,没办法了,随便切一下吧
res, _ = cut_new(limit, get_token_fn, txt, must_break_at_empty_line=False, break_anyway=True)
return res,_
breakdown_text_to_satisfy_token_limit_new = run_in_subprocess_with_timeout(breakdown_text_to_satisfy_token_limit_new_, timeout=60)
def cut_from_end_to_satisfy_token_limit_(txt, limit, reserve_token=500, llm_model="gpt-3.5-turbo"):
"""从后往前裁剪文本,以论文为单位进行裁剪
参数:
txt: 要处理的文本(格式化后的论文列表字符串)
limit: token数量上限
reserve_token: 需要预留的token数量,默认500
llm_model: 使用的模型名称
返回:
裁剪后的文本
"""
from request_llms.bridge_all import model_info
enc = model_info[llm_model]['tokenizer']
def get_token_fn(txt): return len(enc.encode(txt, disallowed_special=()))
# 计算当前文本的token数
current_tokens = get_token_fn(txt)
target_limit = limit - reserve_token
# 如果当前token数已经在限制范围内,直接返回
if current_tokens <= target_limit:
return txt
# 按论文编号分割文本
papers = re.split(r'\n(?=\d+\. \*\*)', txt)
if not papers:
return txt
# 从前往后累加论文,直到达到token限制
result = papers[0] # 保留第一篇
current_tokens = get_token_fn(result)
for paper in papers[1:]:
paper_tokens = get_token_fn(paper)
if current_tokens + paper_tokens <= target_limit:
result += "\n" + paper
current_tokens += paper_tokens
else:
break
return result
# 添加超时保护
cut_from_end_to_satisfy_token_limit = run_in_subprocess_with_timeout(cut_from_end_to_satisfy_token_limit_, timeout=20)
if __name__ == '__main__':
from crazy_functions.crazy_utils import read_and_clean_pdf_text
file_content, page_one = read_and_clean_pdf_text("build/assets/at.pdf")
from request_llms.bridge_all import model_info
for i in range(5):
file_content += file_content
logger.info(len(file_content))
TOKEN_LIMIT_PER_FRAGMENT = 2500
res = breakdown_text_to_satisfy_token_limit(file_content, TOKEN_LIMIT_PER_FRAGMENT)

查看文件

@@ -113,7 +113,7 @@ def translate_pdf(article_dict, llm_kwargs, chatbot, fp, generated_conclusion_fi
return [txt]
else:
# raw_token_num > TOKEN_LIMIT_PER_FRAGMENT
# find a smooth token limit to achieve even seperation
# find a smooth token limit to achieve even separation
count = int(math.ceil(raw_token_num / TOKEN_LIMIT_PER_FRAGMENT))
token_limit_smooth = raw_token_num // count + count
return breakdown_text_to_satisfy_token_limit(txt, limit=token_limit_smooth, llm_model=llm_kwargs['llm_model'])

查看文件

@@ -1,6 +1,6 @@
import os
from toolbox import CatchException, report_exception, get_log_folder, gen_time_str, check_packages
from toolbox import update_ui, promote_file_to_downloadzone, update_ui_lastest_msg, disable_auto_promotion
from toolbox import update_ui, promote_file_to_downloadzone, update_ui_latest_msg, disable_auto_promotion
from toolbox import write_history_to_file, promote_file_to_downloadzone, get_conf, extract_archive
from crazy_functions.pdf_fns.parse_pdf import parse_pdf, translate_pdf

查看文件

@@ -14,17 +14,17 @@ def extract_text_from_files(txt, chatbot, history):
final_result(list):文本内容
page_one(list):第一页内容/摘要
file_manifest(list):文件路径
excption(string):需要用户手动处理的信息,如没出错则保持为空
exception(string):需要用户手动处理的信息,如没出错则保持为空
"""
final_result = []
page_one = []
file_manifest = []
excption = ""
exception = ""
if txt == "":
final_result.append(txt)
return False, final_result, page_one, file_manifest, excption #如输入区内容不是文件则直接返回输入区内容
return False, final_result, page_one, file_manifest, exception #如输入区内容不是文件则直接返回输入区内容
#查找输入区内容中的文件
file_pdf,pdf_manifest,folder_pdf = get_files_from_everything(txt, '.pdf')
@@ -33,20 +33,20 @@ def extract_text_from_files(txt, chatbot, history):
file_doc,doc_manifest,folder_doc = get_files_from_everything(txt, '.doc')
if file_doc:
excption = "word"
return False, final_result, page_one, file_manifest, excption
exception = "word"
return False, final_result, page_one, file_manifest, exception
file_num = len(pdf_manifest) + len(md_manifest) + len(word_manifest)
if file_num == 0:
final_result.append(txt)
return False, final_result, page_one, file_manifest, excption #如输入区内容不是文件则直接返回输入区内容
return False, final_result, page_one, file_manifest, exception #如输入区内容不是文件则直接返回输入区内容
if file_pdf:
try: # 尝试导入依赖,如果缺少依赖,则给出安装建议
import fitz
except:
excption = "pdf"
return False, final_result, page_one, file_manifest, excption
exception = "pdf"
return False, final_result, page_one, file_manifest, exception
for index, fp in enumerate(pdf_manifest):
file_content, pdf_one = read_and_clean_pdf_text(fp) # 尝试按照章节切割PDF
file_content = file_content.encode('utf-8', 'ignore').decode() # avoid reading non-utf8 chars
@@ -72,8 +72,8 @@ def extract_text_from_files(txt, chatbot, history):
try: # 尝试导入依赖,如果缺少依赖,则给出安装建议
from docx import Document
except:
excption = "word_pip"
return False, final_result, page_one, file_manifest, excption
exception = "word_pip"
return False, final_result, page_one, file_manifest, exception
for index, fp in enumerate(word_manifest):
doc = Document(fp)
file_content = '\n'.join([p.text for p in doc.paragraphs])
@@ -82,4 +82,4 @@ def extract_text_from_files(txt, chatbot, history):
final_result.append(file_content)
file_manifest.append(os.path.relpath(fp, folder_word))
return True, final_result, page_one, file_manifest, excption
return True, final_result, page_one, file_manifest, exception

查看文件

@@ -1,22 +1,48 @@
import subprocess
import os
from llama_index.core import SimpleDirectoryReader
supports_format = ['.csv', '.docx', '.epub', '.ipynb', '.mbox', '.md', '.pdf', '.txt', '.ppt',
'.pptm', '.pptx']
supports_format = ['.csv', '.docx', '.epub', '.ipynb', '.mbox', '.md', '.pdf', '.txt', '.ppt', '.pptm', '.pptx', '.bat']
def convert_to_markdown(file_path: str) -> str:
"""
将支持的文件格式转换为Markdown格式
Args:
file_path: 输入文件路径
Returns:
str: 转换后的Markdown文件路径,如果转换失败则返回原始文件路径
"""
_, ext = os.path.splitext(file_path.lower())
if ext in ['.docx', '.doc', '.pptx', '.ppt', '.pptm', '.xls', '.xlsx', '.csv', 'pdf']:
try:
# 创建输出Markdown文件路径
md_path = os.path.splitext(file_path)[0] + '.md'
# 使用markitdown工具将文件转换为Markdown
command = f"markitdown {file_path} > {md_path}"
subprocess.run(command, shell=True, check=True)
print(f"已将{ext}文件转换为Markdown: {md_path}")
return md_path
except Exception as e:
print(f"{ext}转Markdown失败: {str(e)},将继续处理原文件")
return file_path
return file_path
# 修改后的 extract_text 函数,结合 SimpleDirectoryReader 和自定义解析逻辑
def extract_text(file_path):
from llama_index.core import SimpleDirectoryReader
_, ext = os.path.splitext(file_path.lower())
# 使用 SimpleDirectoryReader 处理它支持的文件格式
if ext in supports_format:
try:
reader = SimpleDirectoryReader(input_files=[file_path])
print(f"Extracting text from {file_path} using SimpleDirectoryReader")
documents = reader.load_data()
if len(documents) > 0:
return documents[0].text
print(f"Complete: Extracting text from {file_path} using SimpleDirectoryReader")
buffer = [ doc.text for doc in documents ]
return '\n'.join(buffer)
except Exception as e:
pass
return None
else:
return '格式不支持'

查看文件

查看文件

@@ -0,0 +1,68 @@
from typing import List
from crazy_functions.review_fns.data_sources.base_source import PaperMetadata
class EndNoteFormatter:
"""EndNote参考文献格式生成器"""
def __init__(self):
pass
def create_document(self, papers: List[PaperMetadata]) -> str:
"""生成EndNote格式的参考文献文本
Args:
papers: 论文列表
Returns:
str: EndNote格式的参考文献文本
"""
endnote_text = ""
for paper in papers:
# 开始一个新条目
endnote_text += "%0 Journal Article\n" # 默认类型为期刊文章
# 根据venue_type调整条目类型
if hasattr(paper, 'venue_type') and paper.venue_type:
if paper.venue_type.lower() == 'conference':
endnote_text = endnote_text.replace("Journal Article", "Conference Paper")
elif paper.venue_type.lower() == 'preprint':
endnote_text = endnote_text.replace("Journal Article", "Electronic Article")
# 添加标题
endnote_text += f"%T {paper.title}\n"
# 添加作者
for author in paper.authors:
endnote_text += f"%A {author}\n"
# 添加年份
if paper.year:
endnote_text += f"%D {paper.year}\n"
# 添加期刊/会议名称
if hasattr(paper, 'venue_name') and paper.venue_name:
endnote_text += f"%J {paper.venue_name}\n"
elif paper.venue:
endnote_text += f"%J {paper.venue}\n"
# 添加DOI
if paper.doi:
endnote_text += f"%R {paper.doi}\n"
endnote_text += f"%U https://doi.org/{paper.doi}\n"
elif paper.url:
endnote_text += f"%U {paper.url}\n"
# 添加摘要
if paper.abstract:
endnote_text += f"%X {paper.abstract}\n"
# 添加机构
if hasattr(paper, 'institutions'):
for institution in paper.institutions:
endnote_text += f"%I {institution}\n"
# 条目之间添加空行
endnote_text += "\n"
return endnote_text

查看文件

@@ -0,0 +1,211 @@
import re
import os
import pandas as pd
from datetime import datetime
class ExcelTableFormatter:
"""聊天记录中Markdown表格转Excel生成器"""
def __init__(self):
"""初始化Excel文档对象"""
from openpyxl import Workbook
self.workbook = Workbook()
self._table_count = 0
self._current_sheet = None
def _normalize_table_row(self, row):
"""标准化表格行,处理不同的分隔符情况"""
row = row.strip()
if row.startswith('|'):
row = row[1:]
if row.endswith('|'):
row = row[:-1]
return [cell.strip() for cell in row.split('|')]
def _is_separator_row(self, row):
"""检查是否是分隔行(由 - 或 : 组成)"""
clean_row = re.sub(r'[\s|]', '', row)
return bool(re.match(r'^[-:]+$', clean_row))
def _extract_tables_from_text(self, text):
"""从文本中提取所有表格内容"""
if not isinstance(text, str):
return []
tables = []
current_table = []
is_in_table = False
for line in text.split('\n'):
line = line.strip()
if not line:
if is_in_table and current_table:
if len(current_table) >= 2:
tables.append(current_table)
current_table = []
is_in_table = False
continue
if '|' in line:
if not is_in_table:
is_in_table = True
current_table.append(line)
else:
if is_in_table and current_table:
if len(current_table) >= 2:
tables.append(current_table)
current_table = []
is_in_table = False
if is_in_table and current_table and len(current_table) >= 2:
tables.append(current_table)
return tables
def _parse_table(self, table_lines):
"""解析表格内容为结构化数据"""
try:
headers = self._normalize_table_row(table_lines[0])
separator_index = next(
(i for i, line in enumerate(table_lines) if self._is_separator_row(line)),
1
)
data_rows = []
for line in table_lines[separator_index + 1:]:
cells = self._normalize_table_row(line)
# 确保单元格数量与表头一致
while len(cells) < len(headers):
cells.append('')
cells = cells[:len(headers)]
data_rows.append(cells)
if headers and data_rows:
return {
'headers': headers,
'data': data_rows
}
except Exception as e:
print(f"解析表格时发生错误: {str(e)}")
return None
def _create_sheet(self, question_num, table_num):
"""创建新的工作表"""
sheet_name = f'Q{question_num}_T{table_num}'
if len(sheet_name) > 31:
sheet_name = f'Table{self._table_count}'
if sheet_name in self.workbook.sheetnames:
sheet_name = f'{sheet_name}_{datetime.now().strftime("%H%M%S")}'
return self.workbook.create_sheet(title=sheet_name)
def create_document(self, history):
"""
处理聊天历史中的所有表格并创建Excel文档
Args:
history: 聊天历史列表
Returns:
Workbook: 处理完成的Excel工作簿对象,如果没有表格则返回None
"""
has_tables = False
# 删除默认创建的工作表
default_sheet = self.workbook['Sheet']
self.workbook.remove(default_sheet)
# 遍历所有回答
for i in range(1, len(history), 2):
answer = history[i]
tables = self._extract_tables_from_text(answer)
for table_lines in tables:
parsed_table = self._parse_table(table_lines)
if parsed_table:
self._table_count += 1
sheet = self._create_sheet(i // 2 + 1, self._table_count)
# 写入表头
for col, header in enumerate(parsed_table['headers'], 1):
sheet.cell(row=1, column=col, value=header)
# 写入数据
for row_idx, row_data in enumerate(parsed_table['data'], 2):
for col_idx, value in enumerate(row_data, 1):
sheet.cell(row=row_idx, column=col_idx, value=value)
has_tables = True
return self.workbook if has_tables else None
def save_chat_tables(history, save_dir, base_name):
"""
保存聊天历史中的表格到Excel文件
Args:
history: 聊天历史列表
save_dir: 保存目录
base_name: 基础文件名
Returns:
list: 保存的文件路径列表
"""
result_files = []
try:
# 创建Excel格式
excel_formatter = ExcelTableFormatter()
workbook = excel_formatter.create_document(history)
if workbook is not None:
# 确保保存目录存在
os.makedirs(save_dir, exist_ok=True)
# 生成Excel文件路径
excel_file = os.path.join(save_dir, base_name + '.xlsx')
# 保存Excel文件
workbook.save(excel_file)
result_files.append(excel_file)
print(f"已保存表格到Excel文件: {excel_file}")
except Exception as e:
print(f"保存Excel格式失败: {str(e)}")
return result_files
# 使用示例
if __name__ == "__main__":
# 示例聊天历史
history = [
"问题1",
"""这是第一个表格:
| A | B | C |
|---|---|---|
| 1 | 2 | 3 |""",
"问题2",
"这是没有表格的回答",
"问题3",
"""回答包含多个表格:
| Name | Age |
|------|-----|
| Tom | 20 |
第二个表格:
| X | Y |
|---|---|
| 1 | 2 |"""
]
# 保存表格
save_dir = "output"
base_name = "chat_tables"
saved_files = save_chat_tables(history, save_dir, base_name)

查看文件

@@ -0,0 +1,472 @@
class HtmlFormatter:
"""聊天记录HTML格式生成器"""
def __init__(self):
self.css_styles = """
:root {
--primary-color: #2563eb;
--primary-light: #eff6ff;
--secondary-color: #1e293b;
--background-color: #f8fafc;
--text-color: #334155;
--border-color: #e2e8f0;
--card-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1);
}
body {
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
line-height: 1.8;
margin: 0;
padding: 2rem;
color: var(--text-color);
background-color: var(--background-color);
}
.container {
max-width: 1200px;
margin: 0 auto;
background: white;
padding: 2rem;
border-radius: 16px;
box-shadow: var(--card-shadow);
}
::selection {
background: var(--primary-light);
color: var(--primary-color);
}
@keyframes fadeIn {
from { opacity: 0; transform: translateY(20px); }
to { opacity: 1; transform: translateY(0); }
}
@keyframes slideIn {
from { transform: translateX(-20px); opacity: 0; }
to { transform: translateX(0); opacity: 1; }
}
.container {
animation: fadeIn 0.6s ease-out;
}
.QaBox {
animation: slideIn 0.5s ease-out;
transition: all 0.3s ease;
}
.QaBox:hover {
transform: translateX(5px);
}
.Question, .Answer, .historyBox {
transition: all 0.3s ease;
}
.chat-title {
color: var(--primary-color);
font-size: 2em;
text-align: center;
margin: 1rem 0 2rem;
padding-bottom: 1rem;
border-bottom: 2px solid var(--primary-color);
}
.chat-body {
display: flex;
flex-direction: column;
gap: 1.5rem;
margin: 2rem 0;
}
.QaBox {
background: white;
padding: 1.5rem;
border-radius: 8px;
border-left: 4px solid var(--primary-color);
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
margin-bottom: 1.5rem;
}
.Question {
color: var(--secondary-color);
font-weight: 500;
margin-bottom: 1rem;
}
.Answer {
color: var(--text-color);
background: var(--primary-light);
padding: 1rem;
border-radius: 6px;
}
.history-section {
margin-top: 3rem;
padding-top: 2rem;
border-top: 2px solid var(--border-color);
}
.history-title {
color: var(--secondary-color);
font-size: 1.5em;
margin-bottom: 1.5rem;
text-align: center;
}
.historyBox {
background: white;
padding: 1rem;
margin: 0.5rem 0;
border-radius: 6px;
border: 1px solid var(--border-color);
}
@media (prefers-color-scheme: dark) {
:root {
--background-color: #0f172a;
--text-color: #e2e8f0;
--border-color: #1e293b;
}
.container, .QaBox {
background: #1e293b;
}
}
"""
def create_document(self, question: str, answer: str, ranked_papers: list = None) -> str:
"""生成完整的HTML文档
Args:
question: str, 用户问题
answer: str, AI回答
ranked_papers: list, 排序后的论文列表
Returns:
str: 完整的HTML文档字符串
"""
chat_content = f'''
<div class="QaBox">
<div class="Question">{question}</div>
<div class="Answer markdown-body" id="answer-content">{answer}</div>
</div>
'''
references_content = ""
if ranked_papers:
references_content = '<div class="history-section"><h2 class="history-title">参考文献</h2>'
for idx, paper in enumerate(ranked_papers, 1):
authors = ', '.join(paper.authors)
# 构建引用信息
citations_info = f"被引用次数:{paper.citations}" if paper.citations is not None else "引用信息未知"
# 构建下载链接
download_links = []
if paper.doi:
# 检查是否是arXiv链接
if 'arxiv.org' in paper.doi:
# 如果DOI中包含完整的arXiv URL,直接使用
arxiv_url = paper.doi if paper.doi.startswith('http') else f'http://{paper.doi}'
download_links.append(f'<a href="{arxiv_url}">arXiv链接</a>')
# 提取arXiv ID并添加PDF链接
arxiv_id = arxiv_url.split('abs/')[-1].split('v')[0]
download_links.append(f'<a href="https://arxiv.org/pdf/{arxiv_id}.pdf">PDF下载</a>')
else:
# 非arXiv的DOI使用标准格式
download_links.append(f'<a href="https://doi.org/{paper.doi}">DOI: {paper.doi}</a>')
if hasattr(paper, 'url') and paper.url and 'arxiv.org' not in str(paper.url):
# 只有当URL不是arXiv链接时才添加
download_links.append(f'<a href="{paper.url}">原文链接</a>')
download_section = ' | '.join(download_links) if download_links else "无直接下载链接"
# 构建来源信息
source_info = []
if paper.venue_type:
source_info.append(f"类型:{paper.venue_type}")
if paper.venue_name:
source_info.append(f"来源:{paper.venue_name}")
# 添加期刊指标信息
if hasattr(paper, 'if_factor') and paper.if_factor:
source_info.append(f"<span class='journal-metric'>IF: {paper.if_factor}</span>")
if hasattr(paper, 'jcr_division') and paper.jcr_division:
source_info.append(f"<span class='journal-metric'>JCR分区: {paper.jcr_division}</span>")
if hasattr(paper, 'cas_division') and paper.cas_division:
source_info.append(f"<span class='journal-metric'>中科院分区: {paper.cas_division}</span>")
if hasattr(paper, 'venue_info') and paper.venue_info:
if paper.venue_info.get('journal_ref'):
source_info.append(f"期刊参考:{paper.venue_info['journal_ref']}")
if paper.venue_info.get('publisher'):
source_info.append(f"出版商:{paper.venue_info['publisher']}")
source_section = ' | '.join(source_info) if source_info else ""
# 构建标准引用格式
standard_citation = f"[{idx}] "
# 添加作者最多3个,超过则添加et al.
author_list = paper.authors[:3]
if len(paper.authors) > 3:
author_list.append("et al.")
standard_citation += ", ".join(author_list) + ". "
# 添加标题
standard_citation += f"<i>{paper.title}</i>"
# 添加期刊/会议名称
if paper.venue_name:
standard_citation += f". {paper.venue_name}"
# 添加年份
if paper.year:
standard_citation += f", {paper.year}"
# 添加DOI
if paper.doi:
if 'arxiv.org' in paper.doi:
# 如果是arXiv链接,直接使用arXiv URL
arxiv_url = paper.doi if paper.doi.startswith('http') else f'http://{paper.doi}'
standard_citation += f". {arxiv_url}"
else:
# 非arXiv的DOI使用标准格式
standard_citation += f". DOI: {paper.doi}"
standard_citation += "."
references_content += f'''
<div class="historyBox">
<div class="entry">
<p class="paper-title"><b>[{idx}]</b> <i>{paper.title}</i></p>
<p class="paper-authors">作者:{authors}</p>
<p class="paper-year">发表年份:{paper.year if paper.year else "未知"}</p>
<p class="paper-citations">{citations_info}</p>
{f'<p class="paper-source">{source_section}</p>' if source_section else ""}
<p class="paper-abstract">摘要:{paper.abstract if paper.abstract else "无摘要"}</p>
<p class="paper-links">链接:{download_section}</p>
<div class="standard-citation">
<p class="citation-title">标准引用格式:</p>
<p class="citation-text">{standard_citation}</p>
<button class="copy-btn" onclick="copyToClipboard(this.previousElementSibling)">复制引用格式</button>
</div>
</div>
</div>
'''
references_content += '</div>'
# 添加新的CSS样式
css_additions = """
.paper-title {
font-size: 1.1em;
margin-bottom: 0.5em;
}
.paper-authors {
color: var(--secondary-color);
margin: 0.3em 0;
}
.paper-year, .paper-citations {
color: var(--text-color);
margin: 0.3em 0;
}
.paper-source {
color: var(--text-color);
font-style: italic;
margin: 0.3em 0;
}
.paper-abstract {
margin: 0.8em 0;
padding: 0.8em;
background: var(--primary-light);
border-radius: 4px;
}
.paper-links {
margin-top: 0.5em;
}
.paper-links a {
color: var(--primary-color);
text-decoration: none;
margin-right: 1em;
}
.paper-links a:hover {
text-decoration: underline;
}
.standard-citation {
margin-top: 1em;
padding: 1em;
background: #f8fafc;
border-radius: 4px;
border: 1px solid var(--border-color);
}
.citation-title {
font-weight: bold;
margin-bottom: 0.5em;
color: var(--secondary-color);
}
.citation-text {
font-family: 'Times New Roman', Times, serif;
line-height: 1.6;
margin-bottom: 0.5em;
padding: 0.5em;
background: white;
border-radius: 4px;
border: 1px solid var(--border-color);
}
.copy-btn {
background: var(--primary-color);
color: white;
border: none;
padding: 0.5em 1em;
border-radius: 4px;
cursor: pointer;
font-size: 0.9em;
transition: background-color 0.2s;
}
.copy-btn:hover {
background: #1e40af;
}
@media (prefers-color-scheme: dark) {
.standard-citation {
background: #1e293b;
}
.citation-text {
background: #0f172a;
}
}
/* 添加期刊指标样式 */
.journal-metric {
display: inline-block;
padding: 0.2em 0.6em;
margin: 0 0.3em;
background: var(--primary-light);
border-radius: 4px;
font-weight: 500;
color: var(--primary-color);
}
@media (prefers-color-scheme: dark) {
.journal-metric {
background: #1e293b;
color: #60a5fa;
}
}
"""
# 修改 js_code 部分,添加 markdown 解析功能
js_code = """
<script>
// 复制功能
function copyToClipboard(element) {
const text = element.innerText;
navigator.clipboard.writeText(text).then(function() {
const btn = element.nextElementSibling;
const originalText = btn.innerText;
btn.innerText = '已复制!';
setTimeout(() => {
btn.innerText = originalText;
}, 2000);
}).catch(function(err) {
console.error('复制失败:', err);
});
}
// Markdown解析
document.addEventListener('DOMContentLoaded', function() {
const answerContent = document.getElementById('answer-content');
if (answerContent) {
const markdown = answerContent.textContent;
answerContent.innerHTML = marked.parse(markdown);
}
});
</script>
"""
# 将新的CSS样式添加到现有样式中
self.css_styles += css_additions
return f"""
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>学术对话存档</title>
<!-- 添加 marked.js -->
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
<!-- 添加 GitHub Markdown CSS -->
<link rel="stylesheet" href="https://cdn.jsdelivr.net/gh/sindresorhus/github-markdown-css@4.0.0/github-markdown.min.css">
<style>
{self.css_styles}
/* 添加 Markdown 相关样式 */
.markdown-body {{
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
padding: 1rem;
background: var(--primary-light);
border-radius: 6px;
}}
.markdown-body pre {{
background-color: #f6f8fa;
border-radius: 6px;
padding: 16px;
overflow: auto;
}}
.markdown-body code {{
background-color: rgba(175,184,193,0.2);
border-radius: 6px;
padding: 0.2em 0.4em;
font-size: 85%;
}}
.markdown-body pre code {{
background-color: transparent;
padding: 0;
}}
.markdown-body blockquote {{
border-left: 0.25em solid #d0d7de;
padding: 0 1em;
color: #656d76;
}}
.markdown-body table {{
border-collapse: collapse;
width: 100%;
margin: 1em 0;
}}
.markdown-body table th,
.markdown-body table td {{
border: 1px solid #d0d7de;
padding: 6px 13px;
}}
.markdown-body table tr:nth-child(2n) {{
background-color: #f6f8fa;
}}
@media (prefers-color-scheme: dark) {{
.markdown-body {{
background: #1e293b;
color: #e2e8f0;
}}
.markdown-body pre {{
background-color: #0f172a;
}}
.markdown-body code {{
background-color: rgba(99,110,123,0.4);
}}
.markdown-body blockquote {{
border-left-color: #30363d;
color: #8b949e;
}}
.markdown-body table th,
.markdown-body table td {{
border-color: #30363d;
}}
.markdown-body table tr:nth-child(2n) {{
background-color: #0f172a;
}}
}}
</style>
</head>
<body>
<div class="container">
<h1 class="chat-title">学术对话存档</h1>
<div class="chat-body">
{chat_content}
{references_content}
</div>
</div>
{js_code}
</body>
</html>
"""

查看文件

@@ -0,0 +1,47 @@
class MarkdownFormatter:
"""Markdown格式文档生成器 - 用于生成对话记录的markdown文档"""
def __init__(self):
self.content = []
def _add_content(self, text: str):
"""添加正文内容"""
if text:
self.content.append(f"\n{text}\n")
def create_document(self, question: str, answer: str, ranked_papers: list = None) -> str:
"""创建完整的Markdown文档
Args:
question: str, 用户问题
answer: str, AI回答
ranked_papers: list, 排序后的论文列表
Returns:
str: 生成的Markdown文本
"""
content = []
# 添加问答部分
content.append("## 问题")
content.append(question)
content.append("\n## 回答")
content.append(answer)
# 添加参考文献
if ranked_papers:
content.append("\n## 参考文献")
for idx, paper in enumerate(ranked_papers, 1):
authors = ', '.join(paper.authors[:3])
if len(paper.authors) > 3:
authors += ' et al.'
ref = f"[{idx}] {authors}. *{paper.title}*"
if paper.venue_name:
ref += f". {paper.venue_name}"
if paper.year:
ref += f", {paper.year}"
if paper.doi:
ref += f". [DOI: {paper.doi}](https://doi.org/{paper.doi})"
content.append(ref)
return "\n\n".join(content)

查看文件

@@ -0,0 +1,174 @@
from typing import List
from crazy_functions.review_fns.data_sources.base_source import PaperMetadata
import re
class ReferenceFormatter:
"""通用参考文献格式生成器"""
def __init__(self):
pass
def _sanitize_bibtex(self, text: str) -> str:
"""清理BibTeX字符串,处理特殊字符"""
if not text:
return ""
# 替换特殊字符
replacements = {
'&': '\\&',
'%': '\\%',
'$': '\\$',
'#': '\\#',
'_': '\\_',
'{': '\\{',
'}': '\\}',
'~': '\\textasciitilde{}',
'^': '\\textasciicircum{}',
'\\': '\\textbackslash{}',
'<': '\\textless{}',
'>': '\\textgreater{}',
'"': '``',
"'": "'",
'-': '--',
'': '---',
}
for char, replacement in replacements.items():
text = text.replace(char, replacement)
return text
def _generate_cite_key(self, paper: PaperMetadata) -> str:
"""生成引用键
格式: 第一作者姓氏_年份_第一个实词
"""
# 获取第一作者姓氏
first_author = ""
if paper.authors and len(paper.authors) > 0:
first_author = paper.authors[0].split()[-1].lower()
# 获取年份
year = str(paper.year) if paper.year else "0000"
# 从标题中获取第一个实词
title_word = ""
if paper.title:
# 移除特殊字符,分割成单词
words = re.findall(r'\w+', paper.title.lower())
# 过滤掉常见的停用词
stop_words = {'a', 'an', 'the', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'}
for word in words:
if word not in stop_words and len(word) > 2:
title_word = word
break
# 组合cite key
cite_key = f"{first_author}{year}{title_word}"
# 确保cite key只包含合法字符
cite_key = re.sub(r'[^a-z0-9]', '', cite_key.lower())
return cite_key
def _get_entry_type(self, paper: PaperMetadata) -> str:
"""确定BibTeX条目类型"""
if hasattr(paper, 'venue_type') and paper.venue_type:
venue_type = paper.venue_type.lower()
if venue_type == 'conference':
return 'inproceedings'
elif venue_type == 'preprint':
return 'unpublished'
elif venue_type == 'journal':
return 'article'
elif venue_type == 'book':
return 'book'
elif venue_type == 'thesis':
return 'phdthesis'
return 'article' # 默认为期刊文章
def create_document(self, papers: List[PaperMetadata]) -> str:
"""生成BibTeX格式的参考文献文本"""
bibtex_text = "% This file was automatically generated by GPT-Academic\n"
bibtex_text += "% Compatible with: EndNote, Zotero, JabRef, and LaTeX\n\n"
for paper in papers:
entry_type = self._get_entry_type(paper)
cite_key = self._generate_cite_key(paper)
bibtex_text += f"@{entry_type}{{{cite_key},\n"
# 添加标题
if paper.title:
bibtex_text += f" title = {{{self._sanitize_bibtex(paper.title)}}},\n"
# 添加作者
if paper.authors:
# 确保每个作者的姓和名正确分隔
processed_authors = []
for author in paper.authors:
names = author.split()
if len(names) > 1:
# 假设最后一个词是姓,其他的是名
surname = names[-1]
given_names = ' '.join(names[:-1])
processed_authors.append(f"{surname}, {given_names}")
else:
processed_authors.append(author)
authors = " and ".join([self._sanitize_bibtex(author) for author in processed_authors])
bibtex_text += f" author = {{{authors}}},\n"
# 添加年份
if paper.year:
bibtex_text += f" year = {{{paper.year}}},\n"
# 添加期刊/会议名称
if hasattr(paper, 'venue_name') and paper.venue_name:
if entry_type == 'inproceedings':
bibtex_text += f" booktitle = {{{self._sanitize_bibtex(paper.venue_name)}}},\n"
elif entry_type == 'article':
bibtex_text += f" journal = {{{self._sanitize_bibtex(paper.venue_name)}}},\n"
# 添加期刊相关信息
if hasattr(paper, 'venue_info'):
if 'volume' in paper.venue_info:
bibtex_text += f" volume = {{{paper.venue_info['volume']}}},\n"
if 'number' in paper.venue_info:
bibtex_text += f" number = {{{paper.venue_info['number']}}},\n"
if 'pages' in paper.venue_info:
bibtex_text += f" pages = {{{paper.venue_info['pages']}}},\n"
elif paper.venue:
venue_field = "booktitle" if entry_type == "inproceedings" else "journal"
bibtex_text += f" {venue_field} = {{{self._sanitize_bibtex(paper.venue)}}},\n"
# 添加DOI
if paper.doi:
bibtex_text += f" doi = {{{paper.doi}}},\n"
# 添加URL
if paper.url:
bibtex_text += f" url = {{{paper.url}}},\n"
elif paper.doi:
bibtex_text += f" url = {{https://doi.org/{paper.doi}}},\n"
# 添加摘要
if paper.abstract:
bibtex_text += f" abstract = {{{self._sanitize_bibtex(paper.abstract)}}},\n"
# 添加机构
if hasattr(paper, 'institutions') and paper.institutions:
institutions = " and ".join([self._sanitize_bibtex(inst) for inst in paper.institutions])
bibtex_text += f" institution = {{{institutions}}},\n"
# 添加月份
if hasattr(paper, 'month'):
bibtex_text += f" month = {{{paper.month}}},\n"
# 添加注释字段
if hasattr(paper, 'note'):
bibtex_text += f" note = {{{self._sanitize_bibtex(paper.note)}}},\n"
# 移除最后一个逗号并关闭条目
bibtex_text = bibtex_text.rstrip(',\n') + "\n}\n\n"
return bibtex_text

查看文件

@@ -0,0 +1,138 @@
from docx2pdf import convert
import os
import platform
from typing import Union
from pathlib import Path
from datetime import datetime
class WordToPdfConverter:
"""Word文档转PDF转换器"""
@staticmethod
def _replace_docx_in_filename(filename: Union[str, Path]) -> Path:
"""
将文件名中的'docx'替换为'pdf'
例如: 'docx_test.pdf' -> 'pdf_test.pdf'
"""
path = Path(filename)
new_name = path.stem.replace('docx', 'pdf')
return path.parent / f"{new_name}{path.suffix}"
@staticmethod
def convert_to_pdf(word_path: Union[str, Path], pdf_path: Union[str, Path] = None) -> str:
"""
将Word文档转换为PDF
参数:
word_path: Word文档的路径
pdf_path: 可选,PDF文件的输出路径。如果未指定,将使用与Word文档相同的名称和位置
返回:
生成的PDF文件路径
异常:
如果转换失败,将抛出相应异常
"""
try:
word_path = Path(word_path)
if pdf_path is None:
# 创建新的pdf路径,同时替换文件名中的docx
pdf_path = WordToPdfConverter._replace_docx_in_filename(word_path).with_suffix('.pdf')
else:
pdf_path = WordToPdfConverter._replace_docx_in_filename(Path(pdf_path))
# 检查操作系统
if platform.system() == 'Linux':
# Linux系统需要安装libreoffice
if not os.system('which libreoffice') == 0:
raise RuntimeError("请先安装LibreOffice: sudo apt-get install libreoffice")
# 使用libreoffice进行转换
os.system(f'libreoffice --headless --convert-to pdf "{word_path}" --outdir "{pdf_path.parent}"')
# 如果输出路径与默认生成的不同,则重命名
default_pdf = word_path.with_suffix('.pdf')
if default_pdf != pdf_path:
os.rename(default_pdf, pdf_path)
else:
# Windows和MacOS使用 docx2pdf
convert(word_path, pdf_path)
return str(pdf_path)
except Exception as e:
raise Exception(f"转换PDF失败: {str(e)}")
@staticmethod
def batch_convert(word_dir: Union[str, Path], pdf_dir: Union[str, Path] = None) -> list:
"""
批量转换目录下的所有Word文档
参数:
word_dir: 包含Word文档的目录路径
pdf_dir: 可选,PDF文件的输出目录。如果未指定,将使用与Word文档相同的目录
返回:
生成的PDF文件路径列表
"""
word_dir = Path(word_dir)
if pdf_dir:
pdf_dir = Path(pdf_dir)
pdf_dir.mkdir(parents=True, exist_ok=True)
converted_files = []
for word_file in word_dir.glob("*.docx"):
try:
if pdf_dir:
pdf_path = pdf_dir / WordToPdfConverter._replace_docx_in_filename(
word_file.with_suffix('.pdf')
).name
else:
pdf_path = WordToPdfConverter._replace_docx_in_filename(
word_file.with_suffix('.pdf')
)
pdf_file = WordToPdfConverter.convert_to_pdf(word_file, pdf_path)
converted_files.append(pdf_file)
except Exception as e:
print(f"转换 {word_file} 失败: {str(e)}")
return converted_files
@staticmethod
def convert_doc_to_pdf(doc, output_dir: Union[str, Path] = None) -> str:
"""
将docx对象直接转换为PDF
参数:
doc: python-docx的Document对象
output_dir: 可选,输出目录。如果未指定,将使用当前目录
返回:
生成的PDF文件路径
"""
try:
# 设置临时文件路径和输出路径
output_dir = Path(output_dir) if output_dir else Path.cwd()
output_dir.mkdir(parents=True, exist_ok=True)
# 生成临时word文件
temp_docx = output_dir / f"temp_{datetime.now().strftime('%Y%m%d_%H%M%S')}.docx"
doc.save(temp_docx)
# 转换为PDF
pdf_path = temp_docx.with_suffix('.pdf')
WordToPdfConverter.convert_to_pdf(temp_docx, pdf_path)
# 删除临时word文件
temp_docx.unlink()
return str(pdf_path)
except Exception as e:
if temp_docx.exists():
temp_docx.unlink()
raise Exception(f"转换PDF失败: {str(e)}")

查看文件

@@ -0,0 +1,246 @@
import re
from docx import Document
from docx.shared import Cm, Pt
from docx.enum.text import WD_PARAGRAPH_ALIGNMENT, WD_LINE_SPACING
from docx.enum.style import WD_STYLE_TYPE
from docx.oxml.ns import qn
from datetime import datetime
import docx
from docx.oxml import shared
from crazy_functions.doc_fns.conversation_doc.word_doc import convert_markdown_to_word
class WordFormatter:
"""聊天记录Word文档生成器 - 符合中国政府公文格式规范(GB/T 9704-2012)"""
def __init__(self):
self.doc = Document()
self._setup_document()
self._create_styles()
def _setup_document(self):
"""设置文档基本格式,包括页面设置和页眉"""
sections = self.doc.sections
for section in sections:
# 设置页面大小为A4
section.page_width = Cm(21)
section.page_height = Cm(29.7)
# 设置页边距
section.top_margin = Cm(3.7) # 上边距37mm
section.bottom_margin = Cm(3.5) # 下边距35mm
section.left_margin = Cm(2.8) # 左边距28mm
section.right_margin = Cm(2.6) # 右边距26mm
# 设置页眉页脚距离
section.header_distance = Cm(2.0)
section.footer_distance = Cm(2.0)
# 修改页眉
header = section.header
header_para = header.paragraphs[0]
header_para.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
header_run = header_para.add_run("GPT-Academic学术对话 (体验地址https://auth.gpt-academic.top/)")
header_run.font.name = '仿宋'
header_run._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
header_run.font.size = Pt(9)
def _create_styles(self):
"""创建文档样式"""
# 创建正文样式
style = self.doc.styles.add_style('Normal_Custom', WD_STYLE_TYPE.PARAGRAPH)
style.font.name = '仿宋'
style._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
style.font.size = Pt(12)
style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
style.paragraph_format.space_after = Pt(0)
# 创建问题样式
question_style = self.doc.styles.add_style('Question_Style', WD_STYLE_TYPE.PARAGRAPH)
question_style.font.name = '黑体'
question_style._element.rPr.rFonts.set(qn('w:eastAsia'), '黑体')
question_style.font.size = Pt(14) # 调整为14磅
question_style.font.bold = True
question_style.paragraph_format.space_before = Pt(12) # 减小段前距
question_style.paragraph_format.space_after = Pt(6)
question_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
question_style.paragraph_format.left_indent = Pt(0) # 移除左缩进
# 创建回答样式
answer_style = self.doc.styles.add_style('Answer_Style', WD_STYLE_TYPE.PARAGRAPH)
answer_style.font.name = '仿宋'
answer_style._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
answer_style.font.size = Pt(12) # 调整为12磅
answer_style.paragraph_format.space_before = Pt(6)
answer_style.paragraph_format.space_after = Pt(12)
answer_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
answer_style.paragraph_format.left_indent = Pt(0) # 移除左缩进
# 创建标题样式
title_style = self.doc.styles.add_style('Title_Custom', WD_STYLE_TYPE.PARAGRAPH)
title_style.font.name = '黑体' # 改用黑体
title_style._element.rPr.rFonts.set(qn('w:eastAsia'), '黑体')
title_style.font.size = Pt(22) # 调整为22磅
title_style.font.bold = True
title_style.paragraph_format.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
title_style.paragraph_format.space_before = Pt(0)
title_style.paragraph_format.space_after = Pt(24)
title_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
# 添加参考文献样式
ref_style = self.doc.styles.add_style('Reference_Style', WD_STYLE_TYPE.PARAGRAPH)
ref_style.font.name = '宋体'
ref_style._element.rPr.rFonts.set(qn('w:eastAsia'), '宋体')
ref_style.font.size = Pt(10.5) # 参考文献使用小号字体
ref_style.paragraph_format.space_before = Pt(3)
ref_style.paragraph_format.space_after = Pt(3)
ref_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.SINGLE
ref_style.paragraph_format.left_indent = Pt(21)
ref_style.paragraph_format.first_line_indent = Pt(-21)
# 添加参考文献标题样式
ref_title_style = self.doc.styles.add_style('Reference_Title_Style', WD_STYLE_TYPE.PARAGRAPH)
ref_title_style.font.name = '黑体'
ref_title_style._element.rPr.rFonts.set(qn('w:eastAsia'), '黑体')
ref_title_style.font.size = Pt(16) # 参考文献标题与问题同样大小
ref_title_style.font.bold = True
ref_title_style.paragraph_format.space_before = Pt(24) # 增加段前距
ref_title_style.paragraph_format.space_after = Pt(12)
ref_title_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
def create_document(self, question: str, answer: str, ranked_papers: list = None):
"""写入聊天历史
Args:
question: str, 用户问题
answer: str, AI回答
ranked_papers: list, 排序后的论文列表
"""
try:
# 添加标题
title_para = self.doc.add_paragraph(style='Title_Custom')
title_run = title_para.add_run('GPT-Academic 对话记录')
# 添加日期
try:
date_para = self.doc.add_paragraph()
date_para.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
date_run = date_para.add_run(datetime.now().strftime('%Y年%m月%d'))
date_run.font.name = '仿宋'
date_run._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
date_run.font.size = Pt(16)
except Exception as e:
print(f"添加日期失败: {str(e)}")
raise
self.doc.add_paragraph() # 添加空行
# 添加问答对话
try:
q_para = self.doc.add_paragraph(style='Question_Style')
q_para.add_run('问题:').bold = True
q_para.add_run(str(question))
a_para = self.doc.add_paragraph(style='Answer_Style')
a_para.add_run('回答:').bold = True
a_para.add_run(convert_markdown_to_word(str(answer)))
except Exception as e:
print(f"添加问答对话失败: {str(e)}")
raise
# 添加参考文献部分
if ranked_papers:
try:
ref_title = self.doc.add_paragraph(style='Reference_Title_Style')
ref_title.add_run("参考文献")
for idx, paper in enumerate(ranked_papers, 1):
try:
ref_para = self.doc.add_paragraph(style='Reference_Style')
ref_para.add_run(f'[{idx}] ').bold = True
# 添加作者
authors = ', '.join(paper.authors[:3])
if len(paper.authors) > 3:
authors += ' et al.'
ref_para.add_run(f'{authors}. ')
# 添加标题
title_run = ref_para.add_run(paper.title)
title_run.italic = True
if hasattr(paper, 'url') and paper.url:
try:
title_run._element.rPr.rStyle = self._create_hyperlink_style()
self._add_hyperlink(ref_para, paper.title, paper.url)
except Exception as e:
print(f"添加超链接失败: {str(e)}")
# 添加期刊/会议信息
if paper.venue_name:
ref_para.add_run(f'. {paper.venue_name}')
# 添加年份
if paper.year:
ref_para.add_run(f', {paper.year}')
# 添加DOI
if paper.doi:
ref_para.add_run('. ')
if "arxiv" in paper.url:
doi_url = paper.doi
else:
doi_url = f'https://doi.org/{paper.doi}'
self._add_hyperlink(ref_para, f'DOI: {paper.doi}', doi_url)
ref_para.add_run('.')
except Exception as e:
print(f"添加第 {idx} 篇参考文献失败: {str(e)}")
continue
except Exception as e:
print(f"添加参考文献部分失败: {str(e)}")
raise
return self.doc
except Exception as e:
print(f"Word文档创建失败: {str(e)}")
import traceback
print(f"详细错误信息: {traceback.format_exc()}")
raise
def _create_hyperlink_style(self):
"""创建超链接样式"""
styles = self.doc.styles
if 'Hyperlink' not in styles:
hyperlink_style = styles.add_style('Hyperlink', WD_STYLE_TYPE.CHARACTER)
# 使用科技蓝 (#0066CC)
hyperlink_style.font.color.rgb = 0x0066CC # 科技蓝
hyperlink_style.font.underline = True
return styles['Hyperlink']
def _add_hyperlink(self, paragraph, text, url):
"""添加超链接到段落"""
# 这个是在XML级别添加超链接
part = paragraph.part
r_id = part.relate_to(url, docx.opc.constants.RELATIONSHIP_TYPE.HYPERLINK, is_external=True)
# 创建超链接XML元素
hyperlink = docx.oxml.shared.OxmlElement('w:hyperlink')
hyperlink.set(docx.oxml.shared.qn('r:id'), r_id)
# 创建文本运行
new_run = docx.oxml.shared.OxmlElement('w:r')
rPr = docx.oxml.shared.OxmlElement('w:rPr')
# 应用超链接样式
rStyle = docx.oxml.shared.OxmlElement('w:rStyle')
rStyle.set(docx.oxml.shared.qn('w:val'), 'Hyperlink')
rPr.append(rStyle)
# 添加文本
t = docx.oxml.shared.OxmlElement('w:t')
t.text = text
new_run.append(rPr)
new_run.append(t)
hyperlink.append(new_run)
# 将超链接添加到段落
paragraph._p.append(hyperlink)

查看文件

@@ -0,0 +1,279 @@
from typing import List, Optional, Dict, Union
from datetime import datetime
import aiohttp
import asyncio
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
import json
from tqdm import tqdm
import random
class AdsabsSource(DataSource):
"""ADS (Astrophysics Data System) API实现"""
# 定义API密钥列表
API_KEYS = [
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
]
def __init__(self, api_key: str = None):
"""初始化
Args:
api_key: ADS API密钥,如果不提供则从预定义列表中随机选择
"""
self.api_key = api_key or random.choice(self.API_KEYS) # 随机选择一个API密钥
self._initialize()
def _initialize(self) -> None:
"""初始化基础URL和请求头"""
self.base_url = "https://api.adsabs.harvard.edu/v1"
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
async def _make_request(self, url: str, method: str = "GET", data: dict = None) -> Optional[dict]:
"""发送HTTP请求
Args:
url: 请求URL
method: HTTP方法
data: POST请求数据
Returns:
响应内容
"""
try:
async with aiohttp.ClientSession(headers=self.headers) as session:
if method == "GET":
async with session.get(url) as response:
if response.status == 200:
return await response.json()
elif method == "POST":
async with session.post(url, json=data) as response:
if response.status == 200:
return await response.json()
return None
except Exception as e:
print(f"请求发生错误: {str(e)}")
return None
def _parse_paper(self, doc: dict) -> PaperMetadata:
"""解析ADS文献数据
Args:
doc: ADS文献数据
Returns:
解析后的论文数据
"""
try:
return PaperMetadata(
title=doc.get('title', [''])[0] if doc.get('title') else '',
authors=doc.get('author', []),
abstract=doc.get('abstract', ''),
year=doc.get('year'),
doi=doc.get('doi', [''])[0] if doc.get('doi') else None,
url=f"https://ui.adsabs.harvard.edu/abs/{doc.get('bibcode')}/abstract" if doc.get('bibcode') else None,
citations=doc.get('citation_count'),
venue=doc.get('pub', ''),
institutions=doc.get('aff', []),
venue_type="journal",
venue_name=doc.get('pub', ''),
venue_info={
'volume': doc.get('volume'),
'issue': doc.get('issue'),
'pub_date': doc.get('pubdate', '')
},
source='adsabs'
)
except Exception as e:
print(f"解析文章时发生错误: {str(e)}")
return None
async def search(
self,
query: str,
limit: int = 100,
sort_by: str = "relevance",
start_year: int = None
) -> List[PaperMetadata]:
"""搜索论文
Args:
query: 搜索关键词
limit: 返回结果数量限制
sort_by: 排序方式 ('relevance', 'date', 'citations')
start_year: 起始年份
Returns:
论文列表
"""
try:
# 构建查询
if start_year:
query = f"{query} year:{start_year}-"
# 设置排序
sort_mapping = {
'relevance': 'score desc',
'date': 'date desc',
'citations': 'citation_count desc'
}
sort = sort_mapping.get(sort_by, 'score desc')
# 构建搜索请求
search_url = f"{self.base_url}/search/query"
params = {
"q": query,
"rows": limit,
"sort": sort,
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
}
response = await self._make_request(f"{search_url}?{self._build_query_string(params)}")
if not response or 'response' not in response:
return []
# 解析结果
papers = []
for doc in response['response']['docs']:
paper = self._parse_paper(doc)
if paper:
papers.append(paper)
return papers
except Exception as e:
print(f"搜索论文时发生错误: {str(e)}")
return []
def _build_query_string(self, params: dict) -> str:
"""构建查询字符串"""
return "&".join([f"{k}={v}" for k, v in params.items()])
async def get_paper_details(self, bibcode: str) -> Optional[PaperMetadata]:
"""获取指定bibcode的论文详情"""
search_url = f"{self.base_url}/search/query"
params = {
"q": f"identifier:{bibcode}",
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
}
response = await self._make_request(f"{search_url}?{self._build_query_string(params)}")
if response and 'response' in response and response['response']['docs']:
return self._parse_paper(response['response']['docs'][0])
return None
async def get_related_papers(self, bibcode: str, limit: int = 100) -> List[PaperMetadata]:
"""获取相关论文"""
url = f"{self.base_url}/search/query"
params = {
"q": f"citations(identifier:{bibcode}) OR references(identifier:{bibcode})",
"rows": limit,
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
}
response = await self._make_request(f"{url}?{self._build_query_string(params)}")
if not response or 'response' not in response:
return []
papers = []
for doc in response['response']['docs']:
paper = self._parse_paper(doc)
if paper:
papers.append(paper)
return papers
async def search_by_author(
self,
author: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""按作者搜索论文"""
query = f"author:\"{author}\""
return await self.search(query, limit=limit, start_year=start_year)
async def search_by_journal(
self,
journal: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""按期刊搜索论文"""
query = f"pub:\"{journal}\""
return await self.search(query, limit=limit, start_year=start_year)
async def get_latest_papers(
self,
days: int = 7,
limit: int = 100
) -> List[PaperMetadata]:
"""获取最新论文"""
query = f"entdate:[NOW-{days}DAYS TO NOW]"
return await self.search(query, limit=limit, sort_by="date")
async def get_citations(self, bibcode: str) -> List[PaperMetadata]:
"""获取引用该论文的文献"""
url = f"{self.base_url}/search/query"
params = {
"q": f"citations(identifier:{bibcode})",
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
}
response = await self._make_request(f"{url}?{self._build_query_string(params)}")
if not response or 'response' not in response:
return []
papers = []
for doc in response['response']['docs']:
paper = self._parse_paper(doc)
if paper:
papers.append(paper)
return papers
async def get_references(self, bibcode: str) -> List[PaperMetadata]:
"""获取该论文引用的文献"""
url = f"{self.base_url}/search/query"
params = {
"q": f"references(identifier:{bibcode})",
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
}
response = await self._make_request(f"{url}?{self._build_query_string(params)}")
if not response or 'response' not in response:
return []
papers = []
for doc in response['response']['docs']:
paper = self._parse_paper(doc)
if paper:
papers.append(paper)
return papers
async def example_usage():
"""AdsabsSource使用示例"""
ads = AdsabsSource()
try:
# 示例1基本搜索
print("\n=== 示例1搜索黑洞相关论文 ===")
papers = await ads.search("black hole", limit=3)
for i, paper in enumerate(papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
# 其他示例...
except Exception as e:
print(f"发生错误: {str(e)}")
if __name__ == "__main__":
# python -m crazy_functions.review_fns.data_sources.adsabs_source
asyncio.run(example_usage())

查看文件

@@ -0,0 +1,636 @@
import arxiv
from typing import List, Optional, Union, Literal, Dict
from datetime import datetime
from .base_source import DataSource, PaperMetadata
import os
from urllib.request import urlretrieve
import feedparser
from tqdm import tqdm
class ArxivSource(DataSource):
"""arXiv API实现"""
CATEGORIES = {
# 物理学
"Physics": {
"astro-ph": "天体物理学",
"cond-mat": "凝聚态物理",
"gr-qc": "广义相对论与量子宇宙学",
"hep-ex": "高能物理实验",
"hep-lat": "格点场论",
"hep-ph": "高能物理理论",
"hep-th": "高能物理理论",
"math-ph": "数学物理",
"nlin": "非线性科学",
"nucl-ex": "核实验",
"nucl-th": "核理论",
"physics": "物理学",
"quant-ph": "量子物理",
},
# 数学
"Mathematics": {
"math.AG": "代数几何",
"math.AT": "代数拓扑",
"math.AP": "分析与偏微分方程",
"math.CT": "范畴论",
"math.CA": "复分析",
"math.CO": "组合数学",
"math.AC": "交换代数",
"math.CV": "复变函数",
"math.DG": "微分几何",
"math.DS": "动力系统",
"math.FA": "泛函分析",
"math.GM": "一般数学",
"math.GN": "一般拓扑",
"math.GT": "几何拓扑",
"math.GR": "群论",
"math.HO": "数学史与数学概述",
"math.IT": "信息论",
"math.KT": "K理论与同调",
"math.LO": "逻辑",
"math.MP": "数学物理",
"math.MG": "度量几何",
"math.NT": "数论",
"math.NA": "数值分析",
"math.OA": "算子代数",
"math.OC": "最优化与控制",
"math.PR": "概率论",
"math.QA": "量子代数",
"math.RT": "表示论",
"math.RA": "环与代数",
"math.SP": "谱理论",
"math.ST": "统计理论",
"math.SG": "辛几何",
},
# 计算机科学
"Computer Science": {
"cs.AI": "人工智能",
"cs.CL": "计算语言学",
"cs.CC": "计算复杂性",
"cs.CE": "计算工程",
"cs.CG": "计算几何",
"cs.GT": "计算机博弈论",
"cs.CV": "计算机视觉",
"cs.CY": "计算机与社会",
"cs.CR": "密码学与安全",
"cs.DS": "数据结构与算法",
"cs.DB": "数据库",
"cs.DL": "数字图书馆",
"cs.DM": "离散数学",
"cs.DC": "分布式计算",
"cs.ET": "新兴技术",
"cs.FL": "形式语言与自动机理论",
"cs.GL": "一般文献",
"cs.GR": "图形学",
"cs.AR": "硬件架构",
"cs.HC": "人机交互",
"cs.IR": "信息检索",
"cs.IT": "信息论",
"cs.LG": "机器学习",
"cs.LO": "逻辑与计算机",
"cs.MS": "数学软件",
"cs.MA": "多智能体系统",
"cs.MM": "多媒体",
"cs.NI": "网络与互联网架构",
"cs.NE": "神经与进化计算",
"cs.NA": "数值分析",
"cs.OS": "操作系统",
"cs.OH": "其他计算机科学",
"cs.PF": "性能评估",
"cs.PL": "编程语言",
"cs.RO": "机器人学",
"cs.SI": "社会与信息网络",
"cs.SE": "软件工程",
"cs.SD": "声音",
"cs.SC": "符号计算",
"cs.SY": "系统与控制",
},
# 定量生物学
"Quantitative Biology": {
"q-bio.BM": "生物分子",
"q-bio.CB": "细胞行为",
"q-bio.GN": "基因组学",
"q-bio.MN": "分子网络",
"q-bio.NC": "神经计算",
"q-bio.OT": "其他",
"q-bio.PE": "群体与进化",
"q-bio.QM": "定量方法",
"q-bio.SC": "亚细胞过程",
"q-bio.TO": "组织与器官",
},
# 定量金融
"Quantitative Finance": {
"q-fin.CP": "计算金融",
"q-fin.EC": "经济学",
"q-fin.GN": "一般金融",
"q-fin.MF": "数学金融",
"q-fin.PM": "投资组合管理",
"q-fin.PR": "定价理论",
"q-fin.RM": "风险管理",
"q-fin.ST": "统计金融",
"q-fin.TR": "交易与市场微观结构",
},
# 统计学
"Statistics": {
"stat.AP": "应用统计",
"stat.CO": "计算统计",
"stat.ML": "机器学习",
"stat.ME": "方法论",
"stat.OT": "其他统计",
"stat.TH": "统计理论",
},
# 电气工程与系统科学
"Electrical Engineering and Systems Science": {
"eess.AS": "音频与语音处理",
"eess.IV": "图像与视频处理",
"eess.SP": "信号处理",
"eess.SY": "系统与控制",
},
# 经济学
"Economics": {
"econ.EM": "计量经济学",
"econ.GN": "一般经济学",
"econ.TH": "理论经济学",
}
}
def __init__(self):
"""初始化"""
self._initialize() # 调用初始化方法
# 修改排序选项映射
self.sort_options = {
'relevance': arxiv.SortCriterion.Relevance, # arXiv的相关性排序
'lastUpdatedDate': arxiv.SortCriterion.LastUpdatedDate, # 最后更新日期
'submittedDate': arxiv.SortCriterion.SubmittedDate, # 提交日期
}
self.sort_order_options = {
'ascending': arxiv.SortOrder.Ascending,
'descending': arxiv.SortOrder.Descending
}
self.default_sort = 'lastUpdatedDate'
self.default_order = 'descending'
def _initialize(self) -> None:
"""初始化客户端,设置默认参数"""
self.client = arxiv.Client()
async def search(
self,
query: str,
limit: int = 10,
sort_by: str = None,
sort_order: str = None,
start_year: int = None
) -> List[Dict]:
"""搜索论文"""
try:
# 使用默认排序如果提供的排序选项无效
if not sort_by or sort_by not in self.sort_options:
sort_by = self.default_sort
# 使用默认排序顺序如果提供的顺序无效
if not sort_order or sort_order not in self.sort_order_options:
sort_order = self.default_order
# 如果指定了起始年份,添加到查询中
if start_year:
query = f"{query} AND submittedDate:[{start_year}0101 TO 99991231]"
search = arxiv.Search(
query=query,
max_results=limit,
sort_by=self.sort_options[sort_by],
sort_order=self.sort_order_options[sort_order]
)
results = list(self.client.results(search))
return [self._parse_paper_data(result) for result in results]
except Exception as e:
print(f"搜索论文时发生错误: {str(e)}")
return []
async def search_by_id(self, paper_id: Union[str, List[str]]) -> List[PaperMetadata]:
"""按ID搜索论文
Args:
paper_id: 单个arXiv ID或ID列表,例如'2005.14165' 或 ['2005.14165', '2103.14030']
"""
if isinstance(paper_id, str):
paper_id = [paper_id]
search = arxiv.Search(
id_list=paper_id,
max_results=len(paper_id)
)
results = list(self.client.results(search))
return [self._parse_paper_data(result) for result in results]
async def search_by_category(
self,
category: str,
limit: int = 100,
sort_by: str = 'relevance',
sort_order: str = 'descending',
start_year: int = None
) -> List[PaperMetadata]:
"""按类别搜索论文"""
query = f"cat:{category}"
# 如果指定了起始年份,添加到查询中
if start_year:
query = f"{query} AND submittedDate:[{start_year}0101 TO 99991231]"
return await self.search(
query=query,
limit=limit,
sort_by=sort_by,
sort_order=sort_order
)
async def search_by_authors(
self,
authors: List[str],
limit: int = 100,
sort_by: str = 'relevance',
start_year: int = None
) -> List[PaperMetadata]:
"""按作者搜索论文"""
query = " AND ".join([f"au:\"{author}\"" for author in authors])
# 如果指定了起始年份,添加到查询中
if start_year:
query = f"{query} AND submittedDate:[{start_year}0101 TO 99991231]"
return await self.search(
query=query,
limit=limit,
sort_by=sort_by
)
async def search_by_date_range(
self,
start_date: datetime,
end_date: datetime,
limit: int = 100,
sort_by: Literal['relevance', 'updated', 'submitted'] = 'submitted',
sort_order: Literal['ascending', 'descending'] = 'descending'
) -> List[PaperMetadata]:
"""按日期范围搜索论文"""
query = f"submittedDate:[{start_date.strftime('%Y%m%d')} TO {end_date.strftime('%Y%m%d')}]"
return await self.search(
query,
limit=limit,
sort_by=sort_by,
sort_order=sort_order
)
async def download_pdf(self, paper_id: str, dirpath: str = "./", filename: str = "") -> str:
"""下载论文PDF
Args:
paper_id: arXiv ID
dirpath: 保存目录
filename: 文件名,如果为空则使用默认格式:{paper_id}_{标题}.pdf
Returns:
保存的文件路径
"""
papers = await self.search_by_id(paper_id)
if not papers:
raise ValueError(f"未找到ID为 {paper_id} 的论文")
paper = papers[0]
if not filename:
# 清理标题中的非法字符
safe_title = "".join(c if c.isalnum() else "_" for c in paper.title)
filename = f"{paper_id}_{safe_title}.pdf"
filepath = os.path.join(dirpath, filename)
urlretrieve(paper.url, filepath)
return filepath
async def download_source(self, paper_id: str, dirpath: str = "./", filename: str = "") -> str:
"""下载论文源文件通常是LaTeX源码
Args:
paper_id: arXiv ID
dirpath: 保存目录
filename: 文件名,如果为空则使用默认格式:{paper_id}_{标题}.tar.gz
Returns:
保存的文件路径
"""
papers = await self.search_by_id(paper_id)
if not papers:
raise ValueError(f"未找到ID为 {paper_id} 的论文")
paper = papers[0]
if not filename:
safe_title = "".join(c if c.isalnum() else "_" for c in paper.title)
filename = f"{paper_id}_{safe_title}.tar.gz"
filepath = os.path.join(dirpath, filename)
source_url = paper.url.replace("/pdf/", "/src/")
urlretrieve(source_url, filepath)
return filepath
async def get_citations(self, paper_id: str) -> List[PaperMetadata]:
# arXiv API不直接提供引用信息
return []
async def get_references(self, paper_id: str) -> List[PaperMetadata]:
# arXiv API不直接提供引用信息
return []
async def get_paper_details(self, paper_id: str) -> Optional[PaperMetadata]:
"""获取论文详情
Args:
paper_id: arXiv ID 或 DOI
Returns:
论文详细信息,如果未找到返回 None
"""
try:
# 如果是完整的 arXiv URL,提取 ID
if "arxiv.org" in paper_id:
paper_id = paper_id.split("/")[-1]
# 如果是 DOI 格式且是 arXiv 论文,提取 ID
elif paper_id.startswith("10.48550/arXiv."):
paper_id = paper_id.split(".")[-1]
papers = await self.search_by_id(paper_id)
return papers[0] if papers else None
except Exception as e:
print(f"获取论文详情时发生错误: {str(e)}")
return None
def _parse_paper_data(self, result: arxiv.Result) -> PaperMetadata:
"""解析arXiv API返回的数据"""
# 解析主要类别和次要类别
primary_category = result.primary_category
categories = result.categories
# 构建venue信息
venue_info = {
'primary_category': primary_category,
'categories': categories,
'comments': getattr(result, 'comment', None),
'journal_ref': getattr(result, 'journal_ref', None)
}
return PaperMetadata(
title=result.title,
authors=[author.name for author in result.authors],
abstract=result.summary,
year=result.published.year,
doi=result.entry_id,
url=result.pdf_url,
citations=None,
venue=f"arXiv:{primary_category}",
institutions=[],
venue_type='preprint', # arXiv论文都是预印本
venue_name='arXiv',
venue_info=venue_info,
source='arxiv' # 添加来源标记
)
async def get_latest_papers(
self,
category: str,
debug: bool = False,
batch_size: int = 50
) -> List[PaperMetadata]:
"""获取指定类别的最新论文
通过 RSS feed 获取最新发布的论文,然后批量获取详细信息
Args:
category: arXiv类别,例如
- 整个领域: 'cs'
- 具体方向: 'cs.AI'
- 多个类别: 'cs.AI+q-bio.NC'
debug: 是否为调试模式,如果为True则只返回5篇最新论文
batch_size: 批量获取论文的数量,默认50
Returns:
论文列表
Raises:
ValueError: 如果类别无效
"""
try:
# 处理类别格式
# 1. 转换为小写
# 2. 确保多个类别之间使用+连接
category = category.lower().replace(' ', '+')
# 构建RSS feed URL
feed_url = f"https://rss.arxiv.org/rss/{category}"
print(f"正在获取RSS feed: {feed_url}") # 添加调试信息
feed = feedparser.parse(feed_url)
# 检查feed是否有效
if hasattr(feed, 'status') and feed.status != 200:
raise ValueError(f"获取RSS feed失败,状态码: {feed.status}")
if not feed.entries:
print(f"警告未在feed中找到任何条目") # 添加调试信息
print(f"Feed标题: {feed.feed.title if hasattr(feed, 'feed') else '无标题'}")
raise ValueError(f"无效的arXiv类别或未找到论文: {category}")
if debug:
# 调试模式只获取5篇最新论文
search = arxiv.Search(
query=f'cat:{category}',
sort_by=arxiv.SortCriterion.SubmittedDate,
sort_order=arxiv.SortOrder.Descending,
max_results=5
)
results = list(self.client.results(search))
return [self._parse_paper_data(result) for result in results]
# 正常模式:获取所有新论文
# 从RSS条目中提取arXiv ID
paper_ids = []
for entry in feed.entries:
try:
# RSS链接格式可能是以下几种
# - http://arxiv.org/abs/2403.xxxxx
# - http://arxiv.org/pdf/2403.xxxxx
# - https://arxiv.org/abs/2403.xxxxx
link = entry.link or entry.id
arxiv_id = link.split('/')[-1].replace('.pdf', '')
if arxiv_id:
paper_ids.append(arxiv_id)
except Exception as e:
print(f"警告:处理条目时出错: {str(e)}") # 添加调试信息
continue
if not paper_ids:
print("未能从feed中提取到任何论文ID") # 添加调试信息
return []
print(f"成功提取到 {len(paper_ids)} 个论文ID") # 添加调试信息
# 批量获取论文详情
papers = []
with tqdm(total=len(paper_ids), desc="获取arXiv论文") as pbar:
for i in range(0, len(paper_ids), batch_size):
batch_ids = paper_ids[i:i + batch_size]
search = arxiv.Search(
id_list=batch_ids,
max_results=len(batch_ids)
)
batch_results = list(self.client.results(search))
papers.extend([self._parse_paper_data(result) for result in batch_results])
pbar.update(len(batch_results))
return papers
except Exception as e:
print(f"获取最新论文时发生错误: {str(e)}")
import traceback
print(traceback.format_exc()) # 添加完整的错误追踪
return []
async def example_usage():
"""ArxivSource使用示例"""
arxiv_source = ArxivSource()
try:
# 示例1基本搜索,使用不同的排序方式
# print("\n=== 示例1搜索最新的机器学习论文按提交时间排序===")
# papers = await arxiv_source.search(
# "ti:\"machine learning\"",
# limit=3,
# sort_by='submitted',
# sort_order='descending'
# )
# print(f"找到 {len(papers)} 篇论文")
# for i, paper in enumerate(papers, 1):
# print(f"\n--- 论文 {i} ---")
# print(f"标题: {paper.title}")
# print(f"作者: {', '.join(paper.authors)}")
# print(f"发表年份: {paper.year}")
# print(f"arXiv ID: {paper.doi}")
# print(f"PDF URL: {paper.url}")
# if paper.abstract:
# print(f"\n摘要:")
# print(paper.abstract)
# print(f"发表venue: {paper.venue}")
# # 示例2按ID搜索
# print("\n=== 示例2按ID搜索论文 ===")
# paper_id = "2005.14165" # GPT-3论文
# papers = await arxiv_source.search_by_id(paper_id)
# if papers:
# paper = papers[0]
# print(f"标题: {paper.title}")
# print(f"作者: {', '.join(paper.authors)}")
# print(f"发表年份: {paper.year}")
# # 示例3按类别搜索
# print("\n=== 示例3搜索人工智能领域最新论文 ===")
# ai_papers = await arxiv_source.search_by_category(
# "cs.AI",
# limit=2,
# sort_by='updated',
# sort_order='descending'
# )
# for i, paper in enumerate(ai_papers, 1):
# print(f"\n--- AI论文 {i} ---")
# print(f"标题: {paper.title}")
# print(f"作者: {', '.join(paper.authors)}")
# print(f"发表venue: {paper.venue}")
# # 示例4按作者搜索
# print("\n=== 示例4搜索特定作者的论文 ===")
# author_papers = await arxiv_source.search_by_authors(
# ["Bengio"],
# limit=2,
# sort_by='relevance'
# )
# for i, paper in enumerate(author_papers, 1):
# print(f"\n--- Bengio的论文 {i} ---")
# print(f"标题: {paper.title}")
# print(f"作者: {', '.join(paper.authors)}")
# print(f"发表venue: {paper.venue}")
# # 示例5按日期范围搜索
# print("\n=== 示例5搜索特定日期范围的论文 ===")
# from datetime import datetime, timedelta
# end_date = datetime.now()
# start_date = end_date - timedelta(days=7) # 最近一周
# recent_papers = await arxiv_source.search_by_date_range(
# start_date,
# end_date,
# limit=2
# )
# for i, paper in enumerate(recent_papers, 1):
# print(f"\n--- 最近论文 {i} ---")
# print(f"标题: {paper.title}")
# print(f"作者: {', '.join(paper.authors)}")
# print(f"发表年份: {paper.year}")
# # 示例6下载PDF
# print("\n=== 示例6下载论文PDF ===")
# if papers: # 使用之前搜索到的GPT-3论文
# pdf_path = await arxiv_source.download_pdf(paper_id)
# print(f"PDF已下载到: {pdf_path}")
# # 示例7下载源文件
# print("\n=== 示例7下载论文源文件 ===")
# if papers:
# source_path = await arxiv_source.download_source(paper_id)
# print(f"源文件已下载到: {source_path}")
# 示例6获取最新论文
print("\n=== 示例8获取最新论文 ===")
# 获取CS.AI领域的最新论文
print("\n--- 获取AI领域最新论文 ---")
ai_latest = await arxiv_source.get_latest_papers("cs.AI", debug=True)
for i, paper in enumerate(ai_latest, 1):
print(f"\n论文 {i}:")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
# 获取整个计算机科学领域的最新论文
print("\n--- 获取整个CS领域最新论文 ---")
cs_latest = await arxiv_source.get_latest_papers("cs", debug=True)
for i, paper in enumerate(cs_latest, 1):
print(f"\n论文 {i}:")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
# 获取多个类别的最新论文
print("\n--- 获取AI和机器学习领域最新论文 ---")
multi_latest = await arxiv_source.get_latest_papers("cs.AI+cs.LG", debug=True)
for i, paper in enumerate(multi_latest, 1):
print(f"\n论文 {i}:")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
except Exception as e:
print(f"发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
import asyncio
asyncio.run(example_usage())

查看文件

@@ -0,0 +1,102 @@
from abc import ABC, abstractmethod
from typing import List, Dict, Optional
from dataclasses import dataclass
class PaperMetadata:
"""论文元数据"""
def __init__(
self,
title: str,
authors: List[str],
abstract: str,
year: int,
doi: str = None,
url: str = None,
citations: int = None,
venue: str = None,
institutions: List[str] = None,
venue_type: str = None, # 来源类型(journal/conference/preprint等)
venue_name: str = None, # 具体的期刊/会议名称
venue_info: Dict = None, # 更多来源详细信息(如影响因子、分区等)
source: str = None # 新增: 论文来源标记
):
self.title = title
self.authors = authors
self.abstract = abstract
self.year = year
self.doi = doi
self.url = url
self.citations = citations
self.venue = venue
self.institutions = institutions or []
self.venue_type = venue_type # 新增
self.venue_name = venue_name # 新增
self.venue_info = venue_info or {} # 新增
self.source = source # 新增: 存储论文来源
# 新增影响因子和分区信息,初始化为None
self._if_factor = None
self._cas_division = None
self._jcr_division = None
@property
def if_factor(self) -> Optional[float]:
"""获取影响因子"""
return self._if_factor
@if_factor.setter
def if_factor(self, value: float):
"""设置影响因子"""
self._if_factor = value
@property
def cas_division(self) -> Optional[str]:
"""获取中科院分区"""
return self._cas_division
@cas_division.setter
def cas_division(self, value: str):
"""设置中科院分区"""
self._cas_division = value
@property
def jcr_division(self) -> Optional[str]:
"""获取JCR分区"""
return self._jcr_division
@jcr_division.setter
def jcr_division(self, value: str):
"""设置JCR分区"""
self._jcr_division = value
class DataSource(ABC):
"""数据源基类"""
def __init__(self, api_key: Optional[str] = None):
self.api_key = api_key
self._initialize()
@abstractmethod
def _initialize(self) -> None:
"""初始化数据源"""
pass
@abstractmethod
async def search(self, query: str, limit: int = 100) -> List[PaperMetadata]:
"""搜索论文"""
pass
@abstractmethod
async def get_paper_details(self, paper_id: str) -> PaperMetadata:
"""获取论文详细信息"""
pass
@abstractmethod
async def get_citations(self, paper_id: str) -> List[PaperMetadata]:
"""获取引用该论文的文献"""
pass
@abstractmethod
async def get_references(self, paper_id: str) -> List[PaperMetadata]:
"""获取该论文引用的文献"""
pass

文件差异因一行或多行过长而隐藏

查看文件

@@ -0,0 +1,400 @@
import aiohttp
from typing import List, Dict, Optional
from datetime import datetime
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
import random
class CrossrefSource(DataSource):
"""Crossref API实现"""
CONTACT_EMAILS = [
"gpt_abc_academic@163.com",
"gpt_abc_newapi@163.com",
"gpt_abc_academic_pwd@163.com"
]
def _initialize(self) -> None:
"""初始化客户端,设置默认参数"""
self.base_url = "https://api.crossref.org"
# 随机选择一个邮箱
contact_email = random.choice(self.CONTACT_EMAILS)
self.headers = {
"Accept": "application/json",
"User-Agent": f"Mozilla/5.0 (compatible; PythonScript/1.0; mailto:{contact_email})",
}
if self.api_key:
self.headers["Crossref-Plus-API-Token"] = f"Bearer {self.api_key}"
async def search(
self,
query: str,
limit: int = 100,
sort_by: str = None,
sort_order: str = None,
start_year: int = None
) -> List[PaperMetadata]:
"""搜索论文
Args:
query: 搜索关键词
limit: 返回结果数量限制
sort_by: 排序字段
sort_order: 排序顺序
start_year: 起始年份
"""
async with aiohttp.ClientSession(headers=self.headers) as session:
# 请求更多的结果以补偿可能被过滤掉的文章
adjusted_limit = min(limit * 3, 1000) # 设置上限以避免请求过多
params = {
"query": query,
"rows": adjusted_limit,
"select": (
"DOI,title,author,published-print,abstract,reference,"
"container-title,is-referenced-by-count,type,"
"publisher,ISSN,ISBN,issue,volume,page"
)
}
# 添加年份过滤
if start_year:
params["filter"] = f"from-pub-date:{start_year}"
# 添加排序
if sort_by:
params["sort"] = sort_by
if sort_order:
params["order"] = sort_order
async with session.get(
f"{self.base_url}/works",
params=params
) as response:
if response.status != 200:
print(f"API请求失败: HTTP {response.status}")
print(f"响应内容: {await response.text()}")
return []
data = await response.json()
items = data.get("message", {}).get("items", [])
if not items:
print(f"未找到相关论文")
return []
# 过滤掉没有摘要的文章
papers = []
filtered_count = 0
for work in items:
paper = self._parse_work(work)
if paper.abstract and paper.abstract.strip():
papers.append(paper)
if len(papers) >= limit: # 达到原始请求的限制后停止
break
else:
filtered_count += 1
print(f"找到 {len(items)} 篇相关论文,其中 {filtered_count} 篇因缺少摘要被过滤")
print(f"返回 {len(papers)} 篇包含摘要的论文")
return papers
async def get_paper_details(self, doi: str) -> PaperMetadata:
"""获取指定DOI的论文详情"""
async with aiohttp.ClientSession(headers=self.headers) as session:
async with session.get(
f"{self.base_url}/works/{doi}",
params={
"select": (
"DOI,title,author,published-print,abstract,reference,"
"container-title,is-referenced-by-count,type,"
"publisher,ISSN,ISBN,issue,volume,page"
)
}
) as response:
if response.status != 200:
print(f"获取论文详情失败: HTTP {response.status}")
print(f"响应内容: {await response.text()}")
return None
try:
data = await response.json()
return self._parse_work(data.get("message", {}))
except Exception as e:
print(f"解析论文详情时发生错误: {str(e)}")
return None
async def get_references(self, doi: str) -> List[PaperMetadata]:
"""获取指定DOI论文的参考文献列表"""
async with aiohttp.ClientSession(headers=self.headers) as session:
async with session.get(
f"{self.base_url}/works/{doi}",
params={"select": "reference"}
) as response:
if response.status != 200:
print(f"获取参考文献失败: HTTP {response.status}")
return []
try:
data = await response.json()
# 确保我们正确处理返回的数据结构
if not isinstance(data, dict):
print(f"API返回了意外的数据格式: {type(data)}")
return []
references = data.get("message", {}).get("reference", [])
if not references:
print(f"未找到参考文献")
return []
return [
PaperMetadata(
title=ref.get("article-title", ""),
authors=[ref.get("author", "")],
year=ref.get("year"),
doi=ref.get("DOI"),
url=f"https://doi.org/{ref.get('DOI')}" if ref.get("DOI") else None,
abstract="",
citations=None,
venue=ref.get("journal-title", ""),
institutions=[]
)
for ref in references
]
except Exception as e:
print(f"解析参考文献数据时发生错误: {str(e)}")
return []
async def get_citations(self, doi: str) -> List[PaperMetadata]:
"""获取引用指定DOI论文的文献列表"""
async with aiohttp.ClientSession(headers=self.headers) as session:
async with session.get(
f"{self.base_url}/works",
params={
"filter": f"reference.DOI:{doi}",
"select": "DOI,title,author,published-print,abstract"
}
) as response:
if response.status != 200:
print(f"获取引用信息失败: HTTP {response.status}")
print(f"响应内容: {await response.text()}")
return []
try:
data = await response.json()
# 检查返回的数据结构
if isinstance(data, dict):
items = data.get("message", {}).get("items", [])
return [self._parse_work(work) for work in items]
else:
print(f"API返回了意外的数据格式: {type(data)}")
return []
except Exception as e:
print(f"解析引用数据时发生错误: {str(e)}")
return []
def _parse_work(self, work: Dict) -> PaperMetadata:
"""解析Crossref返回的数据"""
# 获取摘要 - 处理可能的不同格式
abstract = ""
if isinstance(work.get("abstract"), str):
abstract = work.get("abstract", "")
elif isinstance(work.get("abstract"), dict):
abstract = work.get("abstract", {}).get("value", "")
if not abstract:
print(f"警告: 论文 '{work.get('title', [''])[0]}' 没有可用的摘要")
# 获取机构信息
institutions = []
for author in work.get("author", []):
if "affiliation" in author:
for affiliation in author["affiliation"]:
if "name" in affiliation and affiliation["name"] not in institutions:
institutions.append(affiliation["name"])
# 获取venue信息
venue_name = work.get("container-title", [None])[0]
venue_type = work.get("type", "unknown") # 文献类型
venue_info = {
"publisher": work.get("publisher"),
"issn": work.get("ISSN", []),
"isbn": work.get("ISBN", []),
"issue": work.get("issue"),
"volume": work.get("volume"),
"page": work.get("page")
}
return PaperMetadata(
title=work.get("title", [None])[0] or "",
authors=[
author.get("given", "") + " " + author.get("family", "")
for author in work.get("author", [])
],
institutions=institutions, # 添加机构信息
abstract=abstract,
year=work.get("published-print", {}).get("date-parts", [[None]])[0][0],
doi=work.get("DOI"),
url=f"https://doi.org/{work.get('DOI')}" if work.get("DOI") else None,
citations=work.get("is-referenced-by-count"),
venue=venue_name,
venue_type=venue_type, # 添加venue类型
venue_name=venue_name, # 添加venue名称
venue_info=venue_info, # 添加venue详细信息
source='crossref' # 添加来源标记
)
async def search_by_authors(
self,
authors: List[str],
limit: int = 100,
sort_by: str = None,
start_year: int = None
) -> List[PaperMetadata]:
"""按作者搜索论文"""
query = " ".join([f"author:\"{author}\"" for author in authors])
return await self.search(
query=query,
limit=limit,
sort_by=sort_by,
start_year=start_year
)
async def search_by_date_range(
self,
start_date: datetime,
end_date: datetime,
limit: int = 100,
sort_by: str = None,
sort_order: str = None
) -> List[PaperMetadata]:
"""按日期范围搜索论文"""
query = f"from-pub-date:{start_date.strftime('%Y-%m-%d')} until-pub-date:{end_date.strftime('%Y-%m-%d')}"
return await self.search(
query=query,
limit=limit,
sort_by=sort_by,
sort_order=sort_order
)
async def example_usage():
"""CrossrefSource使用示例"""
crossref = CrossrefSource(api_key=None)
try:
# 示例1基本搜索,使用不同的排序方式
print("\n=== 示例1搜索最新的机器学习论文 ===")
papers = await crossref.search(
query="machine learning",
limit=3,
sort_by="published",
sort_order="desc",
start_year=2023
)
for i, paper in enumerate(papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
print(f"URL: {paper.url}")
if paper.abstract:
print(f"摘要: {paper.abstract[:200]}...")
if paper.institutions:
print(f"机构: {', '.join(paper.institutions)}")
print(f"引用次数: {paper.citations}")
print(f"发表venue: {paper.venue}")
print(f"venue类型: {paper.venue_type}")
if paper.venue_info:
print("Venue详细信息:")
for key, value in paper.venue_info.items():
if value:
print(f" - {key}: {value}")
# 示例2按DOI获取论文详情
print("\n=== 示例2获取特定论文详情 ===")
# 使用BERT论文的DOI
doi = "10.18653/v1/N19-1423"
paper = await crossref.get_paper_details(doi)
if paper:
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
if paper.abstract:
print(f"摘要: {paper.abstract[:200]}...")
print(f"引用次数: {paper.citations}")
# 示例3按作者搜索
print("\n=== 示例3搜索特定作者的论文 ===")
author_papers = await crossref.search_by_authors(
authors=["Yoshua Bengio"],
limit=3,
sort_by="published",
start_year=2020
)
for i, paper in enumerate(author_papers, 1):
print(f"\n--- {i}. {paper.title} ---")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
print(f"引用次数: {paper.citations}")
# 示例4按日期范围搜索
print("\n=== 示例4搜索特定日期范围的论文 ===")
from datetime import datetime, timedelta
end_date = datetime.now()
start_date = end_date - timedelta(days=30) # 最近一个月
recent_papers = await crossref.search_by_date_range(
start_date=start_date,
end_date=end_date,
limit=3,
sort_by="published",
sort_order="desc"
)
for i, paper in enumerate(recent_papers, 1):
print(f"\n--- 最近发表的论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
# 示例5获取论文引用信息
print("\n=== 示例5获取论文引用信息 ===")
if paper: # 使用之前获取的BERT论文
print("\n获取引用该论文的文献:")
citations = await crossref.get_citations(paper.doi)
for i, citing_paper in enumerate(citations[:3], 1):
print(f"\n--- 引用论文 {i} ---")
print(f"标题: {citing_paper.title}")
print(f"作者: {', '.join(citing_paper.authors)}")
print(f"发表年份: {citing_paper.year}")
print("\n获取该论文引用的参考文献:")
references = await crossref.get_references(paper.doi)
for i, ref_paper in enumerate(references[:3], 1):
print(f"\n--- 参考文献 {i} ---")
print(f"标题: {ref_paper.title}")
print(f"作者: {', '.join(ref_paper.authors)}")
print(f"发表年份: {ref_paper.year if ref_paper.year else '未知'}")
# 示例6展示venue信息的使用
print("\n=== 示例6展示期刊/会议详细信息 ===")
if papers:
paper = papers[0]
print(f"文献类型: {paper.venue_type}")
print(f"发表venue: {paper.venue_name}")
if paper.venue_info:
print("Venue详细信息:")
for key, value in paper.venue_info.items():
if value:
print(f" - {key}: {value}")
except Exception as e:
print(f"发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
import asyncio
# 运行示例
asyncio.run(example_usage())

查看文件

@@ -0,0 +1,449 @@
from typing import List, Optional, Dict, Union
from datetime import datetime
import aiohttp
import asyncio
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
import json
from tqdm import tqdm
import random
class ElsevierSource(DataSource):
"""Elsevier (Scopus) API实现"""
# 定义API密钥列表
API_KEYS = [
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
]
def __init__(self, api_key: str = None):
"""初始化
Args:
api_key: Elsevier API密钥,如果不提供则从预定义列表中随机选择
"""
self.api_key = api_key or random.choice(self.API_KEYS)
self._initialize()
def _initialize(self) -> None:
"""初始化基础URL和请求头"""
self.base_url = "https://api.elsevier.com/content"
self.headers = {
"X-ELS-APIKey": self.api_key,
"Accept": "application/json",
"Content-Type": "application/json",
# 添加更多必要的头部信息
"X-ELS-Insttoken": "", # 如果有机构令牌
}
async def _make_request(self, url: str, params: Dict = None) -> Optional[Dict]:
"""发送HTTP请求
Args:
url: 请求URL
params: 查询参数
Returns:
JSON响应
"""
try:
async with aiohttp.ClientSession(headers=self.headers) as session:
async with session.get(url, params=params) as response:
if response.status == 200:
return await response.json()
else:
# 添加更详细的错误信息
error_text = await response.text()
print(f"请求失败: {response.status}")
print(f"错误详情: {error_text}")
if response.status == 401:
print(f"使用的API密钥: {self.api_key}")
# 尝试切换到另一个API密钥
new_key = random.choice([k for k in self.API_KEYS if k != self.api_key])
print(f"尝试切换到新的API密钥: {new_key}")
self.api_key = new_key
self.headers["X-ELS-APIKey"] = new_key
# 重试请求
return await self._make_request(url, params)
return None
except Exception as e:
print(f"请求发生错误: {str(e)}")
return None
async def search(
self,
query: str,
limit: int = 100,
sort_by: str = "relevance",
start_year: int = None
) -> List[PaperMetadata]:
"""搜索论文"""
try:
params = {
"query": query,
"count": min(limit, 100),
"view": "STANDARD",
# 移除dc:description字段,因为它在STANDARD视图中不可用
"field": "dc:title,dc:creator,prism:doi,prism:coverDate,citedby-count,prism:publicationName"
}
# 添加年份过滤
if start_year:
params["date"] = f"{start_year}-present"
# 添加排序
if sort_by == "date":
params["sort"] = "-coverDate"
elif sort_by == "cited":
params["sort"] = "-citedby-count"
# 发送搜索请求
response = await self._make_request(
f"{self.base_url}/search/scopus",
params=params
)
if not response or "search-results" not in response:
return []
# 解析搜索结果
entries = response["search-results"].get("entry", [])
papers = [paper for paper in (self._parse_entry(entry) for entry in entries) if paper is not None]
# 尝试为每篇论文获取摘要
for paper in papers:
if paper.doi:
paper.abstract = await self.fetch_abstract(paper.doi) or ""
return papers
except Exception as e:
print(f"搜索论文时发生错误: {str(e)}")
return []
def _parse_entry(self, entry: Dict) -> Optional[PaperMetadata]:
"""解析Scopus API返回的条目"""
try:
# 获取作者列表
authors = []
creator = entry.get("dc:creator")
if creator:
authors = [creator]
# 获取发表年份
year = None
if "prism:coverDate" in entry:
try:
year = int(entry["prism:coverDate"][:4])
except:
pass
# 简化venue信息
venue_info = {
'source_id': entry.get("source-id"),
'issn': entry.get("prism:issn")
}
return PaperMetadata(
title=entry.get("dc:title", ""),
authors=authors,
abstract=entry.get("dc:description", ""), # 从响应中获取摘要
year=year,
doi=entry.get("prism:doi"),
url=entry.get("prism:url"),
citations=int(entry.get("citedby-count", 0)),
venue=entry.get("prism:publicationName"),
institutions=[], # 移除机构信息
venue_type="",
venue_name=entry.get("prism:publicationName"),
venue_info=venue_info
)
except Exception as e:
print(f"解析条目时发生错误: {str(e)}")
return None
async def get_citations(self, doi: str, limit: int = 100) -> List[PaperMetadata]:
"""获取引用该论文的文献"""
try:
params = {
"query": f"REF({doi})",
"count": min(limit, 100),
"view": "STANDARD"
}
response = await self._make_request(
f"{self.base_url}/search/scopus",
params=params
)
if not response or "search-results" not in response:
return []
entries = response["search-results"].get("entry", [])
return [self._parse_entry(entry) for entry in entries]
except Exception as e:
print(f"获取引用文献时发生错误: {str(e)}")
return []
async def get_references(self, doi: str) -> List[PaperMetadata]:
"""获取该论文引用的文献"""
try:
response = await self._make_request(
f"{self.base_url}/abstract/doi/{doi}/references",
params={"view": "STANDARD"}
)
if not response or "references" not in response:
return []
references = response["references"].get("reference", [])
papers = [paper for paper in (self._parse_reference(ref) for ref in references) if paper is not None]
return papers
except Exception as e:
print(f"获取参考文献时发生错误: {str(e)}")
return []
def _parse_reference(self, ref: Dict) -> Optional[PaperMetadata]:
"""解析参考文献数据"""
try:
authors = []
if "author-list" in ref:
author_list = ref["author-list"].get("author", [])
if isinstance(author_list, list):
authors = [f"{author.get('ce:given-name', '')} {author.get('ce:surname', '')}"
for author in author_list]
else:
authors = [f"{author_list.get('ce:given-name', '')} {author_list.get('ce:surname', '')}"]
year = None
if "prism:coverDate" in ref:
try:
year = int(ref["prism:coverDate"][:4])
except:
pass
return PaperMetadata(
title=ref.get("ce:title", ""),
authors=authors,
abstract="", # 参考文献通常不包含摘要
year=year,
doi=ref.get("prism:doi"),
url=None,
citations=None,
venue=ref.get("prism:publicationName"),
institutions=[],
venue_type="unknown",
venue_name=ref.get("prism:publicationName"),
venue_info={}
)
except Exception as e:
print(f"解析参考文献时发生错误: {str(e)}")
return None
async def search_by_author(
self,
author: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""按作者搜索论文"""
query = f"AUTHOR-NAME({author})"
return await self.search(query, limit=limit, start_year=start_year)
async def search_by_affiliation(
self,
affiliation: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""按机构搜索论文"""
query = f"AF-ID({affiliation})"
return await self.search(query, limit=limit, start_year=start_year)
async def search_by_venue(
self,
venue: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""按期刊/会议搜索论文"""
query = f"SRCTITLE({venue})"
return await self.search(query, limit=limit, start_year=start_year)
async def test_api_access(self):
"""测试API访问权限"""
print(f"\n测试API密钥: {self.api_key}")
# 测试1: 基础搜索
basic_params = {
"query": "test",
"count": 1,
"view": "STANDARD"
}
print("\n1. 测试基础搜索...")
response = await self._make_request(
f"{self.base_url}/search/scopus",
params=basic_params
)
if response:
print("基础搜索成功")
print("可用字段:", list(response.get("search-results", {}).get("entry", [{}])[0].keys()))
# 测试2: 测试单篇文章访问
print("\n2. 测试文章详情访问...")
test_doi = "10.1016/j.artint.2021.103535" # 一个示例DOI
response = await self._make_request(
f"{self.base_url}/abstract/doi/{test_doi}",
params={"view": "STANDARD"} # 改为STANDARD视图
)
if response:
print("文章详情访问成功")
else:
print("文章详情访问失败")
async def get_paper_details(self, paper_id: str) -> Optional[PaperMetadata]:
"""获取论文详细信息
注意当前API权限不支持获取详细信息,返回None
Args:
paper_id: 论文ID
Returns:
None,因为当前API权限不支持此功能
"""
return None
async def fetch_abstract(self, doi: str) -> Optional[str]:
"""获取论文摘要
使用Scopus Abstract API获取论文摘要
Args:
doi: 论文的DOI
Returns:
摘要文本,如果获取失败则返回None
"""
try:
# 使用Abstract API而不是Search API
response = await self._make_request(
f"{self.base_url}/abstract/doi/{doi}",
params={
"view": "FULL" # 使用FULL视图
}
)
if response and "abstracts-retrieval-response" in response:
# 从coredata中获取摘要
coredata = response["abstracts-retrieval-response"].get("coredata", {})
return coredata.get("dc:description", "")
return None
except Exception as e:
print(f"获取摘要时发生错误: {str(e)}")
return None
async def example_usage():
"""ElsevierSource使用示例"""
elsevier = ElsevierSource()
try:
# 首先测试API访问权限
print("\n=== 测试API访问权限 ===")
await elsevier.test_api_access()
# 示例1基本搜索
print("\n=== 示例1搜索机器学习相关论文 ===")
papers = await elsevier.search("machine learning", limit=3)
for i, paper in enumerate(papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
print(f"URL: {paper.url}")
print(f"引用次数: {paper.citations}")
print(f"期刊/会议: {paper.venue}")
print("期刊信息:")
for key, value in paper.venue_info.items():
if value: # 只打印非空值
print(f" - {key}: {value}")
# 示例2获取引用信息
if papers and papers[0].doi:
print("\n=== 示例2获取引用该论文的文献 ===")
citations = await elsevier.get_citations(papers[0].doi, limit=3)
for i, paper in enumerate(citations, 1):
print(f"\n--- 引用论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
print(f"引用次数: {paper.citations}")
print(f"期刊/会议: {paper.venue}")
# 示例3获取参考文献
if papers and papers[0].doi:
print("\n=== 示例3获取论文的参考文献 ===")
references = await elsevier.get_references(papers[0].doi)
for i, paper in enumerate(references[:3], 1):
print(f"\n--- 参考文献 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
print(f"期刊/会议: {paper.venue}")
# 示例4按作者搜索
print("\n=== 示例4按作者搜索 ===")
author_papers = await elsevier.search_by_author("Hinton G", limit=3)
for i, paper in enumerate(author_papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
print(f"引用次数: {paper.citations}")
print(f"期刊/会议: {paper.venue}")
# 示例5按机构搜索
print("\n=== 示例5按机构搜索 ===")
affiliation_papers = await elsevier.search_by_affiliation("60027950", limit=3) # MIT的机构ID
for i, paper in enumerate(affiliation_papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
print(f"引用次数: {paper.citations}")
print(f"期刊/会议: {paper.venue}")
# 示例6获取论文摘要
print("\n=== 示例6获取论文摘要 ===")
test_doi = "10.1016/j.artint.2021.103535"
abstract = await elsevier.fetch_abstract(test_doi)
if abstract:
print(f"摘要: {abstract[:200]}...") # 只显示前200个字符
else:
print("无法获取摘要")
# 在搜索结果中显示摘要
print("\n=== 示例7搜索结果中的摘要 ===")
papers = await elsevier.search("machine learning", limit=1)
for paper in papers:
print(f"标题: {paper.title}")
print(f"摘要: {paper.abstract[:200]}..." if paper.abstract else "摘要: 无")
except Exception as e:
print(f"发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
asyncio.run(example_usage())

查看文件

@@ -0,0 +1,698 @@
import aiohttp
import asyncio
import base64
import json
import random
from datetime import datetime
from typing import List, Dict, Optional, Union, Any
class GitHubSource:
"""GitHub API实现"""
# 默认API密钥列表 - 可以放置多个GitHub令牌
API_KEYS = [
"github_pat_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"github_pat_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
]
def __init__(self, api_key: Optional[Union[str, List[str]]] = None):
"""初始化GitHub API客户端
Args:
api_key: GitHub个人访问令牌或令牌列表
"""
if api_key is None:
self.api_keys = self.API_KEYS
elif isinstance(api_key, str):
self.api_keys = [api_key]
else:
self.api_keys = api_key
self._initialize()
def _initialize(self) -> None:
"""初始化客户端,设置默认参数"""
self.base_url = "https://api.github.com"
self.headers = {
"Accept": "application/vnd.github+json",
"X-GitHub-Api-Version": "2022-11-28",
"User-Agent": "GitHub-API-Python-Client"
}
# 如果有可用的API密钥,随机选择一个
if self.api_keys:
selected_key = random.choice(self.api_keys)
self.headers["Authorization"] = f"Bearer {selected_key}"
print(f"已随机选择API密钥进行认证")
else:
print("警告: 未提供API密钥,将受到GitHub API请求限制")
async def _request(self, method: str, endpoint: str, params: Dict = None, data: Dict = None) -> Any:
"""发送API请求
Args:
method: HTTP方法 (GET, POST, PUT, DELETE等)
endpoint: API端点
params: URL参数
data: 请求体数据
Returns:
解析后的响应JSON
"""
async with aiohttp.ClientSession(headers=self.headers) as session:
url = f"{self.base_url}{endpoint}"
# 为调试目的打印请求信息
print(f"请求: {method} {url}")
if params:
print(f"参数: {params}")
# 发送请求
request_kwargs = {}
if params:
request_kwargs["params"] = params
if data:
request_kwargs["json"] = data
async with session.request(method, url, **request_kwargs) as response:
response_text = await response.text()
# 检查HTTP状态码
if response.status >= 400:
print(f"API请求失败: HTTP {response.status}")
print(f"响应内容: {response_text}")
return None
# 解析JSON响应
try:
return json.loads(response_text)
except json.JSONDecodeError:
print(f"JSON解析错误: {response_text}")
return None
# ===== 用户相关方法 =====
async def get_user(self, username: Optional[str] = None) -> Dict:
"""获取用户信息
Args:
username: 指定用户名,不指定则获取当前授权用户
Returns:
用户信息字典
"""
endpoint = "/user" if username is None else f"/users/{username}"
return await self._request("GET", endpoint)
async def get_user_repos(self, username: Optional[str] = None, sort: str = "updated",
direction: str = "desc", per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取用户的仓库列表
Args:
username: 指定用户名,不指定则获取当前授权用户
sort: 排序方式 (created, updated, pushed, full_name)
direction: 排序方向 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
仓库列表
"""
endpoint = "/user/repos" if username is None else f"/users/{username}/repos"
params = {
"sort": sort,
"direction": direction,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def get_user_starred(self, username: Optional[str] = None,
per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取用户星标的仓库
Args:
username: 指定用户名,不指定则获取当前授权用户
per_page: 每页结果数量
page: 页码
Returns:
星标仓库列表
"""
endpoint = "/user/starred" if username is None else f"/users/{username}/starred"
params = {
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
# ===== 仓库相关方法 =====
async def get_repo(self, owner: str, repo: str) -> Dict:
"""获取仓库信息
Args:
owner: 仓库所有者
repo: 仓库名
Returns:
仓库信息
"""
endpoint = f"/repos/{owner}/{repo}"
return await self._request("GET", endpoint)
async def get_repo_branches(self, owner: str, repo: str, per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取仓库的分支列表
Args:
owner: 仓库所有者
repo: 仓库名
per_page: 每页结果数量
page: 页码
Returns:
分支列表
"""
endpoint = f"/repos/{owner}/{repo}/branches"
params = {
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def get_repo_commits(self, owner: str, repo: str, sha: Optional[str] = None,
path: Optional[str] = None, per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取仓库的提交历史
Args:
owner: 仓库所有者
repo: 仓库名
sha: 特定提交SHA或分支名
path: 文件路径筛选
per_page: 每页结果数量
page: 页码
Returns:
提交列表
"""
endpoint = f"/repos/{owner}/{repo}/commits"
params = {
"per_page": per_page,
"page": page
}
if sha:
params["sha"] = sha
if path:
params["path"] = path
return await self._request("GET", endpoint, params=params)
async def get_commit_details(self, owner: str, repo: str, commit_sha: str) -> Dict:
"""获取特定提交的详情
Args:
owner: 仓库所有者
repo: 仓库名
commit_sha: 提交SHA
Returns:
提交详情
"""
endpoint = f"/repos/{owner}/{repo}/commits/{commit_sha}"
return await self._request("GET", endpoint)
# ===== 内容相关方法 =====
async def get_file_content(self, owner: str, repo: str, path: str, ref: Optional[str] = None) -> Dict:
"""获取文件内容
Args:
owner: 仓库所有者
repo: 仓库名
path: 文件路径
ref: 分支名、标签名或提交SHA
Returns:
文件内容信息
"""
endpoint = f"/repos/{owner}/{repo}/contents/{path}"
params = {}
if ref:
params["ref"] = ref
response = await self._request("GET", endpoint, params=params)
if response and isinstance(response, dict) and "content" in response:
try:
# 解码Base64编码的文件内容
content = base64.b64decode(response["content"].encode()).decode()
response["decoded_content"] = content
except Exception as e:
print(f"解码文件内容时出错: {str(e)}")
return response
async def get_directory_content(self, owner: str, repo: str, path: str, ref: Optional[str] = None) -> List[Dict]:
"""获取目录内容
Args:
owner: 仓库所有者
repo: 仓库名
path: 目录路径
ref: 分支名、标签名或提交SHA
Returns:
目录内容列表
"""
# 注意此方法与get_file_content使用相同的端点,但对于目录会返回列表
endpoint = f"/repos/{owner}/{repo}/contents/{path}"
params = {}
if ref:
params["ref"] = ref
return await self._request("GET", endpoint, params=params)
# ===== Issues相关方法 =====
async def get_issues(self, owner: str, repo: str, state: str = "open",
sort: str = "created", direction: str = "desc",
per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取仓库的Issues列表
Args:
owner: 仓库所有者
repo: 仓库名
state: Issue状态 (open, closed, all)
sort: 排序方式 (created, updated, comments)
direction: 排序方向 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
Issues列表
"""
endpoint = f"/repos/{owner}/{repo}/issues"
params = {
"state": state,
"sort": sort,
"direction": direction,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def get_issue(self, owner: str, repo: str, issue_number: int) -> Dict:
"""获取特定Issue的详情
Args:
owner: 仓库所有者
repo: 仓库名
issue_number: Issue编号
Returns:
Issue详情
"""
endpoint = f"/repos/{owner}/{repo}/issues/{issue_number}"
return await self._request("GET", endpoint)
async def get_issue_comments(self, owner: str, repo: str, issue_number: int) -> List[Dict]:
"""获取Issue的评论
Args:
owner: 仓库所有者
repo: 仓库名
issue_number: Issue编号
Returns:
评论列表
"""
endpoint = f"/repos/{owner}/{repo}/issues/{issue_number}/comments"
return await self._request("GET", endpoint)
# ===== Pull Requests相关方法 =====
async def get_pull_requests(self, owner: str, repo: str, state: str = "open",
sort: str = "created", direction: str = "desc",
per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取仓库的Pull Request列表
Args:
owner: 仓库所有者
repo: 仓库名
state: PR状态 (open, closed, all)
sort: 排序方式 (created, updated, popularity, long-running)
direction: 排序方向 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
Pull Request列表
"""
endpoint = f"/repos/{owner}/{repo}/pulls"
params = {
"state": state,
"sort": sort,
"direction": direction,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def get_pull_request(self, owner: str, repo: str, pr_number: int) -> Dict:
"""获取特定Pull Request的详情
Args:
owner: 仓库所有者
repo: 仓库名
pr_number: Pull Request编号
Returns:
Pull Request详情
"""
endpoint = f"/repos/{owner}/{repo}/pulls/{pr_number}"
return await self._request("GET", endpoint)
async def get_pull_request_files(self, owner: str, repo: str, pr_number: int) -> List[Dict]:
"""获取Pull Request中修改的文件
Args:
owner: 仓库所有者
repo: 仓库名
pr_number: Pull Request编号
Returns:
修改文件列表
"""
endpoint = f"/repos/{owner}/{repo}/pulls/{pr_number}/files"
return await self._request("GET", endpoint)
# ===== 搜索相关方法 =====
async def search_repositories(self, query: str, sort: str = "stars",
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
"""搜索仓库
Args:
query: 搜索关键词
sort: 排序方式 (stars, forks, updated)
order: 排序顺序 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
搜索结果
"""
endpoint = "/search/repositories"
params = {
"q": query,
"sort": sort,
"order": order,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def search_code(self, query: str, sort: str = "indexed",
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
"""搜索代码
Args:
query: 搜索关键词
sort: 排序方式 (indexed)
order: 排序顺序 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
搜索结果
"""
endpoint = "/search/code"
params = {
"q": query,
"sort": sort,
"order": order,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def search_issues(self, query: str, sort: str = "created",
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
"""搜索Issues和Pull Requests
Args:
query: 搜索关键词
sort: 排序方式 (created, updated, comments)
order: 排序顺序 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
搜索结果
"""
endpoint = "/search/issues"
params = {
"q": query,
"sort": sort,
"order": order,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def search_users(self, query: str, sort: str = "followers",
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
"""搜索用户
Args:
query: 搜索关键词
sort: 排序方式 (followers, repositories, joined)
order: 排序顺序 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
搜索结果
"""
endpoint = "/search/users"
params = {
"q": query,
"sort": sort,
"order": order,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
# ===== 组织相关方法 =====
async def get_organization(self, org: str) -> Dict:
"""获取组织信息
Args:
org: 组织名称
Returns:
组织信息
"""
endpoint = f"/orgs/{org}"
return await self._request("GET", endpoint)
async def get_organization_repos(self, org: str, type: str = "all",
sort: str = "created", direction: str = "desc",
per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取组织的仓库列表
Args:
org: 组织名称
type: 仓库类型 (all, public, private, forks, sources, member, internal)
sort: 排序方式 (created, updated, pushed, full_name)
direction: 排序方向 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
仓库列表
"""
endpoint = f"/orgs/{org}/repos"
params = {
"type": type,
"sort": sort,
"direction": direction,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def get_organization_members(self, org: str, per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取组织成员列表
Args:
org: 组织名称
per_page: 每页结果数量
page: 页码
Returns:
成员列表
"""
endpoint = f"/orgs/{org}/members"
params = {
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
# ===== 更复杂的操作 =====
async def get_repository_languages(self, owner: str, repo: str) -> Dict:
"""获取仓库使用的编程语言及其比例
Args:
owner: 仓库所有者
repo: 仓库名
Returns:
语言使用情况
"""
endpoint = f"/repos/{owner}/{repo}/languages"
return await self._request("GET", endpoint)
async def get_repository_stats_contributors(self, owner: str, repo: str) -> List[Dict]:
"""获取仓库的贡献者统计
Args:
owner: 仓库所有者
repo: 仓库名
Returns:
贡献者统计信息
"""
endpoint = f"/repos/{owner}/{repo}/stats/contributors"
return await self._request("GET", endpoint)
async def get_repository_stats_commit_activity(self, owner: str, repo: str) -> List[Dict]:
"""获取仓库的提交活动
Args:
owner: 仓库所有者
repo: 仓库名
Returns:
提交活动统计
"""
endpoint = f"/repos/{owner}/{repo}/stats/commit_activity"
return await self._request("GET", endpoint)
async def example_usage():
"""GitHubSource使用示例"""
# 创建客户端实例可选传入API令牌
# github = GitHubSource(api_key="your_github_token")
github = GitHubSource()
try:
# 示例1搜索热门Python仓库
print("\n=== 示例1搜索热门Python仓库 ===")
repos = await github.search_repositories(
query="language:python stars:>1000",
sort="stars",
order="desc",
per_page=5
)
if repos and "items" in repos:
for i, repo in enumerate(repos["items"], 1):
print(f"\n--- 仓库 {i} ---")
print(f"名称: {repo['full_name']}")
print(f"描述: {repo['description']}")
print(f"星标数: {repo['stargazers_count']}")
print(f"Fork数: {repo['forks_count']}")
print(f"最近更新: {repo['updated_at']}")
print(f"URL: {repo['html_url']}")
# 示例2获取特定仓库的详情
print("\n=== 示例2获取特定仓库的详情 ===")
repo_details = await github.get_repo("microsoft", "vscode")
if repo_details:
print(f"名称: {repo_details['full_name']}")
print(f"描述: {repo_details['description']}")
print(f"星标数: {repo_details['stargazers_count']}")
print(f"Fork数: {repo_details['forks_count']}")
print(f"默认分支: {repo_details['default_branch']}")
print(f"开源许可: {repo_details.get('license', {}).get('name', '')}")
print(f"语言: {repo_details['language']}")
print(f"Open Issues数: {repo_details['open_issues_count']}")
# 示例3获取仓库的提交历史
print("\n=== 示例3获取仓库的最近提交 ===")
commits = await github.get_repo_commits("tensorflow", "tensorflow", per_page=5)
if commits:
for i, commit in enumerate(commits, 1):
print(f"\n--- 提交 {i} ---")
print(f"SHA: {commit['sha'][:7]}")
print(f"作者: {commit['commit']['author']['name']}")
print(f"日期: {commit['commit']['author']['date']}")
print(f"消息: {commit['commit']['message'].splitlines()[0]}")
# 示例4搜索代码
print("\n=== 示例4搜索代码 ===")
code_results = await github.search_code(
query="filename:README.md language:markdown pytorch in:file",
per_page=3
)
if code_results and "items" in code_results:
print(f"共找到: {code_results['total_count']} 个结果")
for i, item in enumerate(code_results["items"], 1):
print(f"\n--- 代码 {i} ---")
print(f"仓库: {item['repository']['full_name']}")
print(f"文件: {item['path']}")
print(f"URL: {item['html_url']}")
# 示例5获取文件内容
print("\n=== 示例5获取文件内容 ===")
file_content = await github.get_file_content("python", "cpython", "README.rst")
if file_content and "decoded_content" in file_content:
content = file_content["decoded_content"]
print(f"文件名: {file_content['name']}")
print(f"大小: {file_content['size']} 字节")
print(f"内容预览: {content[:200]}...")
# 示例6获取仓库使用的编程语言
print("\n=== 示例6获取仓库使用的编程语言 ===")
languages = await github.get_repository_languages("facebook", "react")
if languages:
print(f"React仓库使用的编程语言:")
for lang, bytes_of_code in languages.items():
print(f"- {lang}: {bytes_of_code} 字节")
# 示例7获取组织信息
print("\n=== 示例7获取组织信息 ===")
org_info = await github.get_organization("google")
if org_info:
print(f"名称: {org_info['name']}")
print(f"描述: {org_info.get('description', '')}")
print(f"位置: {org_info.get('location', '未指定')}")
print(f"公共仓库数: {org_info['public_repos']}")
print(f"成员数: {org_info.get('public_members', 0)}")
print(f"URL: {org_info['html_url']}")
# 示例8获取用户信息
print("\n=== 示例8获取用户信息 ===")
user_info = await github.get_user("torvalds")
if user_info:
print(f"名称: {user_info['name']}")
print(f"公司: {user_info.get('company', '')}")
print(f"博客: {user_info.get('blog', '')}")
print(f"位置: {user_info.get('location', '未指定')}")
print(f"公共仓库数: {user_info['public_repos']}")
print(f"关注者数: {user_info['followers']}")
print(f"URL: {user_info['html_url']}")
except Exception as e:
print(f"发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
import asyncio
# 运行示例
asyncio.run(example_usage())

查看文件

@@ -0,0 +1,142 @@
import json
import os
from typing import Dict, Optional
class JournalMetrics:
"""期刊指标管理类"""
def __init__(self):
self.journal_data: Dict = {} # 期刊名称到指标的映射
self.issn_map: Dict = {} # ISSN到指标的映射
self.name_map: Dict = {} # 标准化名称到指标的映射
self._load_journal_data()
def _normalize_journal_name(self, name: str) -> str:
"""标准化期刊名称
Args:
name: 原始期刊名称
Returns:
标准化后的期刊名称
"""
if not name:
return ""
# 转换为小写
name = name.lower()
# 移除常见的前缀和后缀
prefixes = ['the ', 'proceedings of ', 'journal of ']
suffixes = [' journal', ' proceedings', ' magazine', ' review', ' letters']
for prefix in prefixes:
if name.startswith(prefix):
name = name[len(prefix):]
for suffix in suffixes:
if name.endswith(suffix):
name = name[:-len(suffix)]
# 移除特殊字符,保留字母、数字和空格
name = ''.join(c for c in name if c.isalnum() or c.isspace())
# 移除多余的空格
name = ' '.join(name.split())
return name
def _convert_if_value(self, if_str: str) -> Optional[float]:
"""转换IF值为float,处理特殊情况"""
try:
if if_str.startswith('<'):
# 对于<0.1这样的值,返回0.1
return float(if_str.strip('<'))
return float(if_str)
except (ValueError, AttributeError):
return None
def _load_journal_data(self):
"""加载期刊数据"""
try:
file_path = os.path.join(os.path.dirname(__file__), 'cas_if.json')
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 建立期刊名称到指标的映射
for journal in data:
# 准备指标数据
metrics = {
'if_factor': self._convert_if_value(journal.get('IF')),
'jcr_division': journal.get('Q'),
'cas_division': journal.get('B')
}
# 存储期刊名称映射(使用标准化名称)
if journal.get('journal'):
normalized_name = self._normalize_journal_name(journal['journal'])
self.journal_data[normalized_name] = metrics
self.name_map[normalized_name] = metrics
# 存储期刊缩写映射
if journal.get('jabb'):
normalized_abbr = self._normalize_journal_name(journal['jabb'])
self.journal_data[normalized_abbr] = metrics
self.name_map[normalized_abbr] = metrics
# 存储ISSN映射
if journal.get('issn'):
self.issn_map[journal['issn']] = metrics
if journal.get('eissn'):
self.issn_map[journal['eissn']] = metrics
except Exception as e:
print(f"加载期刊数据时出错: {str(e)}")
self.journal_data = {}
self.issn_map = {}
self.name_map = {}
def get_journal_metrics(self, venue_name: str, venue_info: dict) -> dict:
"""获取期刊指标
Args:
venue_name: 期刊名称
venue_info: 期刊详细信息
Returns:
包含期刊指标的字典
"""
try:
metrics = {}
# 1. 首先尝试通过ISSN匹配
if venue_info and 'issn' in venue_info:
issn_value = venue_info['issn']
# 处理ISSN可能是列表的情况
if isinstance(issn_value, list):
# 尝试每个ISSN
for issn in issn_value:
metrics = self.issn_map.get(issn, {})
if metrics: # 如果找到匹配的指标,就停止搜索
break
else: # ISSN是字符串的情况
metrics = self.issn_map.get(issn_value, {})
# 2. 如果ISSN匹配失败,尝试通过期刊名称匹配
if not metrics and venue_name:
# 标准化期刊名称
normalized_name = self._normalize_journal_name(venue_name)
metrics = self.name_map.get(normalized_name, {})
# 如果完全匹配失败,尝试部分匹配
# if not metrics:
# for db_name, db_metrics in self.name_map.items():
# if normalized_name in db_name:
# metrics = db_metrics
# break
return metrics
except Exception as e:
print(f"获取期刊指标时出错: {str(e)}")
return {}

查看文件

@@ -0,0 +1,163 @@
import aiohttp
from typing import List, Dict, Optional
from datetime import datetime
from .base_source import DataSource, PaperMetadata
import os
from urllib.parse import quote
class OpenAlexSource(DataSource):
"""OpenAlex API实现"""
def _initialize(self) -> None:
self.base_url = "https://api.openalex.org"
self.mailto = "xxxxxxxxxxxxxxxxxxxxxxxx@163.com" # 直接写入邮件地址
async def search(self, query: str, limit: int = 100) -> List[PaperMetadata]:
params = {"mailto": self.mailto} if self.mailto else {}
params.update({
"filter": f"title.search:{query}",
"per-page": limit
})
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.base_url}/works",
params=params
) as response:
try:
response.raise_for_status()
data = await response.json()
results = data.get("results", [])
return [self._parse_work(work) for work in results]
except Exception as e:
print(f"搜索出错: {str(e)}")
return []
def _parse_work(self, work: Dict) -> PaperMetadata:
"""解析OpenAlex返回的数据"""
# 获取作者信息
raw_author_names = [
authorship.get("raw_author_name", "")
for authorship in work.get("authorships", [])
if authorship
]
# 处理作者名字格式
authors = [
self._reformat_name(author)
for author in raw_author_names
]
# 获取机构信息
institutions = [
inst.get("display_name", "")
for authorship in work.get("authorships", [])
for inst in authorship.get("institutions", [])
if inst
]
# 获取主要发表位置信息
primary_location = work.get("primary_location") or {}
source = primary_location.get("source") or {}
venue = source.get("display_name")
# 获取发表日期
year = work.get("publication_year")
return PaperMetadata(
title=work.get("title", ""),
authors=authors,
institutions=institutions,
abstract=work.get("abstract", ""),
year=year,
doi=work.get("doi"),
url=work.get("doi"), # OpenAlex 使用 DOI 作为 URL
citations=work.get("cited_by_count"),
venue=venue
)
def _reformat_name(self, name: str) -> str:
"""重新格式化作者名字"""
if "," not in name:
return name
family, given_names = (x.strip() for x in name.split(",", maxsplit=1))
return f"{given_names} {family}"
async def get_paper_details(self, doi: str) -> PaperMetadata:
"""获取指定DOI的论文详情"""
params = {"mailto": self.mailto} if self.mailto else {}
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.base_url}/works/https://doi.org/{quote(doi, safe='')}",
params=params
) as response:
data = await response.json()
return self._parse_work(data)
async def get_references(self, doi: str) -> List[PaperMetadata]:
"""获取指定DOI论文的参考文献列表"""
params = {"mailto": self.mailto} if self.mailto else {}
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.base_url}/works/https://doi.org/{quote(doi, safe='')}/references",
params=params
) as response:
data = await response.json()
return [self._parse_work(work) for work in data.get("results", [])]
async def get_citations(self, doi: str) -> List[PaperMetadata]:
"""获取引用指定DOI论文的文献列表"""
params = {"mailto": self.mailto} if self.mailto else {}
params.update({
"filter": f"cites:doi:{doi}",
"per-page": 100
})
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.base_url}/works",
params=params
) as response:
data = await response.json()
return [self._parse_work(work) for work in data.get("results", [])]
async def example_usage():
"""OpenAlexSource使用示例"""
# 初始化OpenAlexSource
openalex = OpenAlexSource()
try:
print("正在搜索论文...")
# 搜索与"artificial intelligence"相关的论文,限制返回5篇
papers = await openalex.search(query="artificial intelligence", limit=5)
if not papers:
print("未获取到任何论文信息")
return
print(f"找到 {len(papers)} 篇论文")
# 打印搜索结果
for i, paper in enumerate(papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors) if paper.authors else '未知'}")
if paper.institutions:
print(f"机构: {', '.join(paper.institutions)}")
print(f"发表年份: {paper.year if paper.year else '未知'}")
print(f"DOI: {paper.doi if paper.doi else '未知'}")
print(f"URL: {paper.url if paper.url else '未知'}")
if paper.abstract:
print(f"摘要: {paper.abstract[:200]}...")
print(f"引用次数: {paper.citations if paper.citations is not None else '未知'}")
print(f"发表venue: {paper.venue if paper.venue else '未知'}")
except Exception as e:
print(f"发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
# 如果直接运行此文件,执行示例代码
if __name__ == "__main__":
import asyncio
# 运行示例
asyncio.run(example_usage())

某些文件未显示,因为此 diff 中更改的文件太多 显示更多