diff --git a/crazy_functional.py b/crazy_functional.py index c4df777c..001e6998 100644 --- a/crazy_functional.py +++ b/crazy_functional.py @@ -619,23 +619,6 @@ def get_crazy_functions(): logger.error(trimmed_format_exc()) logger.error("Load function plugin failed") - try: - from crazy_functions.Arxiv_论文对话 import Arxiv论文对话 - - function_plugins.update( - { - "Arxiv论文对话": { - "Group": "对话", - "Color": "stop", - "AsButton": False, - "Info": "将问答数据记录到向量库中,作为长期参考。", - "Function": HotReload(Arxiv论文对话), - }, - } - ) - except: - logger.error(trimmed_format_exc()) - logger.error("Load function plugin failed") diff --git a/crazy_functions/Arxiv_论文对话.py b/crazy_functions/Arxiv_论文对话.py index debc2c77..2c29503f 100644 --- a/crazy_functions/Arxiv_论文对话.py +++ b/crazy_functions/Arxiv_论文对话.py @@ -43,22 +43,42 @@ class ProcessingTask: class ArxivRagWorker: - def __init__(self, user_name: str, llm_kwargs: Dict): + def __init__(self, user_name: str, llm_kwargs: Dict, arxiv_id: str = None): self.user_name = user_name self.llm_kwargs = llm_kwargs + self.max_concurrent_papers = MAX_CONCURRENT_PAPERS # 存储最大并发数 + self.arxiv_id = self._normalize_arxiv_id(arxiv_id) if arxiv_id else None - # 初始化存储目录 - self.checkpoint_dir = Path(get_log_folder(user_name, plugin_name='rag_cache')) - self.vector_store_dir = self.checkpoint_dir / "vector_store" - self.fragment_store_dir = self.checkpoint_dir / "fragments" + # 初始化基础存储目录 + self.base_dir = Path(get_log_folder(user_name, plugin_name='rag_cache')) + if os.path.exists(self.base_dir): + self.loading = True + else: + self.loading = False + # 如果提供了 arxiv_id,创建针对该论文的子目录 + if self.arxiv_id: + self.checkpoint_dir = self.base_dir / self.arxiv_id + self.vector_store_dir = self.checkpoint_dir / "vector_store" + self.fragment_store_dir = self.checkpoint_dir / "fragments" + else: + # 如果没有 arxiv_id,使用基础目录 + self.checkpoint_dir = self.base_dir + self.vector_store_dir = self.base_dir / "vector_store" + self.fragment_store_dir = self.base_dir / "fragments" # 创建必要的目录 + self.checkpoint_dir.mkdir(parents=True, exist_ok=True) self.vector_store_dir.mkdir(parents=True, exist_ok=True) self.fragment_store_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Checkpoint directory: {self.checkpoint_dir}") logger.info(f"Vector store directory: {self.vector_store_dir}") logger.info(f"Fragment store directory: {self.fragment_store_dir}") + # 初始化处理队列和线程池 + self.processing_queue = {} + self.thread_pool = ThreadPoolExecutor(max_workers=MAX_WORKERS) + # 初始化RAG worker self.rag_worker = LlamaIndexRagWorker( user_name=user_name, @@ -68,15 +88,30 @@ class ArxivRagWorker: ) # 初始化arxiv splitter + # 初始化 arxiv splitter self.arxiv_splitter = ArxivSplitter( char_range=(1000, 1200), root_dir=str(self.checkpoint_dir / "arxiv_cache") ) + # 初始化处理队列和线程池 + self._semaphore = None + self._loop = None + + @property + def loop(self): + """获取当前事件循环""" + if self._loop is None: + self._loop = asyncio.get_event_loop() + return self._loop + + @property + def semaphore(self): + """延迟创建 semaphore""" + if self._semaphore is None: + self._semaphore = asyncio.Semaphore(self.max_concurrent_papers) + return self._semaphore + - # 初始化并行处理组件 - self.processing_queue = {} - self.semaphore = asyncio.Semaphore(MAX_CONCURRENT_PAPERS) - self.thread_pool = ThreadPoolExecutor(max_workers=MAX_WORKERS) async def _process_fragments(self, fragments: List[Fragment]) -> None: """并行处理论文片段""" @@ -106,17 +141,11 @@ class ArxivRagWorker: # 并行处理其余片段 tasks = [] for i, fragment in enumerate(fragments): - task = asyncio.get_event_loop().run_in_executor( - self.thread_pool, - self._process_single_fragment, - fragment, - i - ) - tasks.append(task) - + tasks.append(self._process_single_fragment(fragment, i)) await asyncio.gather(*tasks) logger.info(f"Processed {len(fragments)} fragments successfully") + # 保存到本地文件用于调试 save_fragments_to_file( fragments, @@ -127,8 +156,26 @@ class ArxivRagWorker: logger.error(f"Error processing fragments: {str(e)}") raise - def _process_single_fragment(self, fragment: Fragment, index: int) -> None: - """处理单个论文片段""" + async def _process_single_fragment(self, fragment: Fragment, index: int) -> None: + """处理单个论文片段(改为异步方法)""" + try: + text = ( + f"Paper Title: {fragment.title}\n" + f"ArXiv ID: {fragment.arxiv_id}\n" + f"Section: {fragment.section}\n" + f"Fragment Index: {index}\n" + f"Content: {fragment.content}\n" + f"Type: FRAGMENT" + ) + + logger.info(f"Processing fragment {index} for paper {fragment.arxiv_id}") + # 如果 add_text_to_vector_store 是异步的,使用 await + self.rag_worker.add_text_to_vector_store(text) + logger.info(f"Successfully added fragment {index} to vector store") + + except Exception as e: + logger.error(f"Error processing fragment {index}: {str(e)}") + raise """处理单个论文片段""" try: text = ( f"Paper Title: {fragment.title}\n" @@ -289,16 +336,16 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: web_port: Web端口 """ # 初始化时,提示用户需要 arxiv ID/URL - if len(history) == 0 and not txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '1', '2')): + if len(history) == 0 and not txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '0','1', '2')): chatbot.append((txt, "请先提供Arxiv论文链接或ID。")) yield from update_ui(chatbot=chatbot, history=history) return user_name = chatbot.get_user() - worker = ArxivRagWorker(user_name, llm_kwargs) + worker = ArxivRagWorker(user_name, llm_kwargs, arxiv_id=txt) # 处理新论文的情况 - if txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '1', '2')): + if txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '0', '1', '2')) and not worker.loading: chatbot.append((txt, "正在处理论文,请稍等...")) yield from update_ui(chatbot=chatbot, history=history) @@ -327,7 +374,8 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: # 处理用户询问的情况 # 获取用户询问指令 - user_query = plugin_kwargs.get("advanced_arg", "") + user_query = plugin_kwargs.get("advanced_arg", "What is the main research question or problem addressed in this paper?") + # user_query = "What is the main research question or problem addressed in this paper about graph attention network?" if not user_query: chatbot.append((txt, "请提供您的问题。")) yield from update_ui(chatbot=chatbot, history=history) diff --git a/crazy_functions/rag_fns/arxiv_fns/arxiv_fragment.py b/crazy_functions/rag_fns/arxiv_fns/arxiv_fragment.py index 544a8c1d..73d95d9b 100644 --- a/crazy_functions/rag_fns/arxiv_fns/arxiv_fragment.py +++ b/crazy_functions/rag_fns/arxiv_fns/arxiv_fragment.py @@ -14,6 +14,7 @@ class ArxivFragment: section: str # 所属章节 is_appendix: bool # 是否是附录 importance: float = 1.0 # 重要性得分 + arxiv_id: str = "" # 添加 arxiv_id 属性 @staticmethod def merge_segments(seg1: 'ArxivFragment', seg2: 'ArxivFragment') -> 'ArxivFragment':