镜像自地址
https://github.com/binary-husky/gpt_academic.git
已同步 2025-12-06 06:26:47 +00:00
available
这个提交包含在:
@@ -1,21 +1,21 @@
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Generator, Dict, Union
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import aiohttp
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from threading import Lock as ThreadLock
|
||||
from typing import Generator
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from shared_utils.fastapi_server import validate_path_safety
|
||||
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg
|
||||
from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import ArxivSplitter, save_fragments_to_file
|
||||
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment as Fragment
|
||||
|
||||
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
|
||||
from crazy_functions.crazy_utils import input_clipping
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import ArxivSplitter, save_fragments_to_file
|
||||
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment as Fragment
|
||||
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
|
||||
from toolbox import CatchException, update_ui, get_log_folder, update_ui_lastest_msg
|
||||
|
||||
# 全局常量配置
|
||||
MAX_HISTORY_ROUND = 5 # 最大历史对话轮数
|
||||
@@ -32,6 +32,8 @@ logging.basicConfig(
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessingTask:
|
||||
@@ -40,142 +42,231 @@ class ProcessingTask:
|
||||
status: str = "pending" # pending, processing, completed, failed
|
||||
error: Optional[str] = None
|
||||
fragments: List[Fragment] = None
|
||||
start_time: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
class ArxivRagWorker:
|
||||
def __init__(self, user_name: str, llm_kwargs: Dict, arxiv_id: str = None):
|
||||
"""初始化ArxivRagWorker"""
|
||||
self.user_name = user_name
|
||||
self.llm_kwargs = llm_kwargs
|
||||
self.max_concurrent_papers = MAX_CONCURRENT_PAPERS # 存储最大并发数
|
||||
self.arxiv_id = self._normalize_arxiv_id(arxiv_id) if arxiv_id else None
|
||||
|
||||
# 初始化基础存储目录
|
||||
|
||||
# 初始化基础目录
|
||||
self.base_dir = Path(get_log_folder(user_name, plugin_name='rag_cache'))
|
||||
self._setup_directories()
|
||||
|
||||
# 初始化处理状态
|
||||
|
||||
# 线程安全的计数器和集合
|
||||
self._processing_lock = ThreadLock()
|
||||
self._processed_fragments = set()
|
||||
self._processed_count = 0
|
||||
|
||||
# 优化的线程池配置
|
||||
cpu_count = os.cpu_count() or 1
|
||||
self.thread_pool = ThreadPoolExecutor(
|
||||
max_workers=min(32, cpu_count * 4),
|
||||
thread_name_prefix="arxiv_worker"
|
||||
)
|
||||
|
||||
# 批处理配置
|
||||
self._batch_size = min(20, cpu_count * 2) # 动态设置批大小
|
||||
self.max_concurrent_papers = MAX_CONCURRENT_PAPERS
|
||||
self._semaphore = None
|
||||
self._loop = None
|
||||
|
||||
# 初始化处理队列
|
||||
self.processing_queue = {}
|
||||
|
||||
# 初始化工作组件
|
||||
self._init_workers()
|
||||
|
||||
def _setup_directories(self):
|
||||
"""设置工作目录"""
|
||||
|
||||
if self.arxiv_id:
|
||||
self.checkpoint_dir = self.base_dir / self.arxiv_id
|
||||
self.vector_store_dir = self.checkpoint_dir / "vector_store"
|
||||
self.fragment_store_dir = self.checkpoint_dir / "fragments"
|
||||
else:
|
||||
# 如果没有 arxiv_id,使用基础目录
|
||||
self.checkpoint_dir = self.base_dir
|
||||
self.vector_store_dir = self.base_dir / "vector_store"
|
||||
self.fragment_store_dir = self.base_dir / "fragments"
|
||||
|
||||
if os.path.exists(self.vector_store_dir):
|
||||
self.loading = True
|
||||
self.paper_path = self.checkpoint_dir / f"{self.arxiv_id}.processed"
|
||||
self.loading = self.paper_path.exists()
|
||||
# 创建必要的目录
|
||||
for directory in [self.checkpoint_dir, self.vector_store_dir, self.fragment_store_dir]:
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Created directory: {directory}")
|
||||
|
||||
def _init_workers(self):
|
||||
"""初始化工作组件"""
|
||||
try:
|
||||
self.rag_worker = LlamaIndexRagWorker(
|
||||
user_name=self.user_name,
|
||||
llm_kwargs=self.llm_kwargs,
|
||||
checkpoint_dir=str(self.vector_store_dir),
|
||||
auto_load_checkpoint=True
|
||||
)
|
||||
|
||||
self.arxiv_splitter = ArxivSplitter(
|
||||
root_dir=str(self.checkpoint_dir / "arxiv_cache")
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing workers: {str(e)}")
|
||||
raise
|
||||
|
||||
def _ensure_loop(self):
|
||||
"""确保存在事件循环"""
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
if self._loop is None:
|
||||
self._loop = asyncio.get_event_loop()
|
||||
else:
|
||||
self.loading = False
|
||||
|
||||
# Create necessary directories
|
||||
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.vector_store_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.fragment_store_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f"Checkpoint directory: {self.checkpoint_dir}")
|
||||
logger.info(f"Vector store directory: {self.vector_store_dir}")
|
||||
logger.info(f"Fragment store directory: {self.fragment_store_dir}")
|
||||
|
||||
# 初始化处理队列和线程池
|
||||
self.processing_queue = {}
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=MAX_WORKERS)
|
||||
|
||||
# 初始化RAG worker
|
||||
self.rag_worker = LlamaIndexRagWorker(
|
||||
user_name=user_name,
|
||||
llm_kwargs=llm_kwargs,
|
||||
checkpoint_dir=str(self.vector_store_dir),
|
||||
auto_load_checkpoint=True
|
||||
)
|
||||
|
||||
# 初始化arxiv splitter
|
||||
# 初始化 arxiv splitter
|
||||
self.arxiv_splitter = ArxivSplitter(
|
||||
root_dir=str(self.checkpoint_dir / "arxiv_cache")
|
||||
)
|
||||
# 初始化处理队列和线程池
|
||||
self._semaphore = None
|
||||
self._loop = None
|
||||
|
||||
@property
|
||||
def loop(self):
|
||||
"""获取当前事件循环"""
|
||||
if self._loop is None:
|
||||
self._loop = asyncio.get_event_loop()
|
||||
try:
|
||||
self._loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
return self._loop
|
||||
|
||||
@property
|
||||
def semaphore(self):
|
||||
"""延迟创建 semaphore"""
|
||||
"""延迟创建semaphore"""
|
||||
if self._semaphore is None:
|
||||
self._semaphore = asyncio.Semaphore(self.max_concurrent_papers)
|
||||
return self._semaphore
|
||||
|
||||
|
||||
|
||||
async def _process_fragments(self, fragments: List[Fragment]) -> None:
|
||||
"""并行处理论文片段"""
|
||||
"""优化的并行处理论文片段"""
|
||||
if not fragments:
|
||||
logger.warning("No fragments to process")
|
||||
return
|
||||
|
||||
# 首先添加论文概述
|
||||
overview = {
|
||||
"title": fragments[0].title,
|
||||
"abstract": fragments[0].abstract,
|
||||
"arxiv_id": fragments[0].arxiv_id,
|
||||
"section_tree": fragments[0].section_tree,
|
||||
start_time = time.time()
|
||||
total_fragments = len(fragments)
|
||||
|
||||
try:
|
||||
# 1. 处理论文概述
|
||||
overview = self._create_overview(fragments[0])
|
||||
overview_success = self._safe_add_to_vector_store_sync(overview['text'])
|
||||
if not overview_success:
|
||||
raise RuntimeError("Failed to add overview to vector store")
|
||||
|
||||
# 2. 并行处理片段
|
||||
successful_fragments = await self._parallel_process_fragments(fragments)
|
||||
|
||||
# 3. 保存处理结果
|
||||
if successful_fragments > 0:
|
||||
await self._save_results(fragments, overview['arxiv_id'], successful_fragments)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in fragment processing: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
self._log_processing_stats(start_time, total_fragments)
|
||||
|
||||
def _create_overview(self, first_fragment: Fragment) -> Dict:
|
||||
"""创建论文概述"""
|
||||
return {
|
||||
'arxiv_id': first_fragment.arxiv_id,
|
||||
'text': (
|
||||
f"Paper Title: {first_fragment.title}\n"
|
||||
f"ArXiv ID: {first_fragment.arxiv_id}\n"
|
||||
f"Abstract: {first_fragment.abstract}\n"
|
||||
f"Section Tree:{first_fragment.section_tree}\n"
|
||||
f"Type: OVERVIEW"
|
||||
)
|
||||
}
|
||||
|
||||
overview_text = (
|
||||
f"Paper Title: {overview['title']}\n"
|
||||
f"ArXiv ID: {overview['arxiv_id']}\n"
|
||||
f"Abstract: {overview['abstract']}\n"
|
||||
f"Section Tree:{overview['section_tree']}\n"
|
||||
f"Type: OVERVIEW"
|
||||
)
|
||||
async def _parallel_process_fragments(self, fragments: List[Fragment]) -> int:
|
||||
"""并行处理所有片段"""
|
||||
successful_count = 0
|
||||
loop = self._ensure_loop()
|
||||
|
||||
for i in range(0, len(fragments), self._batch_size):
|
||||
batch = fragments[i:i + self._batch_size]
|
||||
batch_futures = []
|
||||
|
||||
for j, fragment in enumerate(batch):
|
||||
if not self._is_fragment_processed(fragment, i + j):
|
||||
future = loop.run_in_executor(
|
||||
self.thread_pool,
|
||||
self._process_single_fragment_sync,
|
||||
fragment,
|
||||
i + j
|
||||
)
|
||||
batch_futures.append(future)
|
||||
|
||||
if batch_futures:
|
||||
results = await asyncio.gather(*batch_futures, return_exceptions=True)
|
||||
successful_count += sum(1 for r in results if isinstance(r, bool) and r)
|
||||
|
||||
return successful_count
|
||||
|
||||
def _is_fragment_processed(self, fragment: Fragment, index: int) -> bool:
|
||||
"""检查片段是否已处理"""
|
||||
fragment_id = f"{fragment.arxiv_id}_{index}"
|
||||
with self._processing_lock:
|
||||
return fragment_id in self._processed_fragments
|
||||
|
||||
def _safe_add_to_vector_store_sync(self, text: str) -> bool:
|
||||
"""线程安全的向量存储添加"""
|
||||
with self._processing_lock:
|
||||
try:
|
||||
self.rag_worker.add_text_to_vector_store(text)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding to vector store: {str(e)}")
|
||||
return False
|
||||
|
||||
def _process_single_fragment_sync(self, fragment: Fragment, index: int) -> bool:
|
||||
"""处理单个片段"""
|
||||
fragment_id = f"{fragment.arxiv_id}_{index}"
|
||||
try:
|
||||
# 同步添加概述
|
||||
self.rag_worker.add_text_to_vector_store(overview_text)
|
||||
logger.info(f"Added paper overview for {overview['arxiv_id']}")
|
||||
|
||||
# 并行处理其余片段
|
||||
tasks = []
|
||||
for i, fragment in enumerate(fragments):
|
||||
tasks.append(self._process_single_fragment(fragment, i))
|
||||
await asyncio.gather(*tasks)
|
||||
logger.info(f"Processed {len(fragments)} fragments successfully")
|
||||
|
||||
|
||||
# 保存到本地文件用于调试
|
||||
save_fragments_to_file(
|
||||
fragments,
|
||||
str(self.fragment_store_dir / f"{overview['arxiv_id']}_fragments.json")
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing fragments: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _process_single_fragment(self, fragment: Fragment, index: int) -> None:
|
||||
"""处理单个论文片段(改为异步方法)"""
|
||||
try:
|
||||
text = (
|
||||
f"Paper Title: {fragment.title}\n"
|
||||
f"Section: {fragment.current_section}\n"
|
||||
f"Content: {fragment.content}\n"
|
||||
f"Bibliography: {fragment.bibliography}\n"
|
||||
f"Type: FRAGMENT"
|
||||
)
|
||||
logger.info(f"Processing fragment {index} for paper {fragment.arxiv_id}")
|
||||
# 如果 add_text_to_vector_store 是异步的,使用 await
|
||||
self.rag_worker.add_text_to_vector_store(text)
|
||||
logger.info(f"Successfully added fragment {index} to vector store")
|
||||
|
||||
text = self._build_fragment_text(fragment)
|
||||
if self._safe_add_to_vector_store_sync(text):
|
||||
with self._processing_lock:
|
||||
self._processed_fragments.add(fragment_id)
|
||||
self._processed_count += 1
|
||||
logger.info(f"Successfully processed fragment {index}")
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing fragment {index}: {str(e)}")
|
||||
raise """处理单个论文片段"""
|
||||
return False
|
||||
|
||||
def _build_fragment_text(self, fragment: Fragment) -> str:
|
||||
"""构建片段文本"""
|
||||
return "".join([
|
||||
f"Paper Title: {fragment.title}\n",
|
||||
f"Section: {fragment.current_section}\n",
|
||||
f"Content: {fragment.content}\n",
|
||||
f"Bibliography: {fragment.bibliography}\n",
|
||||
"Type: FRAGMENT"
|
||||
])
|
||||
|
||||
async def _save_results(self, fragments: List[Fragment], arxiv_id: str, successful_count: int) -> None:
|
||||
"""保存处理结果"""
|
||||
if successful_count > 0:
|
||||
loop = self._ensure_loop()
|
||||
await loop.run_in_executor(
|
||||
self.thread_pool,
|
||||
save_fragments_to_file,
|
||||
fragments,
|
||||
str(self.fragment_store_dir / f"{arxiv_id}_fragments.json")
|
||||
)
|
||||
|
||||
def _log_processing_stats(self, start_time: float, total_fragments: int) -> None:
|
||||
"""记录处理统计信息"""
|
||||
elapsed_time = time.time() - start_time
|
||||
processing_rate = total_fragments / elapsed_time if elapsed_time > 0 else 0
|
||||
logger.info(
|
||||
f"Processed {self._processed_count}/{total_fragments} fragments "
|
||||
f"in {elapsed_time:.2f}s (rate: {processing_rate:.2f} fragments/s)"
|
||||
)
|
||||
|
||||
async def process_paper(self, arxiv_id: str) -> bool:
|
||||
"""处理论文主函数"""
|
||||
@@ -183,47 +274,61 @@ class ArxivRagWorker:
|
||||
arxiv_id = self._normalize_arxiv_id(arxiv_id)
|
||||
logger.info(f"Starting to process paper: {arxiv_id}")
|
||||
|
||||
paper_path = self.checkpoint_dir / f"{arxiv_id}.processed"
|
||||
|
||||
if paper_path.exists():
|
||||
if self.paper_path.exists():
|
||||
logger.info(f"Paper {arxiv_id} already processed")
|
||||
return True
|
||||
|
||||
# 创建处理任务
|
||||
task = ProcessingTask(arxiv_id=arxiv_id)
|
||||
self.processing_queue[arxiv_id] = task
|
||||
task.status = "processing"
|
||||
task = self._create_processing_task(arxiv_id)
|
||||
|
||||
async with self.semaphore:
|
||||
# 下载和分割论文
|
||||
fragments = await self.arxiv_splitter.process(arxiv_id)
|
||||
try:
|
||||
async with self.semaphore:
|
||||
fragments = await self.arxiv_splitter.process(arxiv_id)
|
||||
if not fragments:
|
||||
raise ValueError(f"No fragments extracted from paper {arxiv_id}")
|
||||
|
||||
if not fragments:
|
||||
raise ValueError(f"No fragments extracted from paper {arxiv_id}")
|
||||
logger.info(f"Extracted {len(fragments)} fragments from paper {arxiv_id}")
|
||||
await self._process_fragments(fragments)
|
||||
|
||||
logger.info(f"Got {len(fragments)} fragments from paper {arxiv_id}")
|
||||
self._complete_task(task, fragments, self.paper_path)
|
||||
return True
|
||||
|
||||
# 处理片段
|
||||
await self._process_fragments(fragments)
|
||||
|
||||
# 标记完成
|
||||
paper_path.touch()
|
||||
task.status = "completed"
|
||||
task.fragments = fragments
|
||||
|
||||
logger.info(f"Successfully processed paper {arxiv_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
self._fail_task(task, str(e))
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing paper {arxiv_id}: {str(e)}")
|
||||
if arxiv_id in self.processing_queue:
|
||||
self.processing_queue[arxiv_id].status = "failed"
|
||||
self.processing_queue[arxiv_id].error = str(e)
|
||||
return False
|
||||
|
||||
def _create_processing_task(self, arxiv_id: str) -> ProcessingTask:
|
||||
"""创建处理任务"""
|
||||
task = ProcessingTask(arxiv_id=arxiv_id)
|
||||
with self._processing_lock:
|
||||
self.processing_queue[arxiv_id] = task
|
||||
task.status = "processing"
|
||||
return task
|
||||
|
||||
def _complete_task(self, task: ProcessingTask, fragments: List[Fragment], paper_path: Path) -> None:
|
||||
"""完成任务处理"""
|
||||
with self._processing_lock:
|
||||
task.status = "completed"
|
||||
task.fragments = fragments
|
||||
paper_path.touch()
|
||||
logger.info(f"Paper {task.arxiv_id} processed successfully with {self._processed_count} fragments")
|
||||
|
||||
def _fail_task(self, task: ProcessingTask, error: str) -> None:
|
||||
"""任务失败处理"""
|
||||
with self._processing_lock:
|
||||
task.status = "failed"
|
||||
task.error = error
|
||||
|
||||
def _normalize_arxiv_id(self, input_str: str) -> str:
|
||||
"""规范化ArXiv ID"""
|
||||
if 'arxiv.org/' in input_str.lower():
|
||||
if not input_str:
|
||||
return ""
|
||||
|
||||
input_str = input_str.strip().lower()
|
||||
if 'arxiv.org/' in input_str:
|
||||
if '/pdf/' in input_str:
|
||||
arxiv_id = input_str.split('/pdf/')[-1]
|
||||
else:
|
||||
@@ -233,21 +338,20 @@ class ArxivRagWorker:
|
||||
|
||||
async def wait_for_paper(self, arxiv_id: str, timeout: float = 300.0) -> bool:
|
||||
"""等待论文处理完成"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
start_time = datetime.now()
|
||||
while True:
|
||||
task = self.processing_queue.get(arxiv_id)
|
||||
with self._processing_lock:
|
||||
task = self.processing_queue.get(arxiv_id)
|
||||
if not task:
|
||||
return False
|
||||
|
||||
if task.status == "completed":
|
||||
return True
|
||||
|
||||
if task.status == "failed":
|
||||
return False
|
||||
|
||||
# 检查超时
|
||||
if (datetime.now() - start_time).total_seconds() > timeout:
|
||||
if time.time() - start_time > timeout:
|
||||
logger.error(f"Processing paper {arxiv_id} timed out")
|
||||
return False
|
||||
|
||||
@@ -319,7 +423,7 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot:
|
||||
web_port: Web端口
|
||||
"""
|
||||
# 初始化时,提示用户需要 arxiv ID/URL
|
||||
if len(history) == 0 and not txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '0','1', '2')):
|
||||
if len(history) == 0 and not txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '0', '1', '2')):
|
||||
chatbot.append((txt, "请先提供Arxiv论文链接或ID。"))
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
@@ -332,32 +436,51 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot:
|
||||
chatbot.append((txt, "正在处理论文,请稍等..."))
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
# 创建事件循环来处理异步调用
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
# 运行异步处理函数
|
||||
success = loop.run_until_complete(worker.process_paper(txt))
|
||||
if success:
|
||||
arxiv_id = worker._normalize_arxiv_id(txt)
|
||||
success = loop.run_until_complete(worker.wait_for_paper(arxiv_id))
|
||||
# if success:
|
||||
# # 执行自动分析
|
||||
# yield from worker.auto_analyze_paper(chatbot, history, system_prompt)
|
||||
finally:
|
||||
loop.close()
|
||||
# 创建新的事件循环
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
if not success:
|
||||
chatbot[-1] = (txt, "论文处理失败,请检查论文ID是否正确或稍后重试。")
|
||||
# 使用超时控制
|
||||
success = False
|
||||
try:
|
||||
# 设置超时时间为5分钟
|
||||
success = loop.run_until_complete(
|
||||
asyncio.wait_for(worker.process_paper(txt), timeout=300)
|
||||
)
|
||||
if success:
|
||||
arxiv_id = worker._normalize_arxiv_id(txt)
|
||||
success = loop.run_until_complete(
|
||||
asyncio.wait_for(worker.wait_for_paper(arxiv_id), timeout=60)
|
||||
)
|
||||
if success:
|
||||
chatbot[-1] = (txt, "论文处理完成,您现在可以开始提问。")
|
||||
else:
|
||||
chatbot[-1] = (txt, "论文处理超时,请重试。")
|
||||
else:
|
||||
chatbot[-1] = (txt, "论文处理失败,请检查论文ID是否正确或稍后重试。")
|
||||
except asyncio.TimeoutError:
|
||||
chatbot[-1] = (txt, "论文处理超时,请重试。")
|
||||
success = False
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
if not success:
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in main process: {str(e)}")
|
||||
chatbot[-1] = (txt, f"处理过程中发生错误: {str(e)}")
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
# 处理用户询问的情况
|
||||
# 获取用户询问指令
|
||||
user_query = plugin_kwargs.get("advanced_arg", "What is the main research question or problem addressed in this paper?")
|
||||
user_query = plugin_kwargs.get("advanced_arg",
|
||||
"What is the main research question or problem addressed in this paper?")
|
||||
if not user_query:
|
||||
user_query = "What is the main research question or problem addressed in this paper about graph attention network?"
|
||||
# chatbot.append((txt, "请提供您的问题。"))
|
||||
@@ -380,7 +503,9 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot:
|
||||
yield from update_ui_lastest_msg('检测到长输入,正在处理...', chatbot, history, delay=0)
|
||||
if len(user_query) > REMEMBER_PREVIEW:
|
||||
HALF = REMEMBER_PREVIEW // 2
|
||||
query_to_remember = user_query[:HALF] + f" ...\n...(省略{len(user_query) - REMEMBER_PREVIEW}字)...\n... " + user_query[-HALF:]
|
||||
query_to_remember = user_query[
|
||||
:HALF] + f" ...\n...(省略{len(user_query) - REMEMBER_PREVIEW}字)...\n... " + user_query[
|
||||
-HALF:]
|
||||
else:
|
||||
query_to_remember = query_clip
|
||||
else:
|
||||
@@ -412,6 +537,7 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot:
|
||||
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试代码
|
||||
llm_kwargs = {
|
||||
|
||||
在新工单中引用
屏蔽一个用户