diff --git a/crazy_functions/Arxiv_论文对话.py b/crazy_functions/Arxiv_论文对话.py index 1713ef4a..771caae5 100644 --- a/crazy_functions/Arxiv_论文对话.py +++ b/crazy_functions/Arxiv_论文对话.py @@ -1,21 +1,21 @@ -import os -import logging import asyncio -from pathlib import Path -from typing import List, Optional, Generator, Dict, Union -from datetime import datetime -from dataclasses import dataclass +import logging +import os +import threading +import time from concurrent.futures import ThreadPoolExecutor -import aiohttp +from dataclasses import dataclass, field +from pathlib import Path +from threading import Lock as ThreadLock +from typing import Generator +from typing import List, Dict, Optional -from shared_utils.fastapi_server import validate_path_safety -from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg -from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import ArxivSplitter, save_fragments_to_file -from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment as Fragment - -from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker from crazy_functions.crazy_utils import input_clipping from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive +from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import ArxivSplitter, save_fragments_to_file +from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment as Fragment +from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker +from toolbox import CatchException, update_ui, get_log_folder, update_ui_lastest_msg # 全局常量配置 MAX_HISTORY_ROUND = 5 # 最大历史对话轮数 @@ -32,6 +32,8 @@ logging.basicConfig( ) logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) + @dataclass class ProcessingTask: @@ -40,142 +42,231 @@ class ProcessingTask: status: str = "pending" # pending, processing, completed, failed error: Optional[str] = None fragments: List[Fragment] = None + start_time: float = field(default_factory=time.time) class ArxivRagWorker: def __init__(self, user_name: str, llm_kwargs: Dict, arxiv_id: str = None): + """初始化ArxivRagWorker""" self.user_name = user_name self.llm_kwargs = llm_kwargs - self.max_concurrent_papers = MAX_CONCURRENT_PAPERS # 存储最大并发数 self.arxiv_id = self._normalize_arxiv_id(arxiv_id) if arxiv_id else None - # 初始化基础存储目录 + + # 初始化基础目录 self.base_dir = Path(get_log_folder(user_name, plugin_name='rag_cache')) + self._setup_directories() + + # 初始化处理状态 + + # 线程安全的计数器和集合 + self._processing_lock = ThreadLock() + self._processed_fragments = set() + self._processed_count = 0 + + # 优化的线程池配置 + cpu_count = os.cpu_count() or 1 + self.thread_pool = ThreadPoolExecutor( + max_workers=min(32, cpu_count * 4), + thread_name_prefix="arxiv_worker" + ) + + # 批处理配置 + self._batch_size = min(20, cpu_count * 2) # 动态设置批大小 + self.max_concurrent_papers = MAX_CONCURRENT_PAPERS + self._semaphore = None + self._loop = None + + # 初始化处理队列 + self.processing_queue = {} + + # 初始化工作组件 + self._init_workers() + + def _setup_directories(self): + """设置工作目录""" if self.arxiv_id: self.checkpoint_dir = self.base_dir / self.arxiv_id self.vector_store_dir = self.checkpoint_dir / "vector_store" self.fragment_store_dir = self.checkpoint_dir / "fragments" else: - # 如果没有 arxiv_id,使用基础目录 self.checkpoint_dir = self.base_dir self.vector_store_dir = self.base_dir / "vector_store" self.fragment_store_dir = self.base_dir / "fragments" - if os.path.exists(self.vector_store_dir): - self.loading = True + self.paper_path = self.checkpoint_dir / f"{self.arxiv_id}.processed" + self.loading = self.paper_path.exists() + # 创建必要的目录 + for directory in [self.checkpoint_dir, self.vector_store_dir, self.fragment_store_dir]: + directory.mkdir(parents=True, exist_ok=True) + logger.info(f"Created directory: {directory}") + + def _init_workers(self): + """初始化工作组件""" + try: + self.rag_worker = LlamaIndexRagWorker( + user_name=self.user_name, + llm_kwargs=self.llm_kwargs, + checkpoint_dir=str(self.vector_store_dir), + auto_load_checkpoint=True + ) + + self.arxiv_splitter = ArxivSplitter( + root_dir=str(self.checkpoint_dir / "arxiv_cache") + ) + except Exception as e: + logger.error(f"Error initializing workers: {str(e)}") + raise + + def _ensure_loop(self): + """确保存在事件循环""" + if threading.current_thread() is threading.main_thread(): + if self._loop is None: + self._loop = asyncio.get_event_loop() else: - self.loading = False - - # Create necessary directories - self.checkpoint_dir.mkdir(parents=True, exist_ok=True) - self.vector_store_dir.mkdir(parents=True, exist_ok=True) - self.fragment_store_dir.mkdir(parents=True, exist_ok=True) - - logger.info(f"Checkpoint directory: {self.checkpoint_dir}") - logger.info(f"Vector store directory: {self.vector_store_dir}") - logger.info(f"Fragment store directory: {self.fragment_store_dir}") - - # 初始化处理队列和线程池 - self.processing_queue = {} - self.thread_pool = ThreadPoolExecutor(max_workers=MAX_WORKERS) - - # 初始化RAG worker - self.rag_worker = LlamaIndexRagWorker( - user_name=user_name, - llm_kwargs=llm_kwargs, - checkpoint_dir=str(self.vector_store_dir), - auto_load_checkpoint=True - ) - - # 初始化arxiv splitter - # 初始化 arxiv splitter - self.arxiv_splitter = ArxivSplitter( - root_dir=str(self.checkpoint_dir / "arxiv_cache") - ) - # 初始化处理队列和线程池 - self._semaphore = None - self._loop = None - - @property - def loop(self): - """获取当前事件循环""" - if self._loop is None: - self._loop = asyncio.get_event_loop() + try: + self._loop = asyncio.get_event_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) return self._loop @property def semaphore(self): - """延迟创建 semaphore""" + """延迟创建semaphore""" if self._semaphore is None: self._semaphore = asyncio.Semaphore(self.max_concurrent_papers) return self._semaphore - - async def _process_fragments(self, fragments: List[Fragment]) -> None: - """并行处理论文片段""" + """优化的并行处理论文片段""" if not fragments: logger.warning("No fragments to process") return - # 首先添加论文概述 - overview = { - "title": fragments[0].title, - "abstract": fragments[0].abstract, - "arxiv_id": fragments[0].arxiv_id, - "section_tree": fragments[0].section_tree, + start_time = time.time() + total_fragments = len(fragments) + + try: + # 1. 处理论文概述 + overview = self._create_overview(fragments[0]) + overview_success = self._safe_add_to_vector_store_sync(overview['text']) + if not overview_success: + raise RuntimeError("Failed to add overview to vector store") + + # 2. 并行处理片段 + successful_fragments = await self._parallel_process_fragments(fragments) + + # 3. 保存处理结果 + if successful_fragments > 0: + await self._save_results(fragments, overview['arxiv_id'], successful_fragments) + + except Exception as e: + logger.error(f"Error in fragment processing: {str(e)}") + raise + finally: + self._log_processing_stats(start_time, total_fragments) + + def _create_overview(self, first_fragment: Fragment) -> Dict: + """创建论文概述""" + return { + 'arxiv_id': first_fragment.arxiv_id, + 'text': ( + f"Paper Title: {first_fragment.title}\n" + f"ArXiv ID: {first_fragment.arxiv_id}\n" + f"Abstract: {first_fragment.abstract}\n" + f"Section Tree:{first_fragment.section_tree}\n" + f"Type: OVERVIEW" + ) } - overview_text = ( - f"Paper Title: {overview['title']}\n" - f"ArXiv ID: {overview['arxiv_id']}\n" - f"Abstract: {overview['abstract']}\n" - f"Section Tree:{overview['section_tree']}\n" - f"Type: OVERVIEW" - ) + async def _parallel_process_fragments(self, fragments: List[Fragment]) -> int: + """并行处理所有片段""" + successful_count = 0 + loop = self._ensure_loop() + for i in range(0, len(fragments), self._batch_size): + batch = fragments[i:i + self._batch_size] + batch_futures = [] + + for j, fragment in enumerate(batch): + if not self._is_fragment_processed(fragment, i + j): + future = loop.run_in_executor( + self.thread_pool, + self._process_single_fragment_sync, + fragment, + i + j + ) + batch_futures.append(future) + + if batch_futures: + results = await asyncio.gather(*batch_futures, return_exceptions=True) + successful_count += sum(1 for r in results if isinstance(r, bool) and r) + + return successful_count + + def _is_fragment_processed(self, fragment: Fragment, index: int) -> bool: + """检查片段是否已处理""" + fragment_id = f"{fragment.arxiv_id}_{index}" + with self._processing_lock: + return fragment_id in self._processed_fragments + + def _safe_add_to_vector_store_sync(self, text: str) -> bool: + """线程安全的向量存储添加""" + with self._processing_lock: + try: + self.rag_worker.add_text_to_vector_store(text) + return True + except Exception as e: + logger.error(f"Error adding to vector store: {str(e)}") + return False + + def _process_single_fragment_sync(self, fragment: Fragment, index: int) -> bool: + """处理单个片段""" + fragment_id = f"{fragment.arxiv_id}_{index}" try: - # 同步添加概述 - self.rag_worker.add_text_to_vector_store(overview_text) - logger.info(f"Added paper overview for {overview['arxiv_id']}") - - # 并行处理其余片段 - tasks = [] - for i, fragment in enumerate(fragments): - tasks.append(self._process_single_fragment(fragment, i)) - await asyncio.gather(*tasks) - logger.info(f"Processed {len(fragments)} fragments successfully") - - - # 保存到本地文件用于调试 - save_fragments_to_file( - fragments, - str(self.fragment_store_dir / f"{overview['arxiv_id']}_fragments.json") - ) - - except Exception as e: - logger.error(f"Error processing fragments: {str(e)}") - raise - - async def _process_single_fragment(self, fragment: Fragment, index: int) -> None: - """处理单个论文片段(改为异步方法)""" - try: - text = ( - f"Paper Title: {fragment.title}\n" - f"Section: {fragment.current_section}\n" - f"Content: {fragment.content}\n" - f"Bibliography: {fragment.bibliography}\n" - f"Type: FRAGMENT" - ) - logger.info(f"Processing fragment {index} for paper {fragment.arxiv_id}") - # 如果 add_text_to_vector_store 是异步的,使用 await - self.rag_worker.add_text_to_vector_store(text) - logger.info(f"Successfully added fragment {index} to vector store") - + text = self._build_fragment_text(fragment) + if self._safe_add_to_vector_store_sync(text): + with self._processing_lock: + self._processed_fragments.add(fragment_id) + self._processed_count += 1 + logger.info(f"Successfully processed fragment {index}") + return True + return False except Exception as e: logger.error(f"Error processing fragment {index}: {str(e)}") - raise """处理单个论文片段""" + return False + + def _build_fragment_text(self, fragment: Fragment) -> str: + """构建片段文本""" + return "".join([ + f"Paper Title: {fragment.title}\n", + f"Section: {fragment.current_section}\n", + f"Content: {fragment.content}\n", + f"Bibliography: {fragment.bibliography}\n", + "Type: FRAGMENT" + ]) + + async def _save_results(self, fragments: List[Fragment], arxiv_id: str, successful_count: int) -> None: + """保存处理结果""" + if successful_count > 0: + loop = self._ensure_loop() + await loop.run_in_executor( + self.thread_pool, + save_fragments_to_file, + fragments, + str(self.fragment_store_dir / f"{arxiv_id}_fragments.json") + ) + + def _log_processing_stats(self, start_time: float, total_fragments: int) -> None: + """记录处理统计信息""" + elapsed_time = time.time() - start_time + processing_rate = total_fragments / elapsed_time if elapsed_time > 0 else 0 + logger.info( + f"Processed {self._processed_count}/{total_fragments} fragments " + f"in {elapsed_time:.2f}s (rate: {processing_rate:.2f} fragments/s)" + ) async def process_paper(self, arxiv_id: str) -> bool: """处理论文主函数""" @@ -183,47 +274,61 @@ class ArxivRagWorker: arxiv_id = self._normalize_arxiv_id(arxiv_id) logger.info(f"Starting to process paper: {arxiv_id}") - paper_path = self.checkpoint_dir / f"{arxiv_id}.processed" - - if paper_path.exists(): + if self.paper_path.exists(): logger.info(f"Paper {arxiv_id} already processed") return True - # 创建处理任务 - task = ProcessingTask(arxiv_id=arxiv_id) - self.processing_queue[arxiv_id] = task - task.status = "processing" + task = self._create_processing_task(arxiv_id) - async with self.semaphore: - # 下载和分割论文 - fragments = await self.arxiv_splitter.process(arxiv_id) + try: + async with self.semaphore: + fragments = await self.arxiv_splitter.process(arxiv_id) + if not fragments: + raise ValueError(f"No fragments extracted from paper {arxiv_id}") - if not fragments: - raise ValueError(f"No fragments extracted from paper {arxiv_id}") + logger.info(f"Extracted {len(fragments)} fragments from paper {arxiv_id}") + await self._process_fragments(fragments) - logger.info(f"Got {len(fragments)} fragments from paper {arxiv_id}") + self._complete_task(task, fragments, self.paper_path) + return True - # 处理片段 - await self._process_fragments(fragments) - - # 标记完成 - paper_path.touch() - task.status = "completed" - task.fragments = fragments - - logger.info(f"Successfully processed paper {arxiv_id}") - return True + except Exception as e: + self._fail_task(task, str(e)) + raise except Exception as e: logger.error(f"Error processing paper {arxiv_id}: {str(e)}") - if arxiv_id in self.processing_queue: - self.processing_queue[arxiv_id].status = "failed" - self.processing_queue[arxiv_id].error = str(e) return False + def _create_processing_task(self, arxiv_id: str) -> ProcessingTask: + """创建处理任务""" + task = ProcessingTask(arxiv_id=arxiv_id) + with self._processing_lock: + self.processing_queue[arxiv_id] = task + task.status = "processing" + return task + + def _complete_task(self, task: ProcessingTask, fragments: List[Fragment], paper_path: Path) -> None: + """完成任务处理""" + with self._processing_lock: + task.status = "completed" + task.fragments = fragments + paper_path.touch() + logger.info(f"Paper {task.arxiv_id} processed successfully with {self._processed_count} fragments") + + def _fail_task(self, task: ProcessingTask, error: str) -> None: + """任务失败处理""" + with self._processing_lock: + task.status = "failed" + task.error = error + def _normalize_arxiv_id(self, input_str: str) -> str: """规范化ArXiv ID""" - if 'arxiv.org/' in input_str.lower(): + if not input_str: + return "" + + input_str = input_str.strip().lower() + if 'arxiv.org/' in input_str: if '/pdf/' in input_str: arxiv_id = input_str.split('/pdf/')[-1] else: @@ -233,21 +338,20 @@ class ArxivRagWorker: async def wait_for_paper(self, arxiv_id: str, timeout: float = 300.0) -> bool: """等待论文处理完成""" + start_time = time.time() try: - start_time = datetime.now() while True: - task = self.processing_queue.get(arxiv_id) + with self._processing_lock: + task = self.processing_queue.get(arxiv_id) if not task: return False if task.status == "completed": return True - if task.status == "failed": return False - # 检查超时 - if (datetime.now() - start_time).total_seconds() > timeout: + if time.time() - start_time > timeout: logger.error(f"Processing paper {arxiv_id} timed out") return False @@ -319,7 +423,7 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: web_port: Web端口 """ # 初始化时,提示用户需要 arxiv ID/URL - if len(history) == 0 and not txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '0','1', '2')): + if len(history) == 0 and not txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '0', '1', '2')): chatbot.append((txt, "请先提供Arxiv论文链接或ID。")) yield from update_ui(chatbot=chatbot, history=history) return @@ -332,32 +436,51 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: chatbot.append((txt, "正在处理论文,请稍等...")) yield from update_ui(chatbot=chatbot, history=history) - # 创建事件循环来处理异步调用 - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) try: - # 运行异步处理函数 - success = loop.run_until_complete(worker.process_paper(txt)) - if success: - arxiv_id = worker._normalize_arxiv_id(txt) - success = loop.run_until_complete(worker.wait_for_paper(arxiv_id)) - # if success: - # # 执行自动分析 - # yield from worker.auto_analyze_paper(chatbot, history, system_prompt) - finally: - loop.close() + # 创建新的事件循环 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) - if not success: - chatbot[-1] = (txt, "论文处理失败,请检查论文ID是否正确或稍后重试。") + # 使用超时控制 + success = False + try: + # 设置超时时间为5分钟 + success = loop.run_until_complete( + asyncio.wait_for(worker.process_paper(txt), timeout=300) + ) + if success: + arxiv_id = worker._normalize_arxiv_id(txt) + success = loop.run_until_complete( + asyncio.wait_for(worker.wait_for_paper(arxiv_id), timeout=60) + ) + if success: + chatbot[-1] = (txt, "论文处理完成,您现在可以开始提问。") + else: + chatbot[-1] = (txt, "论文处理超时,请重试。") + else: + chatbot[-1] = (txt, "论文处理失败,请检查论文ID是否正确或稍后重试。") + except asyncio.TimeoutError: + chatbot[-1] = (txt, "论文处理超时,请重试。") + success = False + finally: + loop.close() + + if not success: + yield from update_ui(chatbot=chatbot, history=history) + return + + except Exception as e: + logger.error(f"Error in main process: {str(e)}") + chatbot[-1] = (txt, f"处理过程中发生错误: {str(e)}") yield from update_ui(chatbot=chatbot, history=history) return yield from update_ui(chatbot=chatbot, history=history) return - # 处理用户询问的情况 # 获取用户询问指令 - user_query = plugin_kwargs.get("advanced_arg", "What is the main research question or problem addressed in this paper?") + user_query = plugin_kwargs.get("advanced_arg", + "What is the main research question or problem addressed in this paper?") if not user_query: user_query = "What is the main research question or problem addressed in this paper about graph attention network?" # chatbot.append((txt, "请提供您的问题。")) @@ -380,7 +503,9 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: yield from update_ui_lastest_msg('检测到长输入,正在处理...', chatbot, history, delay=0) if len(user_query) > REMEMBER_PREVIEW: HALF = REMEMBER_PREVIEW // 2 - query_to_remember = user_query[:HALF] + f" ...\n...(省略{len(user_query) - REMEMBER_PREVIEW}字)...\n... " + user_query[-HALF:] + query_to_remember = user_query[ + :HALF] + f" ...\n...(省略{len(user_query) - REMEMBER_PREVIEW}字)...\n... " + user_query[ + -HALF:] else: query_to_remember = query_clip else: @@ -412,6 +537,7 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: yield from update_ui(chatbot=chatbot, history=history) + if __name__ == "__main__": # 测试代码 llm_kwargs = {