镜像自地址
https://github.com/binary-husky/gpt_academic.git
已同步 2025-12-06 22:46:48 +00:00
up
这个提交包含在:
@@ -43,22 +43,42 @@ class ProcessingTask:
|
||||
|
||||
|
||||
class ArxivRagWorker:
|
||||
def __init__(self, user_name: str, llm_kwargs: Dict):
|
||||
def __init__(self, user_name: str, llm_kwargs: Dict, arxiv_id: str = None):
|
||||
self.user_name = user_name
|
||||
self.llm_kwargs = llm_kwargs
|
||||
self.max_concurrent_papers = MAX_CONCURRENT_PAPERS # 存储最大并发数
|
||||
self.arxiv_id = self._normalize_arxiv_id(arxiv_id) if arxiv_id else None
|
||||
|
||||
# 初始化存储目录
|
||||
self.checkpoint_dir = Path(get_log_folder(user_name, plugin_name='rag_cache'))
|
||||
self.vector_store_dir = self.checkpoint_dir / "vector_store"
|
||||
self.fragment_store_dir = self.checkpoint_dir / "fragments"
|
||||
# 初始化基础存储目录
|
||||
self.base_dir = Path(get_log_folder(user_name, plugin_name='rag_cache'))
|
||||
if os.path.exists(self.base_dir):
|
||||
self.loading = True
|
||||
else:
|
||||
self.loading = False
|
||||
# 如果提供了 arxiv_id,创建针对该论文的子目录
|
||||
if self.arxiv_id:
|
||||
self.checkpoint_dir = self.base_dir / self.arxiv_id
|
||||
self.vector_store_dir = self.checkpoint_dir / "vector_store"
|
||||
self.fragment_store_dir = self.checkpoint_dir / "fragments"
|
||||
else:
|
||||
# 如果没有 arxiv_id,使用基础目录
|
||||
self.checkpoint_dir = self.base_dir
|
||||
self.vector_store_dir = self.base_dir / "vector_store"
|
||||
self.fragment_store_dir = self.base_dir / "fragments"
|
||||
|
||||
# 创建必要的目录
|
||||
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.vector_store_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.fragment_store_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f"Checkpoint directory: {self.checkpoint_dir}")
|
||||
logger.info(f"Vector store directory: {self.vector_store_dir}")
|
||||
logger.info(f"Fragment store directory: {self.fragment_store_dir}")
|
||||
|
||||
# 初始化处理队列和线程池
|
||||
self.processing_queue = {}
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=MAX_WORKERS)
|
||||
|
||||
# 初始化RAG worker
|
||||
self.rag_worker = LlamaIndexRagWorker(
|
||||
user_name=user_name,
|
||||
@@ -68,15 +88,30 @@ class ArxivRagWorker:
|
||||
)
|
||||
|
||||
# 初始化arxiv splitter
|
||||
# 初始化 arxiv splitter
|
||||
self.arxiv_splitter = ArxivSplitter(
|
||||
char_range=(1000, 1200),
|
||||
root_dir=str(self.checkpoint_dir / "arxiv_cache")
|
||||
)
|
||||
# 初始化处理队列和线程池
|
||||
self._semaphore = None
|
||||
self._loop = None
|
||||
|
||||
@property
|
||||
def loop(self):
|
||||
"""获取当前事件循环"""
|
||||
if self._loop is None:
|
||||
self._loop = asyncio.get_event_loop()
|
||||
return self._loop
|
||||
|
||||
@property
|
||||
def semaphore(self):
|
||||
"""延迟创建 semaphore"""
|
||||
if self._semaphore is None:
|
||||
self._semaphore = asyncio.Semaphore(self.max_concurrent_papers)
|
||||
return self._semaphore
|
||||
|
||||
|
||||
# 初始化并行处理组件
|
||||
self.processing_queue = {}
|
||||
self.semaphore = asyncio.Semaphore(MAX_CONCURRENT_PAPERS)
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=MAX_WORKERS)
|
||||
|
||||
async def _process_fragments(self, fragments: List[Fragment]) -> None:
|
||||
"""并行处理论文片段"""
|
||||
@@ -106,17 +141,11 @@ class ArxivRagWorker:
|
||||
# 并行处理其余片段
|
||||
tasks = []
|
||||
for i, fragment in enumerate(fragments):
|
||||
task = asyncio.get_event_loop().run_in_executor(
|
||||
self.thread_pool,
|
||||
self._process_single_fragment,
|
||||
fragment,
|
||||
i
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
tasks.append(self._process_single_fragment(fragment, i))
|
||||
await asyncio.gather(*tasks)
|
||||
logger.info(f"Processed {len(fragments)} fragments successfully")
|
||||
|
||||
|
||||
# 保存到本地文件用于调试
|
||||
save_fragments_to_file(
|
||||
fragments,
|
||||
@@ -127,8 +156,26 @@ class ArxivRagWorker:
|
||||
logger.error(f"Error processing fragments: {str(e)}")
|
||||
raise
|
||||
|
||||
def _process_single_fragment(self, fragment: Fragment, index: int) -> None:
|
||||
"""处理单个论文片段"""
|
||||
async def _process_single_fragment(self, fragment: Fragment, index: int) -> None:
|
||||
"""处理单个论文片段(改为异步方法)"""
|
||||
try:
|
||||
text = (
|
||||
f"Paper Title: {fragment.title}\n"
|
||||
f"ArXiv ID: {fragment.arxiv_id}\n"
|
||||
f"Section: {fragment.section}\n"
|
||||
f"Fragment Index: {index}\n"
|
||||
f"Content: {fragment.content}\n"
|
||||
f"Type: FRAGMENT"
|
||||
)
|
||||
|
||||
logger.info(f"Processing fragment {index} for paper {fragment.arxiv_id}")
|
||||
# 如果 add_text_to_vector_store 是异步的,使用 await
|
||||
self.rag_worker.add_text_to_vector_store(text)
|
||||
logger.info(f"Successfully added fragment {index} to vector store")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing fragment {index}: {str(e)}")
|
||||
raise """处理单个论文片段"""
|
||||
try:
|
||||
text = (
|
||||
f"Paper Title: {fragment.title}\n"
|
||||
@@ -289,16 +336,16 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot:
|
||||
web_port: Web端口
|
||||
"""
|
||||
# 初始化时,提示用户需要 arxiv ID/URL
|
||||
if len(history) == 0 and not txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '1', '2')):
|
||||
if len(history) == 0 and not txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '0','1', '2')):
|
||||
chatbot.append((txt, "请先提供Arxiv论文链接或ID。"))
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
user_name = chatbot.get_user()
|
||||
worker = ArxivRagWorker(user_name, llm_kwargs)
|
||||
worker = ArxivRagWorker(user_name, llm_kwargs, arxiv_id=txt)
|
||||
|
||||
# 处理新论文的情况
|
||||
if txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '1', '2')):
|
||||
if txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '0', '1', '2')) and not worker.loading:
|
||||
chatbot.append((txt, "正在处理论文,请稍等..."))
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
@@ -327,7 +374,8 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot:
|
||||
|
||||
# 处理用户询问的情况
|
||||
# 获取用户询问指令
|
||||
user_query = plugin_kwargs.get("advanced_arg", "")
|
||||
user_query = plugin_kwargs.get("advanced_arg", "What is the main research question or problem addressed in this paper?")
|
||||
# user_query = "What is the main research question or problem addressed in this paper about graph attention network?"
|
||||
if not user_query:
|
||||
chatbot.append((txt, "请提供您的问题。"))
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
@@ -14,6 +14,7 @@ class ArxivFragment:
|
||||
section: str # 所属章节
|
||||
is_appendix: bool # 是否是附录
|
||||
importance: float = 1.0 # 重要性得分
|
||||
arxiv_id: str = "" # 添加 arxiv_id 属性
|
||||
|
||||
@staticmethod
|
||||
def merge_segments(seg1: 'ArxivFragment', seg2: 'ArxivFragment') -> 'ArxivFragment':
|
||||
|
||||
在新工单中引用
屏蔽一个用户