diff --git a/crazy_functions/Arxiv_论文对话.py b/crazy_functions/Arxiv_论文对话.py index 125df8cc..3f1f4f9c 100644 --- a/crazy_functions/Arxiv_论文对话.py +++ b/crazy_functions/Arxiv_论文对话.py @@ -46,10 +46,10 @@ class ArxivRagWorker: def __init__(self, user_name: str, llm_kwargs: Dict, arxiv_id: str = None): self.user_name = user_name self.llm_kwargs = llm_kwargs - self.max_concurrent_papers = MAX_CONCURRENT_PAPERS + self.max_concurrent_papers = MAX_CONCURRENT_PAPERS # 存储最大并发数 self.arxiv_id = self._normalize_arxiv_id(arxiv_id) if arxiv_id else None - # Initialize base storage directory + # 初始化基础存储目录 self.base_dir = Path(get_log_folder(user_name, plugin_name='rag_cache')) if self.arxiv_id: @@ -57,6 +57,7 @@ class ArxivRagWorker: self.vector_store_dir = self.checkpoint_dir / "vector_store" self.fragment_store_dir = self.checkpoint_dir / "fragments" else: + # 如果没有 arxiv_id,使用基础目录 self.checkpoint_dir = self.base_dir self.vector_store_dir = self.base_dir / "vector_store" self.fragment_store_dir = self.base_dir / "fragments" @@ -75,11 +76,11 @@ class ArxivRagWorker: logger.info(f"Vector store directory: {self.vector_store_dir}") logger.info(f"Fragment store directory: {self.fragment_store_dir}") - # Initialize processing queue and thread pool + # 初始化处理队列和线程池 self.processing_queue = {} self.thread_pool = ThreadPoolExecutor(max_workers=MAX_WORKERS) - # Initialize RAG worker + # 初始化RAG worker self.rag_worker = LlamaIndexRagWorker( user_name=user_name, llm_kwargs=llm_kwargs, @@ -87,24 +88,97 @@ class ArxivRagWorker: auto_load_checkpoint=True ) - # Initialize arxiv splitter + # 初始化arxiv splitter + # 初始化 arxiv splitter self.arxiv_splitter = ArxivSplitter( root_dir=str(self.checkpoint_dir / "arxiv_cache") ) - async def _async_get_fragments(self, arxiv_id: str) -> List[Fragment]: - """Async helper to get fragments""" - return await self.arxiv_splitter.process(arxiv_id) + # 初始化处理队列和线程池 + self._semaphore = None + self._loop = None + + @property + def loop(self): + """获取当前事件循环""" + if self._loop is None: + self._loop = asyncio.get_event_loop() + return self._loop + + @property + def semaphore(self): + """延迟创建 semaphore""" + if self._semaphore is None: + self._semaphore = asyncio.Semaphore(self.max_concurrent_papers) + return self._semaphore + + + + async def _process_fragments(self, fragments: List[Fragment]) -> None: + """并行处理论文片段""" + if not fragments: + logger.warning("No fragments to process") + return + + # 首先添加论文概述 + overview = { + "title": fragments[0].title, + "abstract": fragments[0].abstract, + "arxiv_id": fragments[0].arxiv_id, + "section_tree": fragments[0].section_tree, + } + + overview_text = ( + f"Paper Title: {overview['title']}\n" + f"ArXiv ID: {overview['arxiv_id']}\n" + f"Abstract: {overview['abstract']}\n" + f"Section Tree:{overview['section_tree']}\n" + f"Type: OVERVIEW" + ) - def _get_fragments_sync(self, arxiv_id: str) -> List[Fragment]: - """Synchronous wrapper for async fragment retrieval""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) try: - return loop.run_until_complete(self._async_get_fragments(arxiv_id)) - finally: - loop.close() - def _process_single_fragment(self, fragment: Fragment, index: int) -> None: - """Process a single paper fragment""" + # 同步添加概述 + self.rag_worker.add_text_to_vector_store(overview_text) + logger.info(f"Added paper overview for {overview['arxiv_id']}") + + # 并行处理其余片段 + tasks = [] + for i, fragment in enumerate(fragments): + tasks.append(self._process_single_fragment(fragment, i)) + await asyncio.gather(*tasks) + logger.info(f"Processed {len(fragments)} fragments successfully") + + + # 保存到本地文件用于调试 + save_fragments_to_file( + fragments, + str(self.fragment_store_dir / f"{overview['arxiv_id']}_fragments.json") + ) + + except Exception as e: + logger.error(f"Error processing fragments: {str(e)}") + raise + + async def _process_single_fragment(self, fragment: Fragment, index: int) -> None: + """处理单个论文片段(改为异步方法)""" + try: + text = ( + f"Paper Title: {fragment.title}\n" + f"Abstract: {fragment.abstract}\n" + f"ArXiv ID: {fragment.arxiv_id}\n" + f"Section: {fragment.current_section}\n" + f"Section Tree: {fragment.section_tree}\n" + f"Content: {fragment.content}\n" + f"Bibliography: {fragment.bibliography}\n" + f"Type: FRAGMENT" + ) + logger.info(f"Processing fragment {index} for paper {fragment.arxiv_id}") + # 如果 add_text_to_vector_store 是异步的,使用 await + self.rag_worker.add_text_to_vector_store(text) + logger.info(f"Successfully added fragment {index} to vector store") + + except Exception as e: + logger.error(f"Error processing fragment {index}: {str(e)}") + raise """处理单个论文片段""" try: text = ( f"Paper Title: {fragment.title}\n" @@ -124,62 +198,8 @@ class ArxivRagWorker: logger.error(f"Error processing fragment {index}: {str(e)}") raise - def _process_fragments(self, fragments: List[Fragment]) -> None: - """Process paper fragments in parallel using thread pool""" - if not fragments: - logger.warning("No fragments to process") - return - - # First add paper overview - overview = { - "title": fragments[0].title, - "abstract": fragments[0].abstract, - "arxiv_id": fragments[0].arxiv_id, - "section_tree": fragments[0].section_tree, - } - - overview_text = ( - f"Paper Title: {overview['title']}\n" - f"ArXiv ID: {overview['arxiv_id']}\n" - f"Abstract: {overview['abstract']}\n" - f"Section Tree:{overview['section_tree']}\n" - f"Type: OVERVIEW" - ) - - try: - # Add overview synchronously - self.rag_worker.add_text_to_vector_store(overview_text) - logger.info(f"Added paper overview for {overview['arxiv_id']}") - - # Process fragments in parallel using thread pool - with ThreadPoolExecutor(max_workers=10) as executor: - # Submit all fragments for processing - futures = [ - executor.submit(self._process_single_fragment, fragment, i) - for i, fragment in enumerate(fragments) - ] - - # Wait for all tasks to complete and handle any exceptions - for future in futures: - try: - future.result() - except Exception as e: - logger.error(f"Error processing fragment: {str(e)}") - - logger.info(f"Processed {len(fragments)} fragments successfully") - - # Save to local file for debugging - save_fragments_to_file( - fragments, - str(self.fragment_store_dir / f"{overview['arxiv_id']}_fragments.json") - ) - - except Exception as e: - logger.error(f"Error processing fragments: {str(e)}") - raise - - def process_paper(self, arxiv_id: str) -> bool: - """Process paper main function - mixed sync/async version""" + async def process_paper(self, arxiv_id: str) -> bool: + """处理论文主函数""" try: arxiv_id = self._normalize_arxiv_id(arxiv_id) logger.info(f"Starting to process paper: {arxiv_id}") @@ -190,29 +210,30 @@ class ArxivRagWorker: logger.info(f"Paper {arxiv_id} already processed") return True - # Create processing task + # 创建处理任务 task = ProcessingTask(arxiv_id=arxiv_id) self.processing_queue[arxiv_id] = task task.status = "processing" - # Download and split paper using the sync wrapper - fragments = self._get_fragments_sync(arxiv_id) + async with self.semaphore: + # 下载和分割论文 + fragments = await self.arxiv_splitter.process(arxiv_id) - if not fragments: - raise ValueError(f"No fragments extracted from paper {arxiv_id}") + if not fragments: + raise ValueError(f"No fragments extracted from paper {arxiv_id}") - logger.info(f"Got {len(fragments)} fragments from paper {arxiv_id}") + logger.info(f"Got {len(fragments)} fragments from paper {arxiv_id}") - # Process fragments - self._process_fragments(fragments) + # 处理片段 + await self._process_fragments(fragments) - # Mark as completed - paper_path.touch() - task.status = "completed" - task.fragments = fragments + # 标记完成 + paper_path.touch() + task.status = "completed" + task.fragments = fragments - logger.info(f"Successfully processed paper {arxiv_id}") - return True + logger.info(f"Successfully processed paper {arxiv_id}") + return True except Exception as e: logger.error(f"Error processing paper {arxiv_id}: {str(e)}") @@ -220,8 +241,19 @@ class ArxivRagWorker: self.processing_queue[arxiv_id].status = "failed" self.processing_queue[arxiv_id].error = str(e) return False - def wait_for_paper(self, arxiv_id: str, timeout: float = 300.0) -> bool: - """Wait for paper processing to complete - synchronous version""" + + def _normalize_arxiv_id(self, input_str: str) -> str: + """规范化ArXiv ID""" + if 'arxiv.org/' in input_str.lower(): + if '/pdf/' in input_str: + arxiv_id = input_str.split('/pdf/')[-1] + else: + arxiv_id = input_str.split('/abs/')[-1] + return arxiv_id.split('v')[0].strip() + return input_str.split('v')[0].strip() + + async def wait_for_paper(self, arxiv_id: str, timeout: float = 300.0) -> bool: + """等待论文处理完成""" try: start_time = datetime.now() while True: @@ -235,27 +267,16 @@ class ArxivRagWorker: if task.status == "failed": return False - # Check timeout + # 检查超时 if (datetime.now() - start_time).total_seconds() > timeout: logger.error(f"Processing paper {arxiv_id} timed out") return False - time.sleep(0.1) + await asyncio.sleep(0.1) except Exception as e: logger.error(f"Error waiting for paper {arxiv_id}: {str(e)}") return False - def _normalize_arxiv_id(self, input_str: str) -> str: - """Normalize ArXiv ID""" - if 'arxiv.org/' in input_str.lower(): - if '/pdf/' in input_str: - arxiv_id = input_str.split('/pdf/')[-1] - else: - arxiv_id = input_str.split('/abs/')[-1] - return arxiv_id.split('v')[0].strip() - return input_str.split('v')[0].strip() - - def retrieve_and_generate(self, query: str) -> str: """检索相关内容并生成提示词""" try: @@ -332,10 +353,20 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: chatbot.append((txt, "正在处理论文,请稍等...")) yield from update_ui(chatbot=chatbot, history=history) - success = worker.process_paper(txt) - if success: - arxiv_id = worker._normalize_arxiv_id(txt) - success = worker.wait_for_paper(arxiv_id) + # 创建事件循环来处理异步调用 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + # 运行异步处理函数 + success = loop.run_until_complete(worker.process_paper(txt)) + if success: + arxiv_id = worker._normalize_arxiv_id(txt) + success = loop.run_until_complete(worker.wait_for_paper(arxiv_id)) + if success: + # 执行自动分析 + yield from worker.auto_analyze_paper(chatbot, history, system_prompt) + finally: + loop.close() if not success: chatbot[-1] = (txt, "论文处理失败,请检查论文ID是否正确或稍后重试。")