镜像自地址
https://github.com/binary-husky/gpt_academic.git
已同步 2025-12-07 23:16:48 +00:00
up
这个提交包含在:
@@ -15,7 +15,7 @@ def get_crazy_functions():
|
|||||||
from crazy_functions.SourceCode_Analyse import 解析一个Rust项目
|
from crazy_functions.SourceCode_Analyse import 解析一个Rust项目
|
||||||
from crazy_functions.SourceCode_Analyse import 解析一个Java项目
|
from crazy_functions.SourceCode_Analyse import 解析一个Java项目
|
||||||
from crazy_functions.SourceCode_Analyse import 解析一个前端项目
|
from crazy_functions.SourceCode_Analyse import 解析一个前端项目
|
||||||
from crazy_functions.Arxiv_论文对话 import Rag论文对话
|
from crazy_functions.Arxiv_论文对话 import Arxiv论文对话
|
||||||
from crazy_functions.高级功能函数模板 import 高阶功能模板函数
|
from crazy_functions.高级功能函数模板 import 高阶功能模板函数
|
||||||
from crazy_functions.高级功能函数模板 import Demo_Wrap
|
from crazy_functions.高级功能函数模板 import Demo_Wrap
|
||||||
from crazy_functions.Latex全文润色 import Latex英文润色
|
from crazy_functions.Latex全文润色 import Latex英文润色
|
||||||
@@ -31,6 +31,8 @@ def get_crazy_functions():
|
|||||||
from crazy_functions.Markdown_Translate import Markdown英译中
|
from crazy_functions.Markdown_Translate import Markdown英译中
|
||||||
from crazy_functions.批量总结PDF文档 import 批量总结PDF文档
|
from crazy_functions.批量总结PDF文档 import 批量总结PDF文档
|
||||||
from crazy_functions.PDF_Translate import 批量翻译PDF文档
|
from crazy_functions.PDF_Translate import 批量翻译PDF文档
|
||||||
|
from crazy_functions.批量文件询问 import 批量文件询问
|
||||||
|
|
||||||
from crazy_functions.谷歌检索小助手 import 谷歌检索小助手
|
from crazy_functions.谷歌检索小助手 import 谷歌检索小助手
|
||||||
from crazy_functions.理解PDF文档内容 import 理解PDF文档内容标准文件输入
|
from crazy_functions.理解PDF文档内容 import 理解PDF文档内容标准文件输入
|
||||||
from crazy_functions.Latex全文润色 import Latex中文润色
|
from crazy_functions.Latex全文润色 import Latex中文润色
|
||||||
@@ -74,12 +76,25 @@ def get_crazy_functions():
|
|||||||
"Function": HotReload(Latex翻译中文并重新编译PDF), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
|
"Function": HotReload(Latex翻译中文并重新编译PDF), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
|
||||||
"Class": Arxiv_Localize, # 新一代插件需要注册Class
|
"Class": Arxiv_Localize, # 新一代插件需要注册Class
|
||||||
},
|
},
|
||||||
"Rag论文对话": {
|
"批量文件询问": {
|
||||||
"Group": "学术",
|
"Group": "学术",
|
||||||
"Color": "stop",
|
"Color": "stop",
|
||||||
"AsButton": False,
|
"AsButton": False,
|
||||||
"Info": "Arixv论文精细翻译 | 输入参数arxiv论文的ID,比如1812.10695",
|
"AdvancedArgs": True,
|
||||||
"Function": HotReload(Rag论文对话), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
|
"Info": "通过在高级参数区写入prompt,可自定义询问逻辑,默认情况下为总结逻辑 | 输入参数为路径",
|
||||||
|
"ArgsReminder": r"1、请不要更改上方输入框中以“private_upload/...”开头的路径。 "
|
||||||
|
r"2、请在下方高级参数区中输入你的prompt,文档中的内容将被添加你的prompt后。3、示例:“请总结下面的内容:”,此时,文档内容将添加在“:”后 ",
|
||||||
|
"Function": HotReload(批量文件询问),
|
||||||
|
},
|
||||||
|
"Arxiv论文对话": {
|
||||||
|
"Group": "学术",
|
||||||
|
"Color": "stop",
|
||||||
|
"AsButton": False,
|
||||||
|
"AdvancedArgs": True,
|
||||||
|
"Info": "通过在高级参数区写入prompt,可自定义询问逻辑,默认情况下为总结逻辑 | 输入参数为路径",
|
||||||
|
"ArgsReminder": r"1、请不要更改上方输入框中以“private_upload/...”开头的路径。 "
|
||||||
|
r"2、请在下方高级参数区中输入你的prompt,文档中的内容将被添加你的prompt后。3、示例:“这篇文章的方法是什么:” ",
|
||||||
|
"Function": HotReload(Arxiv论文对话),
|
||||||
},
|
},
|
||||||
"翻译README或MD": {
|
"翻译README或MD": {
|
||||||
"Group": "编程",
|
"Group": "编程",
|
||||||
@@ -604,6 +619,23 @@ def get_crazy_functions():
|
|||||||
logger.error(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
logger.error("Load function plugin failed")
|
logger.error("Load function plugin failed")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from crazy_functions.Arxiv_论文对话 import Arxiv论文对话
|
||||||
|
|
||||||
|
function_plugins.update(
|
||||||
|
{
|
||||||
|
"Arxiv论文对话": {
|
||||||
|
"Group": "对话",
|
||||||
|
"Color": "stop",
|
||||||
|
"AsButton": False,
|
||||||
|
"Info": "将问答数据记录到向量库中,作为长期参考。",
|
||||||
|
"Function": HotReload(Arxiv论文对话),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
logger.error(trimmed_format_exc())
|
||||||
|
logger.error("Load function plugin failed")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,63 +1,416 @@
|
|||||||
import os.path
|
import os
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Generator, Dict, Union
|
||||||
|
from datetime import datetime
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
from toolbox import CatchException, update_ui
|
from shared_utils.fastapi_server import validate_path_safety
|
||||||
from crazy_functions.rag_fns.arxiv_fns.paper_processing import ArxivPaperProcessor
|
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg
|
||||||
|
from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import ArxivSplitter, save_fragments_to_file
|
||||||
|
from crazy_functions.rag_fns.arxiv_fns.arxiv_fragment import ArxivFragment as Fragment
|
||||||
|
|
||||||
|
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
|
||||||
|
from crazy_functions.crazy_utils import input_clipping
|
||||||
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||||
|
|
||||||
|
# 全局常量配置
|
||||||
|
MAX_HISTORY_ROUND = 5 # 最大历史对话轮数
|
||||||
|
MAX_CONTEXT_TOKEN_LIMIT = 4096 # 上下文最大token数
|
||||||
|
REMEMBER_PREVIEW = 1000 # 记忆预览长度
|
||||||
|
VECTOR_STORE_TYPE = "Simple" # 向量存储类型:Simple或Milvus
|
||||||
|
MAX_CONCURRENT_PAPERS = 5 # 最大并行处理论文数
|
||||||
|
MAX_WORKERS = 3 # 最大工作线程数
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProcessingTask:
|
||||||
|
"""论文处理任务数据类"""
|
||||||
|
arxiv_id: str
|
||||||
|
status: str = "pending" # pending, processing, completed, failed
|
||||||
|
error: Optional[str] = None
|
||||||
|
fragments: List[Fragment] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ArxivRagWorker:
|
||||||
|
def __init__(self, user_name: str, llm_kwargs: Dict):
|
||||||
|
self.user_name = user_name
|
||||||
|
self.llm_kwargs = llm_kwargs
|
||||||
|
|
||||||
|
# 初始化存储目录
|
||||||
|
self.checkpoint_dir = Path(get_log_folder(user_name, plugin_name='rag_cache'))
|
||||||
|
self.vector_store_dir = self.checkpoint_dir / "vector_store"
|
||||||
|
self.fragment_store_dir = self.checkpoint_dir / "fragments"
|
||||||
|
|
||||||
|
# 创建必要的目录
|
||||||
|
self.vector_store_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.fragment_store_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
logger.info(f"Vector store directory: {self.vector_store_dir}")
|
||||||
|
logger.info(f"Fragment store directory: {self.fragment_store_dir}")
|
||||||
|
|
||||||
|
# 初始化RAG worker
|
||||||
|
self.rag_worker = LlamaIndexRagWorker(
|
||||||
|
user_name=user_name,
|
||||||
|
llm_kwargs=llm_kwargs,
|
||||||
|
checkpoint_dir=str(self.vector_store_dir),
|
||||||
|
auto_load_checkpoint=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 初始化arxiv splitter
|
||||||
|
self.arxiv_splitter = ArxivSplitter(
|
||||||
|
char_range=(1000, 1200),
|
||||||
|
root_dir=str(self.checkpoint_dir / "arxiv_cache")
|
||||||
|
)
|
||||||
|
|
||||||
|
# 初始化并行处理组件
|
||||||
|
self.processing_queue = {}
|
||||||
|
self.semaphore = asyncio.Semaphore(MAX_CONCURRENT_PAPERS)
|
||||||
|
self.thread_pool = ThreadPoolExecutor(max_workers=MAX_WORKERS)
|
||||||
|
|
||||||
|
async def _process_fragments(self, fragments: List[Fragment]) -> None:
|
||||||
|
"""并行处理论文片段"""
|
||||||
|
if not fragments:
|
||||||
|
logger.warning("No fragments to process")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 首先添加论文概述
|
||||||
|
overview = {
|
||||||
|
"title": fragments[0].title,
|
||||||
|
"abstract": fragments[0].abstract,
|
||||||
|
"arxiv_id": fragments[0].arxiv_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
overview_text = (
|
||||||
|
f"Paper Title: {overview['title']}\n"
|
||||||
|
f"ArXiv ID: {overview['arxiv_id']}\n"
|
||||||
|
f"Abstract: {overview['abstract']}\n"
|
||||||
|
f"Type: OVERVIEW"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 同步添加概述
|
||||||
|
self.rag_worker.add_text_to_vector_store(overview_text)
|
||||||
|
logger.info(f"Added paper overview for {overview['arxiv_id']}")
|
||||||
|
|
||||||
|
# 并行处理其余片段
|
||||||
|
tasks = []
|
||||||
|
for i, fragment in enumerate(fragments):
|
||||||
|
task = asyncio.get_event_loop().run_in_executor(
|
||||||
|
self.thread_pool,
|
||||||
|
self._process_single_fragment,
|
||||||
|
fragment,
|
||||||
|
i
|
||||||
|
)
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
logger.info(f"Processed {len(fragments)} fragments successfully")
|
||||||
|
|
||||||
|
# 保存到本地文件用于调试
|
||||||
|
save_fragments_to_file(
|
||||||
|
fragments,
|
||||||
|
str(self.fragment_store_dir / f"{overview['arxiv_id']}_fragments.json")
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing fragments: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _process_single_fragment(self, fragment: Fragment, index: int) -> None:
|
||||||
|
"""处理单个论文片段"""
|
||||||
|
try:
|
||||||
|
text = (
|
||||||
|
f"Paper Title: {fragment.title}\n"
|
||||||
|
f"ArXiv ID: {fragment.arxiv_id}\n"
|
||||||
|
f"Section: {fragment.section}\n"
|
||||||
|
f"Fragment Index: {index}\n"
|
||||||
|
f"Content: {fragment.content}\n"
|
||||||
|
f"Type: FRAGMENT"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Processing fragment {index} for paper {fragment.arxiv_id}")
|
||||||
|
self.rag_worker.add_text_to_vector_store(text)
|
||||||
|
logger.info(f"Successfully added fragment {index} to vector store")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing fragment {index}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def process_paper(self, arxiv_id: str) -> bool:
|
||||||
|
"""处理论文主函数"""
|
||||||
|
try:
|
||||||
|
arxiv_id = self._normalize_arxiv_id(arxiv_id)
|
||||||
|
logger.info(f"Starting to process paper: {arxiv_id}")
|
||||||
|
|
||||||
|
paper_path = self.checkpoint_dir / f"{arxiv_id}.processed"
|
||||||
|
|
||||||
|
if paper_path.exists():
|
||||||
|
logger.info(f"Paper {arxiv_id} already processed")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 创建处理任务
|
||||||
|
task = ProcessingTask(arxiv_id=arxiv_id)
|
||||||
|
self.processing_queue[arxiv_id] = task
|
||||||
|
task.status = "processing"
|
||||||
|
|
||||||
|
async with self.semaphore:
|
||||||
|
# 下载和分割论文
|
||||||
|
fragments = await self.arxiv_splitter.process(arxiv_id)
|
||||||
|
|
||||||
|
if not fragments:
|
||||||
|
raise ValueError(f"No fragments extracted from paper {arxiv_id}")
|
||||||
|
|
||||||
|
logger.info(f"Got {len(fragments)} fragments from paper {arxiv_id}")
|
||||||
|
|
||||||
|
# 处理片段
|
||||||
|
await self._process_fragments(fragments)
|
||||||
|
|
||||||
|
# 标记完成
|
||||||
|
paper_path.touch()
|
||||||
|
task.status = "completed"
|
||||||
|
task.fragments = fragments
|
||||||
|
|
||||||
|
logger.info(f"Successfully processed paper {arxiv_id}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing paper {arxiv_id}: {str(e)}")
|
||||||
|
if arxiv_id in self.processing_queue:
|
||||||
|
self.processing_queue[arxiv_id].status = "failed"
|
||||||
|
self.processing_queue[arxiv_id].error = str(e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _normalize_arxiv_id(self, input_str: str) -> str:
|
||||||
|
"""规范化ArXiv ID"""
|
||||||
|
if 'arxiv.org/' in input_str.lower():
|
||||||
|
if '/pdf/' in input_str:
|
||||||
|
arxiv_id = input_str.split('/pdf/')[-1]
|
||||||
|
else:
|
||||||
|
arxiv_id = input_str.split('/abs/')[-1]
|
||||||
|
return arxiv_id.split('v')[0].strip()
|
||||||
|
return input_str.split('v')[0].strip()
|
||||||
|
|
||||||
|
async def wait_for_paper(self, arxiv_id: str, timeout: float = 300.0) -> bool:
|
||||||
|
"""等待论文处理完成"""
|
||||||
|
try:
|
||||||
|
start_time = datetime.now()
|
||||||
|
while True:
|
||||||
|
task = self.processing_queue.get(arxiv_id)
|
||||||
|
if not task:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if task.status == "completed":
|
||||||
|
return True
|
||||||
|
|
||||||
|
if task.status == "failed":
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查超时
|
||||||
|
if (datetime.now() - start_time).total_seconds() > timeout:
|
||||||
|
logger.error(f"Processing paper {arxiv_id} timed out")
|
||||||
|
return False
|
||||||
|
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error waiting for paper {arxiv_id}: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def retrieve_and_generate(self, query: str) -> str:
|
||||||
|
"""检索相关内容并生成提示词"""
|
||||||
|
try:
|
||||||
|
nodes = self.rag_worker.retrieve_from_store_with_query(query)
|
||||||
|
return self.rag_worker.build_prompt(query=query, nodes=nodes)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in retrieve and generate: {str(e)}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def remember_qa(self, question: str, answer: str) -> None:
|
||||||
|
"""记忆问答对"""
|
||||||
|
try:
|
||||||
|
self.rag_worker.remember_qa(question, answer)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error remembering QA: {str(e)}")
|
||||||
|
|
||||||
|
async def auto_analyze_paper(self, chatbot: List, history: List, system_prompt: str) -> None:
|
||||||
|
"""自动分析论文的关键问题"""
|
||||||
|
key_questions = [
|
||||||
|
"What is the main research question or problem addressed in this paper?",
|
||||||
|
"What methods or approaches did the authors use to investigate the problem?",
|
||||||
|
"What are the key findings or results presented in the paper?",
|
||||||
|
"How do the findings of this paper contribute to the broader field or topic of study?",
|
||||||
|
"What are the limitations of this study, and what future research directions do the authors suggest?"
|
||||||
|
]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for question in key_questions:
|
||||||
|
try:
|
||||||
|
prompt = self.retrieve_and_generate(question)
|
||||||
|
if prompt:
|
||||||
|
response = await request_gpt_model_in_new_thread_with_ui_alive(
|
||||||
|
inputs=prompt,
|
||||||
|
inputs_show_user=question,
|
||||||
|
llm_kwargs=self.llm_kwargs,
|
||||||
|
chatbot=chatbot,
|
||||||
|
history=history,
|
||||||
|
sys_prompt=system_prompt
|
||||||
|
)
|
||||||
|
results.append(f"Q: {question}\nA: {response}\n")
|
||||||
|
self.remember_qa(question, response)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in auto analysis: {str(e)}")
|
||||||
|
|
||||||
|
# 合并所有结果
|
||||||
|
summary = "\n\n".join(results)
|
||||||
|
chatbot[-1] = (chatbot[-1][0], f"论文已成功加载并完成初步分析:\n\n{summary}\n\n您现在可以继续提问更多细节。")
|
||||||
|
|
||||||
@CatchException
|
@CatchException
|
||||||
def Rag论文对话(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
|
||||||
|
history: List, system_prompt: str, web_port: str) -> Generator:
|
||||||
"""
|
"""
|
||||||
txt: 用户输入,通常是arxiv论文链接
|
Arxiv论文对话主函数
|
||||||
功能:RAG论文总结和对话
|
Args:
|
||||||
|
txt: arxiv ID/URL
|
||||||
|
llm_kwargs: LLM配置参数
|
||||||
|
plugin_kwargs: 插件配置参数,包含 advanced_arg 字段作为用户询问指令
|
||||||
|
chatbot: 对话历史
|
||||||
|
history: 聊天历史
|
||||||
|
system_prompt: 系统提示词
|
||||||
|
web_port: Web端口
|
||||||
"""
|
"""
|
||||||
if_project, if_arxiv = False, False
|
# 初始化时,提示用户需要 arxiv ID/URL
|
||||||
if os.path.exists(txt):
|
if len(history) == 0 and not txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '1', '2')):
|
||||||
from crazy_functions.rag_fns.doc_fns.document_splitter import SmartDocumentSplitter
|
chatbot.append((txt, "请先提供Arxiv论文链接或ID。"))
|
||||||
splitter = SmartDocumentSplitter(
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
char_range=(1000, 1200),
|
return
|
||||||
max_workers=32 # 可选,默认会根据CPU核心数自动设置
|
|
||||||
|
user_name = chatbot.get_user()
|
||||||
|
worker = ArxivRagWorker(user_name, llm_kwargs)
|
||||||
|
|
||||||
|
# 处理新论文的情况
|
||||||
|
if txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '1', '2')):
|
||||||
|
chatbot.append((txt, "正在处理论文,请稍等..."))
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
|
||||||
|
# 创建事件循环来处理异步调用
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
# 运行异步处理函数
|
||||||
|
success = loop.run_until_complete(worker.process_paper(txt))
|
||||||
|
if success:
|
||||||
|
arxiv_id = worker._normalize_arxiv_id(txt)
|
||||||
|
success = loop.run_until_complete(worker.wait_for_paper(arxiv_id))
|
||||||
|
if success:
|
||||||
|
# 执行自动分析
|
||||||
|
yield from worker.auto_analyze_paper(chatbot, history, system_prompt)
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
chatbot[-1] = (txt, "论文处理失败,请检查论文ID是否正确或稍后重试。")
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
return
|
||||||
|
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 处理用户询问的情况
|
||||||
|
# 获取用户询问指令
|
||||||
|
user_query = plugin_kwargs.get("advanced_arg", "")
|
||||||
|
if not user_query:
|
||||||
|
chatbot.append((txt, "请提供您的问题。"))
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 处理历史对话长度
|
||||||
|
if len(history) > MAX_HISTORY_ROUND * 2:
|
||||||
|
history = history[-(MAX_HISTORY_ROUND * 2):]
|
||||||
|
|
||||||
|
# 处理询问指令
|
||||||
|
query_clip, history, flags = input_clipping(
|
||||||
|
user_query,
|
||||||
|
history,
|
||||||
|
max_token_limit=MAX_CONTEXT_TOKEN_LIMIT,
|
||||||
|
return_clip_flags=True
|
||||||
)
|
)
|
||||||
if_project = True
|
|
||||||
|
if flags["original_input_len"] != flags["clipped_input_len"]:
|
||||||
|
yield from update_ui_lastest_msg('检测到长输入,正在处理...', chatbot, history, delay=0)
|
||||||
|
if len(user_query) > REMEMBER_PREVIEW:
|
||||||
|
HALF = REMEMBER_PREVIEW // 2
|
||||||
|
query_to_remember = user_query[:HALF] + f" ...\n...(省略{len(user_query) - REMEMBER_PREVIEW}字)...\n... " + user_query[-HALF:]
|
||||||
else:
|
else:
|
||||||
from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import SmartArxivSplitter
|
query_to_remember = query_clip
|
||||||
splitter = SmartArxivSplitter(
|
else:
|
||||||
char_range=(1000, 1200),
|
query_to_remember = query_clip
|
||||||
root_dir="gpt_log/arxiv_cache"
|
|
||||||
|
chatbot.append((user_query, "正在思考中..."))
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
|
||||||
|
# 生成提示词
|
||||||
|
prompt = worker.retrieve_and_generate(query_clip)
|
||||||
|
if not prompt:
|
||||||
|
chatbot[-1] = (user_query, "抱歉,处理您的问题时出现错误,请重试。")
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 获取回答
|
||||||
|
response = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||||
|
inputs=prompt,
|
||||||
|
inputs_show_user=query_clip,
|
||||||
|
llm_kwargs=llm_kwargs,
|
||||||
|
chatbot=chatbot,
|
||||||
|
history=history,
|
||||||
|
sys_prompt=system_prompt
|
||||||
)
|
)
|
||||||
if_arxiv = True
|
|
||||||
for fragment in splitter.process(txt):
|
|
||||||
pass
|
|
||||||
# 初始化处理器
|
|
||||||
processor = ArxivPaperProcessor()
|
|
||||||
rag_handler = RagHandler()
|
|
||||||
|
|
||||||
# Step 1: 下载和提取论文
|
# 记忆问答对
|
||||||
download_result = processor.download_and_extract(txt, chatbot, history)
|
worker.remember_qa(query_to_remember, response)
|
||||||
project_folder, arxiv_id = None, None
|
history.extend([user_query, response])
|
||||||
|
|
||||||
for result in download_result:
|
|
||||||
if isinstance(result, tuple) and len(result) == 2:
|
|
||||||
project_folder, arxiv_id = result
|
|
||||||
break
|
|
||||||
|
|
||||||
if not project_folder or not arxiv_id:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Step 2: 合并TEX文件
|
|
||||||
paper_content = processor.merge_tex_files(project_folder, chatbot, history)
|
|
||||||
if not paper_content:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Step 3: RAG处理
|
|
||||||
chatbot.append(["正在构建知识图谱...", "处理中..."])
|
|
||||||
yield from update_ui(chatbot=chatbot, history=history)
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
|
||||||
# 处理论文内容
|
if __name__ == "__main__":
|
||||||
rag_handler.process_paper_content(paper_content)
|
# 测试代码
|
||||||
|
llm_kwargs = {
|
||||||
|
'api_key': os.getenv("one_api_key"),
|
||||||
|
'client_ip': '127.0.0.1',
|
||||||
|
'embed_model': 'text-embedding-3-small',
|
||||||
|
'llm_model': 'one-api-Qwen2.5-72B-Instruct',
|
||||||
|
'max_length': 4096,
|
||||||
|
'most_recent_uploaded': None,
|
||||||
|
'temperature': 1,
|
||||||
|
'top_p': 1
|
||||||
|
}
|
||||||
|
plugin_kwargs = {}
|
||||||
|
chatbot = []
|
||||||
|
history = []
|
||||||
|
system_prompt = "You are a helpful assistant."
|
||||||
|
web_port = "8080"
|
||||||
|
|
||||||
# 生成初始摘要
|
# 测试论文导入
|
||||||
summary = rag_handler.query("请总结这篇论文的主要内容,包括研究目的、方法、结果和结论。")
|
arxiv_url = "https://arxiv.org/abs/2312.12345"
|
||||||
chatbot.append(["论文摘要", summary])
|
for response in Arxiv论文对话(
|
||||||
yield from update_ui(chatbot=chatbot, history=history)
|
arxiv_url, llm_kwargs, plugin_kwargs,
|
||||||
|
chatbot, history, system_prompt, web_port
|
||||||
|
):
|
||||||
|
print(response)
|
||||||
|
|
||||||
# 交互式问答
|
# 测试问答
|
||||||
|
question = "这篇论文的主要贡献是什么?"
|
||||||
|
for response in Arxiv论文对话(
|
||||||
|
question, llm_kwargs, plugin_kwargs,
|
||||||
|
chatbot, history, system_prompt, web_port
|
||||||
|
):
|
||||||
|
print(response)
|
||||||
@@ -1,15 +1,18 @@
|
|||||||
|
import llama_index
|
||||||
|
import os
|
||||||
import atexit
|
import atexit
|
||||||
from typing import List, Dict, Optional, Any, Tuple
|
|
||||||
|
|
||||||
from llama_index.core import Document
|
|
||||||
from llama_index.core.ingestion import run_transformations
|
|
||||||
from llama_index.core.schema import TextNode, NodeWithScore
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from typing import List
|
||||||
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
|
from llama_index.core import Document
|
||||||
|
from llama_index.core.schema import TextNode
|
||||||
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
|
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
|
||||||
import json
|
from shared_utils.connect_void_terminal import get_chat_default_kwargs
|
||||||
import numpy as np
|
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
|
||||||
|
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
|
||||||
|
from llama_index.core.ingestion import run_transformations
|
||||||
|
from llama_index.core import PromptTemplate
|
||||||
|
from llama_index.core.response_synthesizers import TreeSummarize
|
||||||
|
|
||||||
DEFAULT_QUERY_GENERATION_PROMPT = """\
|
DEFAULT_QUERY_GENERATION_PROMPT = """\
|
||||||
Now, you have context information as below:
|
Now, you have context information as below:
|
||||||
---------------------
|
---------------------
|
||||||
@@ -60,7 +63,7 @@ class SaveLoad():
|
|||||||
def purge(self):
|
def purge(self):
|
||||||
import shutil
|
import shutil
|
||||||
shutil.rmtree(self.checkpoint_dir, ignore_errors=True)
|
shutil.rmtree(self.checkpoint_dir, ignore_errors=True)
|
||||||
self.vs_index = self.create_new_vs(self.checkpoint_dir)
|
self.vs_index = self.create_new_vs()
|
||||||
|
|
||||||
|
|
||||||
class LlamaIndexRagWorker(SaveLoad):
|
class LlamaIndexRagWorker(SaveLoad):
|
||||||
@@ -69,11 +72,61 @@ class LlamaIndexRagWorker(SaveLoad):
|
|||||||
self.embed_model = OpenAiEmbeddingModel(llm_kwargs)
|
self.embed_model = OpenAiEmbeddingModel(llm_kwargs)
|
||||||
self.user_name = user_name
|
self.user_name = user_name
|
||||||
self.checkpoint_dir = checkpoint_dir
|
self.checkpoint_dir = checkpoint_dir
|
||||||
if auto_load_checkpoint:
|
|
||||||
self.vs_index = self.load_from_checkpoint(checkpoint_dir)
|
# 确保checkpoint_dir存在
|
||||||
|
if checkpoint_dir:
|
||||||
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
|
|
||||||
|
logger.info(f"Initializing LlamaIndexRagWorker with checkpoint_dir: {checkpoint_dir}")
|
||||||
|
|
||||||
|
# 初始化向量存储
|
||||||
|
if auto_load_checkpoint and self.does_checkpoint_exist():
|
||||||
|
logger.info("Loading existing vector store from checkpoint")
|
||||||
|
self.vs_index = self.load_from_checkpoint()
|
||||||
else:
|
else:
|
||||||
|
logger.info("Creating new vector store")
|
||||||
self.vs_index = self.create_new_vs()
|
self.vs_index = self.create_new_vs()
|
||||||
atexit.register(lambda: self.save_to_checkpoint(checkpoint_dir))
|
|
||||||
|
# 注册退出时保存
|
||||||
|
atexit.register(self.save_to_checkpoint)
|
||||||
|
|
||||||
|
def add_text_to_vector_store(self, text: str) -> None:
|
||||||
|
"""添加文本到向量存储"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Adding text to vector store (first 100 chars): {text[:100]}...")
|
||||||
|
node = TextNode(text=text)
|
||||||
|
nodes = run_transformations(
|
||||||
|
[node],
|
||||||
|
self.vs_index._transformations,
|
||||||
|
show_progress=True
|
||||||
|
)
|
||||||
|
self.vs_index.insert_nodes(nodes)
|
||||||
|
|
||||||
|
# 立即保存
|
||||||
|
self.save_to_checkpoint()
|
||||||
|
|
||||||
|
if self.debug_mode:
|
||||||
|
self.inspect_vector_store()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error adding text to vector store: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def save_to_checkpoint(self, checkpoint_dir=None):
|
||||||
|
"""保存向量存储到检查点"""
|
||||||
|
try:
|
||||||
|
if checkpoint_dir is None:
|
||||||
|
checkpoint_dir = self.checkpoint_dir
|
||||||
|
logger.info(f'Saving vector store to: {checkpoint_dir}')
|
||||||
|
if checkpoint_dir:
|
||||||
|
self.vs_index.storage_context.persist(persist_dir=checkpoint_dir)
|
||||||
|
logger.info('Vector store saved successfully')
|
||||||
|
else:
|
||||||
|
logger.warning('No checkpoint directory specified, skipping save')
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving checkpoint: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def assign_embedding_model(self):
|
def assign_embedding_model(self):
|
||||||
pass
|
pass
|
||||||
@@ -88,38 +141,22 @@ class LlamaIndexRagWorker(SaveLoad):
|
|||||||
logger.info('oo --------inspect_vector_store end--------')
|
logger.info('oo --------inspect_vector_store end--------')
|
||||||
return vector_store_preview
|
return vector_store_preview
|
||||||
|
|
||||||
def add_documents_to_vector_store(self, document_list: List[Document]):
|
def add_documents_to_vector_store(self, document_list):
|
||||||
"""
|
documents = [Document(text=t) for t in document_list]
|
||||||
Adds a list of Document objects to the vector store after processing.
|
|
||||||
"""
|
|
||||||
documents = document_list
|
|
||||||
documents_nodes = run_transformations(
|
documents_nodes = run_transformations(
|
||||||
documents, # type: ignore
|
documents, # type: ignore
|
||||||
self.vs_index._transformations,
|
self.vs_index._transformations,
|
||||||
show_progress=True
|
show_progress=True
|
||||||
)
|
)
|
||||||
self.vs_index.insert_nodes(documents_nodes)
|
self.vs_index.insert_nodes(documents_nodes)
|
||||||
if self.debug_mode:
|
if self.debug_mode: self.inspect_vector_store()
|
||||||
self.inspect_vector_store()
|
|
||||||
|
|
||||||
def add_text_to_vector_store(self, text: str):
|
|
||||||
node = TextNode(text=text)
|
|
||||||
documents_nodes = run_transformations(
|
|
||||||
[node],
|
|
||||||
self.vs_index._transformations,
|
|
||||||
show_progress=True
|
|
||||||
)
|
|
||||||
self.vs_index.insert_nodes(documents_nodes)
|
|
||||||
if self.debug_mode:
|
|
||||||
self.inspect_vector_store()
|
|
||||||
|
|
||||||
def remember_qa(self, question, answer):
|
def remember_qa(self, question, answer):
|
||||||
formatted_str = QUESTION_ANSWER_RECORD.format(question=question, answer=answer)
|
formatted_str = QUESTION_ANSWER_RECORD.format(question=question, answer=answer)
|
||||||
self.add_text_to_vector_store(formatted_str)
|
self.add_text_to_vector_store(formatted_str)
|
||||||
|
|
||||||
def retrieve_from_store_with_query(self, query):
|
def retrieve_from_store_with_query(self, query):
|
||||||
if self.debug_mode:
|
if self.debug_mode: self.inspect_vector_store()
|
||||||
self.inspect_vector_store()
|
|
||||||
retriever = self.vs_index.as_retriever()
|
retriever = self.vs_index.as_retriever()
|
||||||
return retriever.retrieve(query)
|
return retriever.retrieve(query)
|
||||||
|
|
||||||
@@ -131,227 +168,3 @@ class LlamaIndexRagWorker(SaveLoad):
|
|||||||
buf = "\n".join(([f"(No.{i + 1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)]))
|
buf = "\n".join(([f"(No.{i + 1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)]))
|
||||||
if self.debug_mode: logger.info(buf)
|
if self.debug_mode: logger.info(buf)
|
||||||
return buf
|
return buf
|
||||||
|
|
||||||
def purge_vector_store(self):
|
|
||||||
"""
|
|
||||||
Purges the current vector store and creates a new one.
|
|
||||||
"""
|
|
||||||
self.purge()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
以下是添加的新方法,原有方法保持不变
|
|
||||||
"""
|
|
||||||
|
|
||||||
def add_text_with_metadata(self, text: str, metadata: dict) -> str:
|
|
||||||
"""
|
|
||||||
添加带元数据的文本到向量存储
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: 文本内容
|
|
||||||
metadata: 元数据字典
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
添加的节点ID
|
|
||||||
"""
|
|
||||||
node = TextNode(text=text, metadata=metadata)
|
|
||||||
nodes = run_transformations(
|
|
||||||
[node],
|
|
||||||
self.vs_index._transformations,
|
|
||||||
show_progress=True
|
|
||||||
)
|
|
||||||
self.vs_index.insert_nodes(nodes)
|
|
||||||
return nodes[0].node_id if nodes else None
|
|
||||||
|
|
||||||
def batch_add_texts_with_metadata(self, texts: List[Tuple[str, dict]]) -> List[str]:
|
|
||||||
"""
|
|
||||||
批量添加带元数据的文本
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: (text, metadata)元组列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
添加的节点ID列表
|
|
||||||
"""
|
|
||||||
nodes = [TextNode(text=t, metadata=m) for t, m in texts]
|
|
||||||
transformed_nodes = run_transformations(
|
|
||||||
nodes,
|
|
||||||
self.vs_index._transformations,
|
|
||||||
show_progress=True
|
|
||||||
)
|
|
||||||
if transformed_nodes:
|
|
||||||
self.vs_index.insert_nodes(transformed_nodes)
|
|
||||||
return [node.node_id for node in transformed_nodes]
|
|
||||||
return []
|
|
||||||
|
|
||||||
def get_node_metadata(self, node_id: str) -> Optional[dict]:
|
|
||||||
"""
|
|
||||||
获取节点的元数据
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node_id: 节点ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
节点的元数据字典
|
|
||||||
"""
|
|
||||||
node = self.vs_index.storage_context.docstore.docs.get(node_id)
|
|
||||||
return node.metadata if node else None
|
|
||||||
|
|
||||||
def update_node_metadata(self, node_id: str, metadata: dict, merge: bool = True) -> bool:
|
|
||||||
"""
|
|
||||||
更新节点的元数据
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node_id: 节点ID
|
|
||||||
metadata: 新的元数据
|
|
||||||
merge: 是否与现有元数据合并
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
是否更新成功
|
|
||||||
"""
|
|
||||||
docstore = self.vs_index.storage_context.docstore
|
|
||||||
if node_id in docstore.docs:
|
|
||||||
node = docstore.docs[node_id]
|
|
||||||
if merge:
|
|
||||||
node.metadata.update(metadata)
|
|
||||||
else:
|
|
||||||
node.metadata = metadata
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def filter_nodes_by_metadata(self, filters: Dict[str, Any]) -> List[TextNode]:
|
|
||||||
"""
|
|
||||||
按元数据过滤节点
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filters: 元数据过滤条件
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
符合条件的节点列表
|
|
||||||
"""
|
|
||||||
docstore = self.vs_index.storage_context.docstore
|
|
||||||
results = []
|
|
||||||
for node in docstore.docs.values():
|
|
||||||
if all(node.metadata.get(k) == v for k, v in filters.items()):
|
|
||||||
results.append(node)
|
|
||||||
return results
|
|
||||||
|
|
||||||
def retrieve_with_metadata_filter(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
metadata_filters: Dict[str, Any],
|
|
||||||
top_k: int = 5
|
|
||||||
) -> List[NodeWithScore]:
|
|
||||||
"""
|
|
||||||
结合元数据过滤的检索
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 查询文本
|
|
||||||
metadata_filters: 元数据过滤条件
|
|
||||||
top_k: 返回结果数量
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
检索结果节点列表
|
|
||||||
"""
|
|
||||||
retriever = self.vs_index.as_retriever(similarity_top_k=top_k)
|
|
||||||
nodes = retriever.retrieve(query)
|
|
||||||
|
|
||||||
# 应用元数据过滤
|
|
||||||
filtered_nodes = []
|
|
||||||
for node in nodes:
|
|
||||||
if all(node.metadata.get(k) == v for k, v in metadata_filters.items()):
|
|
||||||
filtered_nodes.append(node)
|
|
||||||
|
|
||||||
return filtered_nodes
|
|
||||||
|
|
||||||
def get_node_stats(self, node_id: str) -> dict:
|
|
||||||
"""
|
|
||||||
获取单个节点的统计信息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node_id: 节点ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
节点统计信息字典
|
|
||||||
"""
|
|
||||||
node = self.vs_index.storage_context.docstore.docs.get(node_id)
|
|
||||||
if not node:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
return {
|
|
||||||
"text_length": len(node.text),
|
|
||||||
"token_count": len(node.text.split()),
|
|
||||||
"has_embedding": node.embedding is not None,
|
|
||||||
"metadata_keys": list(node.metadata.keys()),
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_nodes_by_content_pattern(self, pattern: str) -> List[TextNode]:
|
|
||||||
"""
|
|
||||||
按内容模式查找节点
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pattern: 正则表达式模式
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
匹配的节点列表
|
|
||||||
"""
|
|
||||||
import re
|
|
||||||
docstore = self.vs_index.storage_context.docstore
|
|
||||||
matched_nodes = []
|
|
||||||
for node in docstore.docs.values():
|
|
||||||
if re.search(pattern, node.text):
|
|
||||||
matched_nodes.append(node)
|
|
||||||
return matched_nodes
|
|
||||||
def export_nodes(
|
|
||||||
self,
|
|
||||||
output_file: str,
|
|
||||||
format: str = "json",
|
|
||||||
include_embeddings: bool = False
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Export nodes to file
|
|
||||||
|
|
||||||
Args:
|
|
||||||
output_file: Output file path
|
|
||||||
format: "json" or "csv"
|
|
||||||
include_embeddings: Whether to include embeddings
|
|
||||||
"""
|
|
||||||
docstore = self.vs_index.storage_context.docstore
|
|
||||||
|
|
||||||
data = []
|
|
||||||
for node_id, node in docstore.docs.items():
|
|
||||||
node_data = {
|
|
||||||
"node_id": node_id,
|
|
||||||
"text": node.text,
|
|
||||||
"metadata": node.metadata,
|
|
||||||
}
|
|
||||||
if include_embeddings and node.embedding is not None:
|
|
||||||
node_data["embedding"] = node.embedding.tolist()
|
|
||||||
data.append(node_data)
|
|
||||||
|
|
||||||
if format == "json":
|
|
||||||
with open(output_file, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
elif format == "csv":
|
|
||||||
import csv
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
df = pd.DataFrame(data)
|
|
||||||
df.to_csv(output_file, index=False, quoting=csv.QUOTE_NONNUMERIC)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported format: {format}")
|
|
||||||
|
|
||||||
def get_statistics(self) -> Dict[str, Any]:
|
|
||||||
"""Get vector store statistics"""
|
|
||||||
docstore = self.vs_index.storage_context.docstore
|
|
||||||
docs = list(docstore.docs.values())
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total_nodes": len(docs),
|
|
||||||
"total_tokens": sum(len(node.text.split()) for node in docs),
|
|
||||||
"avg_text_length": np.mean([len(node.text) for node in docs]) if docs else 0,
|
|
||||||
"embedding_dimension": len(docs[0].embedding) if docs and docs[0].embedding is not None else 0
|
|
||||||
}
|
|
||||||
在新工单中引用
屏蔽一个用户