这个提交包含在:
lbykkkk
2024-11-23 11:31:11 +00:00
父节点 b2d6536974
当前提交 241c9641bb

查看文件

@@ -51,10 +51,7 @@ class ArxivRagWorker:
# 初始化基础存储目录 # 初始化基础存储目录
self.base_dir = Path(get_log_folder(user_name, plugin_name='rag_cache')) self.base_dir = Path(get_log_folder(user_name, plugin_name='rag_cache'))
if os.path.exists(self.base_dir):
self.loading = True
else:
self.loading = False
# 如果提供了 arxiv_id,创建针对该论文的子目录 # 如果提供了 arxiv_id,创建针对该论文的子目录
if self.arxiv_id: if self.arxiv_id:
self.checkpoint_dir = self.base_dir / self.arxiv_id self.checkpoint_dir = self.base_dir / self.arxiv_id
@@ -67,6 +64,10 @@ class ArxivRagWorker:
self.fragment_store_dir = self.base_dir / "fragments" self.fragment_store_dir = self.base_dir / "fragments"
# 创建必要的目录 # 创建必要的目录
if os.path.exists(self.vector_store_dir):
self.loading = True
else:
self.loading = False
self.checkpoint_dir.mkdir(parents=True, exist_ok=True) self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.vector_store_dir.mkdir(parents=True, exist_ok=True) self.vector_store_dir.mkdir(parents=True, exist_ok=True)
self.fragment_store_dir.mkdir(parents=True, exist_ok=True) self.fragment_store_dir.mkdir(parents=True, exist_ok=True)
@@ -399,11 +400,11 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot:
# 处理用户询问的情况 # 处理用户询问的情况
# 获取用户询问指令 # 获取用户询问指令
user_query = plugin_kwargs.get("advanced_arg", "What is the main research question or problem addressed in this paper?") user_query = plugin_kwargs.get("advanced_arg", "What is the main research question or problem addressed in this paper?")
# user_query = "What is the main research question or problem addressed in this paper about graph attention network?" user_query = "What is the main research question or problem addressed in this paper about graph attention network?"
if not user_query: # if not user_query:
chatbot.append((txt, "请提供您的问题。")) # chatbot.append((txt, "请提供您的问题。"))
yield from update_ui(chatbot=chatbot, history=history) # yield from update_ui(chatbot=chatbot, history=history)
return # return
# 处理历史对话长度 # 处理历史对话长度
if len(history) > MAX_HISTORY_ROUND * 2: if len(history) > MAX_HISTORY_ROUND * 2: