available

这个提交包含在:
lbykkkk
2024-11-23 19:13:10 +00:00
父节点 e1dc600030
当前提交 50dbff3a14

查看文件

@@ -1,21 +1,21 @@
import os
import logging
import asyncio
from pathlib import Path
from typing import List, Optional, Generator, Dict, Union
from datetime import datetime
from dataclasses import dataclass
import logging
import os
import threading
import time
from concurrent.futures import ThreadPoolExecutor
import aiohttp
from dataclasses import dataclass, field
from pathlib import Path
from threading import Lock as ThreadLock
from typing import Generator
from typing import List, Dict, Optional
from shared_utils.fastapi_server import validate_path_safety
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg
from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import ArxivSplitter, save_fragments_to_file
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment as Fragment
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
from crazy_functions.crazy_utils import input_clipping
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import ArxivSplitter, save_fragments_to_file
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment as Fragment
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
from toolbox import CatchException, update_ui, get_log_folder, update_ui_lastest_msg
# 全局常量配置
MAX_HISTORY_ROUND = 5 # 最大历史对话轮数
@@ -32,6 +32,8 @@ logging.basicConfig(
)
logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
@dataclass
class ProcessingTask:
@@ -40,142 +42,231 @@ class ProcessingTask:
status: str = "pending" # pending, processing, completed, failed
error: Optional[str] = None
fragments: List[Fragment] = None
start_time: float = field(default_factory=time.time)
class ArxivRagWorker:
def __init__(self, user_name: str, llm_kwargs: Dict, arxiv_id: str = None):
"""初始化ArxivRagWorker"""
self.user_name = user_name
self.llm_kwargs = llm_kwargs
self.max_concurrent_papers = MAX_CONCURRENT_PAPERS # 存储最大并发数
self.arxiv_id = self._normalize_arxiv_id(arxiv_id) if arxiv_id else None
# 初始化基础存储目录
# 初始化基础目录
self.base_dir = Path(get_log_folder(user_name, plugin_name='rag_cache'))
self._setup_directories()
# 初始化处理状态
# 线程安全的计数器和集合
self._processing_lock = ThreadLock()
self._processed_fragments = set()
self._processed_count = 0
# 优化的线程池配置
cpu_count = os.cpu_count() or 1
self.thread_pool = ThreadPoolExecutor(
max_workers=min(32, cpu_count * 4),
thread_name_prefix="arxiv_worker"
)
# 批处理配置
self._batch_size = min(20, cpu_count * 2) # 动态设置批大小
self.max_concurrent_papers = MAX_CONCURRENT_PAPERS
self._semaphore = None
self._loop = None
# 初始化处理队列
self.processing_queue = {}
# 初始化工作组件
self._init_workers()
def _setup_directories(self):
"""设置工作目录"""
if self.arxiv_id:
self.checkpoint_dir = self.base_dir / self.arxiv_id
self.vector_store_dir = self.checkpoint_dir / "vector_store"
self.fragment_store_dir = self.checkpoint_dir / "fragments"
else:
# 如果没有 arxiv_id,使用基础目录
self.checkpoint_dir = self.base_dir
self.vector_store_dir = self.base_dir / "vector_store"
self.fragment_store_dir = self.base_dir / "fragments"
if os.path.exists(self.vector_store_dir):
self.loading = True
self.paper_path = self.checkpoint_dir / f"{self.arxiv_id}.processed"
self.loading = self.paper_path.exists()
# 创建必要的目录
for directory in [self.checkpoint_dir, self.vector_store_dir, self.fragment_store_dir]:
directory.mkdir(parents=True, exist_ok=True)
logger.info(f"Created directory: {directory}")
def _init_workers(self):
"""初始化工作组件"""
try:
self.rag_worker = LlamaIndexRagWorker(
user_name=self.user_name,
llm_kwargs=self.llm_kwargs,
checkpoint_dir=str(self.vector_store_dir),
auto_load_checkpoint=True
)
self.arxiv_splitter = ArxivSplitter(
root_dir=str(self.checkpoint_dir / "arxiv_cache")
)
except Exception as e:
logger.error(f"Error initializing workers: {str(e)}")
raise
def _ensure_loop(self):
"""确保存在事件循环"""
if threading.current_thread() is threading.main_thread():
if self._loop is None:
self._loop = asyncio.get_event_loop()
else:
self.loading = False
# Create necessary directories
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.vector_store_dir.mkdir(parents=True, exist_ok=True)
self.fragment_store_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Checkpoint directory: {self.checkpoint_dir}")
logger.info(f"Vector store directory: {self.vector_store_dir}")
logger.info(f"Fragment store directory: {self.fragment_store_dir}")
# 初始化处理队列和线程池
self.processing_queue = {}
self.thread_pool = ThreadPoolExecutor(max_workers=MAX_WORKERS)
# 初始化RAG worker
self.rag_worker = LlamaIndexRagWorker(
user_name=user_name,
llm_kwargs=llm_kwargs,
checkpoint_dir=str(self.vector_store_dir),
auto_load_checkpoint=True
)
# 初始化arxiv splitter
# 初始化 arxiv splitter
self.arxiv_splitter = ArxivSplitter(
root_dir=str(self.checkpoint_dir / "arxiv_cache")
)
# 初始化处理队列和线程池
self._semaphore = None
self._loop = None
@property
def loop(self):
"""获取当前事件循环"""
if self._loop is None:
self._loop = asyncio.get_event_loop()
try:
self._loop = asyncio.get_event_loop()
except RuntimeError:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
return self._loop
@property
def semaphore(self):
"""延迟创建 semaphore"""
"""延迟创建semaphore"""
if self._semaphore is None:
self._semaphore = asyncio.Semaphore(self.max_concurrent_papers)
return self._semaphore
async def _process_fragments(self, fragments: List[Fragment]) -> None:
"""并行处理论文片段"""
"""优化的并行处理论文片段"""
if not fragments:
logger.warning("No fragments to process")
return
# 首先添加论文概述
overview = {
"title": fragments[0].title,
"abstract": fragments[0].abstract,
"arxiv_id": fragments[0].arxiv_id,
"section_tree": fragments[0].section_tree,
start_time = time.time()
total_fragments = len(fragments)
try:
# 1. 处理论文概述
overview = self._create_overview(fragments[0])
overview_success = self._safe_add_to_vector_store_sync(overview['text'])
if not overview_success:
raise RuntimeError("Failed to add overview to vector store")
# 2. 并行处理片段
successful_fragments = await self._parallel_process_fragments(fragments)
# 3. 保存处理结果
if successful_fragments > 0:
await self._save_results(fragments, overview['arxiv_id'], successful_fragments)
except Exception as e:
logger.error(f"Error in fragment processing: {str(e)}")
raise
finally:
self._log_processing_stats(start_time, total_fragments)
def _create_overview(self, first_fragment: Fragment) -> Dict:
"""创建论文概述"""
return {
'arxiv_id': first_fragment.arxiv_id,
'text': (
f"Paper Title: {first_fragment.title}\n"
f"ArXiv ID: {first_fragment.arxiv_id}\n"
f"Abstract: {first_fragment.abstract}\n"
f"Section Tree:{first_fragment.section_tree}\n"
f"Type: OVERVIEW"
)
}
overview_text = (
f"Paper Title: {overview['title']}\n"
f"ArXiv ID: {overview['arxiv_id']}\n"
f"Abstract: {overview['abstract']}\n"
f"Section Tree:{overview['section_tree']}\n"
f"Type: OVERVIEW"
)
async def _parallel_process_fragments(self, fragments: List[Fragment]) -> int:
"""并行处理所有片段"""
successful_count = 0
loop = self._ensure_loop()
for i in range(0, len(fragments), self._batch_size):
batch = fragments[i:i + self._batch_size]
batch_futures = []
for j, fragment in enumerate(batch):
if not self._is_fragment_processed(fragment, i + j):
future = loop.run_in_executor(
self.thread_pool,
self._process_single_fragment_sync,
fragment,
i + j
)
batch_futures.append(future)
if batch_futures:
results = await asyncio.gather(*batch_futures, return_exceptions=True)
successful_count += sum(1 for r in results if isinstance(r, bool) and r)
return successful_count
def _is_fragment_processed(self, fragment: Fragment, index: int) -> bool:
"""检查片段是否已处理"""
fragment_id = f"{fragment.arxiv_id}_{index}"
with self._processing_lock:
return fragment_id in self._processed_fragments
def _safe_add_to_vector_store_sync(self, text: str) -> bool:
"""线程安全的向量存储添加"""
with self._processing_lock:
try:
self.rag_worker.add_text_to_vector_store(text)
return True
except Exception as e:
logger.error(f"Error adding to vector store: {str(e)}")
return False
def _process_single_fragment_sync(self, fragment: Fragment, index: int) -> bool:
"""处理单个片段"""
fragment_id = f"{fragment.arxiv_id}_{index}"
try:
# 同步添加概述
self.rag_worker.add_text_to_vector_store(overview_text)
logger.info(f"Added paper overview for {overview['arxiv_id']}")
# 并行处理其余片段
tasks = []
for i, fragment in enumerate(fragments):
tasks.append(self._process_single_fragment(fragment, i))
await asyncio.gather(*tasks)
logger.info(f"Processed {len(fragments)} fragments successfully")
# 保存到本地文件用于调试
save_fragments_to_file(
fragments,
str(self.fragment_store_dir / f"{overview['arxiv_id']}_fragments.json")
)
except Exception as e:
logger.error(f"Error processing fragments: {str(e)}")
raise
async def _process_single_fragment(self, fragment: Fragment, index: int) -> None:
"""处理单个论文片段(改为异步方法)"""
try:
text = (
f"Paper Title: {fragment.title}\n"
f"Section: {fragment.current_section}\n"
f"Content: {fragment.content}\n"
f"Bibliography: {fragment.bibliography}\n"
f"Type: FRAGMENT"
)
logger.info(f"Processing fragment {index} for paper {fragment.arxiv_id}")
# 如果 add_text_to_vector_store 是异步的,使用 await
self.rag_worker.add_text_to_vector_store(text)
logger.info(f"Successfully added fragment {index} to vector store")
text = self._build_fragment_text(fragment)
if self._safe_add_to_vector_store_sync(text):
with self._processing_lock:
self._processed_fragments.add(fragment_id)
self._processed_count += 1
logger.info(f"Successfully processed fragment {index}")
return True
return False
except Exception as e:
logger.error(f"Error processing fragment {index}: {str(e)}")
raise """处理单个论文片段"""
return False
def _build_fragment_text(self, fragment: Fragment) -> str:
"""构建片段文本"""
return "".join([
f"Paper Title: {fragment.title}\n",
f"Section: {fragment.current_section}\n",
f"Content: {fragment.content}\n",
f"Bibliography: {fragment.bibliography}\n",
"Type: FRAGMENT"
])
async def _save_results(self, fragments: List[Fragment], arxiv_id: str, successful_count: int) -> None:
"""保存处理结果"""
if successful_count > 0:
loop = self._ensure_loop()
await loop.run_in_executor(
self.thread_pool,
save_fragments_to_file,
fragments,
str(self.fragment_store_dir / f"{arxiv_id}_fragments.json")
)
def _log_processing_stats(self, start_time: float, total_fragments: int) -> None:
"""记录处理统计信息"""
elapsed_time = time.time() - start_time
processing_rate = total_fragments / elapsed_time if elapsed_time > 0 else 0
logger.info(
f"Processed {self._processed_count}/{total_fragments} fragments "
f"in {elapsed_time:.2f}s (rate: {processing_rate:.2f} fragments/s)"
)
async def process_paper(self, arxiv_id: str) -> bool:
"""处理论文主函数"""
@@ -183,47 +274,61 @@ class ArxivRagWorker:
arxiv_id = self._normalize_arxiv_id(arxiv_id)
logger.info(f"Starting to process paper: {arxiv_id}")
paper_path = self.checkpoint_dir / f"{arxiv_id}.processed"
if paper_path.exists():
if self.paper_path.exists():
logger.info(f"Paper {arxiv_id} already processed")
return True
# 创建处理任务
task = ProcessingTask(arxiv_id=arxiv_id)
self.processing_queue[arxiv_id] = task
task.status = "processing"
task = self._create_processing_task(arxiv_id)
async with self.semaphore:
# 下载和分割论文
fragments = await self.arxiv_splitter.process(arxiv_id)
try:
async with self.semaphore:
fragments = await self.arxiv_splitter.process(arxiv_id)
if not fragments:
raise ValueError(f"No fragments extracted from paper {arxiv_id}")
if not fragments:
raise ValueError(f"No fragments extracted from paper {arxiv_id}")
logger.info(f"Extracted {len(fragments)} fragments from paper {arxiv_id}")
await self._process_fragments(fragments)
logger.info(f"Got {len(fragments)} fragments from paper {arxiv_id}")
self._complete_task(task, fragments, self.paper_path)
return True
# 处理片段
await self._process_fragments(fragments)
# 标记完成
paper_path.touch()
task.status = "completed"
task.fragments = fragments
logger.info(f"Successfully processed paper {arxiv_id}")
return True
except Exception as e:
self._fail_task(task, str(e))
raise
except Exception as e:
logger.error(f"Error processing paper {arxiv_id}: {str(e)}")
if arxiv_id in self.processing_queue:
self.processing_queue[arxiv_id].status = "failed"
self.processing_queue[arxiv_id].error = str(e)
return False
def _create_processing_task(self, arxiv_id: str) -> ProcessingTask:
"""创建处理任务"""
task = ProcessingTask(arxiv_id=arxiv_id)
with self._processing_lock:
self.processing_queue[arxiv_id] = task
task.status = "processing"
return task
def _complete_task(self, task: ProcessingTask, fragments: List[Fragment], paper_path: Path) -> None:
"""完成任务处理"""
with self._processing_lock:
task.status = "completed"
task.fragments = fragments
paper_path.touch()
logger.info(f"Paper {task.arxiv_id} processed successfully with {self._processed_count} fragments")
def _fail_task(self, task: ProcessingTask, error: str) -> None:
"""任务失败处理"""
with self._processing_lock:
task.status = "failed"
task.error = error
def _normalize_arxiv_id(self, input_str: str) -> str:
"""规范化ArXiv ID"""
if 'arxiv.org/' in input_str.lower():
if not input_str:
return ""
input_str = input_str.strip().lower()
if 'arxiv.org/' in input_str:
if '/pdf/' in input_str:
arxiv_id = input_str.split('/pdf/')[-1]
else:
@@ -233,21 +338,20 @@ class ArxivRagWorker:
async def wait_for_paper(self, arxiv_id: str, timeout: float = 300.0) -> bool:
"""等待论文处理完成"""
start_time = time.time()
try:
start_time = datetime.now()
while True:
task = self.processing_queue.get(arxiv_id)
with self._processing_lock:
task = self.processing_queue.get(arxiv_id)
if not task:
return False
if task.status == "completed":
return True
if task.status == "failed":
return False
# 检查超时
if (datetime.now() - start_time).total_seconds() > timeout:
if time.time() - start_time > timeout:
logger.error(f"Processing paper {arxiv_id} timed out")
return False
@@ -319,7 +423,7 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot:
web_port: Web端口
"""
# 初始化时,提示用户需要 arxiv ID/URL
if len(history) == 0 and not txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '0','1', '2')):
if len(history) == 0 and not txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '0', '1', '2')):
chatbot.append((txt, "请先提供Arxiv论文链接或ID。"))
yield from update_ui(chatbot=chatbot, history=history)
return
@@ -332,32 +436,51 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot:
chatbot.append((txt, "正在处理论文,请稍等..."))
yield from update_ui(chatbot=chatbot, history=history)
# 创建事件循环来处理异步调用
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# 运行异步处理函数
success = loop.run_until_complete(worker.process_paper(txt))
if success:
arxiv_id = worker._normalize_arxiv_id(txt)
success = loop.run_until_complete(worker.wait_for_paper(arxiv_id))
# if success:
# # 执行自动分析
# yield from worker.auto_analyze_paper(chatbot, history, system_prompt)
finally:
loop.close()
# 创建新的事件循环
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
if not success:
chatbot[-1] = (txt, "论文处理失败,请检查论文ID是否正确或稍后重试。")
# 使用超时控制
success = False
try:
# 设置超时时间为5分钟
success = loop.run_until_complete(
asyncio.wait_for(worker.process_paper(txt), timeout=300)
)
if success:
arxiv_id = worker._normalize_arxiv_id(txt)
success = loop.run_until_complete(
asyncio.wait_for(worker.wait_for_paper(arxiv_id), timeout=60)
)
if success:
chatbot[-1] = (txt, "论文处理完成,您现在可以开始提问。")
else:
chatbot[-1] = (txt, "论文处理超时,请重试。")
else:
chatbot[-1] = (txt, "论文处理失败,请检查论文ID是否正确或稍后重试。")
except asyncio.TimeoutError:
chatbot[-1] = (txt, "论文处理超时,请重试。")
success = False
finally:
loop.close()
if not success:
yield from update_ui(chatbot=chatbot, history=history)
return
except Exception as e:
logger.error(f"Error in main process: {str(e)}")
chatbot[-1] = (txt, f"处理过程中发生错误: {str(e)}")
yield from update_ui(chatbot=chatbot, history=history)
return
yield from update_ui(chatbot=chatbot, history=history)
return
# 处理用户询问的情况
# 获取用户询问指令
user_query = plugin_kwargs.get("advanced_arg", "What is the main research question or problem addressed in this paper?")
user_query = plugin_kwargs.get("advanced_arg",
"What is the main research question or problem addressed in this paper?")
if not user_query:
user_query = "What is the main research question or problem addressed in this paper about graph attention network?"
# chatbot.append((txt, "请提供您的问题。"))
@@ -380,7 +503,9 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot:
yield from update_ui_lastest_msg('检测到长输入,正在处理...', chatbot, history, delay=0)
if len(user_query) > REMEMBER_PREVIEW:
HALF = REMEMBER_PREVIEW // 2
query_to_remember = user_query[:HALF] + f" ...\n...(省略{len(user_query) - REMEMBER_PREVIEW}字)...\n... " + user_query[-HALF:]
query_to_remember = user_query[
:HALF] + f" ...\n...(省略{len(user_query) - REMEMBER_PREVIEW}字)...\n... " + user_query[
-HALF:]
else:
query_to_remember = query_clip
else:
@@ -412,6 +537,7 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot:
yield from update_ui(chatbot=chatbot, history=history)
if __name__ == "__main__":
# 测试代码
llm_kwargs = {