From dcfed97054d075fcfadae862c175234ee02a2268 Mon Sep 17 00:00:00 2001 From: binary-husky Date: Sun, 8 Sep 2024 15:43:01 +0000 Subject: [PATCH] revise milvus rag --- crazy_functions/Rag_Interface.py | 17 +++++++++++++---- crazy_functions/rag_fns/milvus_worker.py | 19 ++++++++++++++++--- toolbox.py | 4 ++-- 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/crazy_functions/Rag_Interface.py b/crazy_functions/Rag_Interface.py index 0c5c4e2b..0de42b0c 100644 --- a/crazy_functions/Rag_Interface.py +++ b/crazy_functions/Rag_Interface.py @@ -21,16 +21,25 @@ def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, u # 1. we retrieve rag worker from global context user_name = chatbot.get_user() + checkpoint_dir = get_log_folder(user_name, plugin_name='experimental_rag') if user_name in RAG_WORKER_REGISTER: rag_worker = RAG_WORKER_REGISTER[user_name] else: rag_worker = RAG_WORKER_REGISTER[user_name] = LlamaIndexRagWorker( user_name, llm_kwargs, - checkpoint_dir=get_log_folder(user_name, plugin_name='experimental_rag'), + checkpoint_dir=checkpoint_dir, auto_load_checkpoint=True) + current_context = f"{VECTOR_STORE_TYPE} @ {checkpoint_dir}" + tip = "提示:输入“清空向量数据库”可以清空RAG向量数据库" + if txt == "清空向量数据库": + chatbot.append([txt, f'正在清空 ({current_context}) ...']) + yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 + rag_worker.purge() + yield from update_ui_lastest_msg('已清空', chatbot, history, delay=0) # 刷新界面 + return - chatbot.append([txt, '正在召回知识 ...']) + chatbot.append([txt, f'正在召回知识 ({current_context}) ...']) yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 2. clip history to reduce token consumption @@ -75,8 +84,8 @@ def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, u ) # 5. remember what has been asked / answered - yield from update_ui_lastest_msg(model_say + '

' + '对话记忆中, 请稍等 ...', chatbot, history, delay=0.5) # 刷新界面 + yield from update_ui_lastest_msg(model_say + '

' + f'对话记忆中, 请稍等 ({current_context}) ...', chatbot, history, delay=0.5) # 刷新界面 rag_worker.remember_qa(i_say_to_remember, model_say) history.extend([i_say, model_say]) - yield from update_ui_lastest_msg(model_say, chatbot, history, delay=0) # 刷新界面 + yield from update_ui_lastest_msg(model_say, chatbot, history, delay=0, msg=tip) # 刷新界面 diff --git a/crazy_functions/rag_fns/milvus_worker.py b/crazy_functions/rag_fns/milvus_worker.py index 4cfc1678..4b5b0ad9 100644 --- a/crazy_functions/rag_fns/milvus_worker.py +++ b/crazy_functions/rag_fns/milvus_worker.py @@ -62,18 +62,31 @@ class MilvusSaveLoad(): else: return self.create_new_vs(checkpoint_dir) - def create_new_vs(self, checkpoint_dir): + def create_new_vs(self, checkpoint_dir, overwrite=False): vector_store = MilvusVectorStore( uri=os.path.join(checkpoint_dir, "milvus_demo.db"), - dim=self.embed_model.embedding_dimension() + dim=self.embed_model.embedding_dimension(), + overwrite=overwrite ) storage_context = StorageContext.from_defaults(vector_store=vector_store) index = GptacVectorStoreIndex.default_vector_store(storage_context=storage_context, embed_model=self.embed_model) return index + def purge(self): + self.vs_index = self.create_new_vs(self.checkpoint_dir, overwrite=True) -class MilvusRagWorker(LlamaIndexRagWorker): +class MilvusRagWorker(MilvusSaveLoad, LlamaIndexRagWorker): + def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None: + self.debug_mode = True + self.embed_model = OpenAiEmbeddingModel(llm_kwargs) + self.user_name = user_name + self.checkpoint_dir = checkpoint_dir + if auto_load_checkpoint: + self.vs_index = self.load_from_checkpoint(checkpoint_dir) + else: + self.vs_index = self.create_new_vs(checkpoint_dir) + atexit.register(lambda: self.save_to_checkpoint(checkpoint_dir)) def inspect_vector_store(self): # This function is for debugging diff --git a/toolbox.py b/toolbox.py index 6b2f4c10..900cf234 100644 --- a/toolbox.py +++ b/toolbox.py @@ -178,7 +178,7 @@ def update_ui(chatbot:ChatBotWithCookies, history, msg="正常", **kwargs): # yield cookies, chatbot_gr, history, msg -def update_ui_lastest_msg(lastmsg:str, chatbot:ChatBotWithCookies, history:list, delay=1): # 刷新界面 +def update_ui_lastest_msg(lastmsg:str, chatbot:ChatBotWithCookies, history:list, delay=1, msg="正常"): # 刷新界面 """ 刷新用户界面 """ @@ -186,7 +186,7 @@ def update_ui_lastest_msg(lastmsg:str, chatbot:ChatBotWithCookies, history:list, chatbot.append(["update_ui_last_msg", lastmsg]) chatbot[-1] = list(chatbot[-1]) chatbot[-1][-1] = lastmsg - yield from update_ui(chatbot=chatbot, history=history) + yield from update_ui(chatbot=chatbot, history=history, msg=msg) time.sleep(delay)