revise milvus rag

这个提交包含在:
binary-husky
2024-09-08 15:43:01 +00:00
父节点 8b91d2ac0a
当前提交 dcfed97054
共有 3 个文件被更改,包括 31 次插入9 次删除

查看文件

@@ -21,16 +21,25 @@ def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, u
# 1. we retrieve rag worker from global context # 1. we retrieve rag worker from global context
user_name = chatbot.get_user() user_name = chatbot.get_user()
checkpoint_dir = get_log_folder(user_name, plugin_name='experimental_rag')
if user_name in RAG_WORKER_REGISTER: if user_name in RAG_WORKER_REGISTER:
rag_worker = RAG_WORKER_REGISTER[user_name] rag_worker = RAG_WORKER_REGISTER[user_name]
else: else:
rag_worker = RAG_WORKER_REGISTER[user_name] = LlamaIndexRagWorker( rag_worker = RAG_WORKER_REGISTER[user_name] = LlamaIndexRagWorker(
user_name, user_name,
llm_kwargs, llm_kwargs,
checkpoint_dir=get_log_folder(user_name, plugin_name='experimental_rag'), checkpoint_dir=checkpoint_dir,
auto_load_checkpoint=True) auto_load_checkpoint=True)
current_context = f"{VECTOR_STORE_TYPE} @ {checkpoint_dir}"
tip = "提示输入“清空向量数据库”可以清空RAG向量数据库"
if txt == "清空向量数据库":
chatbot.append([txt, f'正在清空 ({current_context}) ...'])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
rag_worker.purge()
yield from update_ui_lastest_msg('已清空', chatbot, history, delay=0) # 刷新界面
return
chatbot.append([txt, '正在召回知识 ...']) chatbot.append([txt, f'正在召回知识 ({current_context}) ...'])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
# 2. clip history to reduce token consumption # 2. clip history to reduce token consumption
@@ -75,8 +84,8 @@ def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, u
) )
# 5. remember what has been asked / answered # 5. remember what has been asked / answered
yield from update_ui_lastest_msg(model_say + '</br></br>' + '对话记忆中, 请稍等 ...', chatbot, history, delay=0.5) # 刷新界面 yield from update_ui_lastest_msg(model_say + '</br></br>' + f'对话记忆中, 请稍等 ({current_context}) ...', chatbot, history, delay=0.5) # 刷新界面
rag_worker.remember_qa(i_say_to_remember, model_say) rag_worker.remember_qa(i_say_to_remember, model_say)
history.extend([i_say, model_say]) history.extend([i_say, model_say])
yield from update_ui_lastest_msg(model_say, chatbot, history, delay=0) # 刷新界面 yield from update_ui_lastest_msg(model_say, chatbot, history, delay=0, msg=tip) # 刷新界面

查看文件

@@ -62,18 +62,31 @@ class MilvusSaveLoad():
else: else:
return self.create_new_vs(checkpoint_dir) return self.create_new_vs(checkpoint_dir)
def create_new_vs(self, checkpoint_dir): def create_new_vs(self, checkpoint_dir, overwrite=False):
vector_store = MilvusVectorStore( vector_store = MilvusVectorStore(
uri=os.path.join(checkpoint_dir, "milvus_demo.db"), uri=os.path.join(checkpoint_dir, "milvus_demo.db"),
dim=self.embed_model.embedding_dimension() dim=self.embed_model.embedding_dimension(),
overwrite=overwrite
) )
storage_context = StorageContext.from_defaults(vector_store=vector_store) storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = GptacVectorStoreIndex.default_vector_store(storage_context=storage_context, embed_model=self.embed_model) index = GptacVectorStoreIndex.default_vector_store(storage_context=storage_context, embed_model=self.embed_model)
return index return index
def purge(self):
self.vs_index = self.create_new_vs(self.checkpoint_dir, overwrite=True)
class MilvusRagWorker(LlamaIndexRagWorker): class MilvusRagWorker(MilvusSaveLoad, LlamaIndexRagWorker):
def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None:
self.debug_mode = True
self.embed_model = OpenAiEmbeddingModel(llm_kwargs)
self.user_name = user_name
self.checkpoint_dir = checkpoint_dir
if auto_load_checkpoint:
self.vs_index = self.load_from_checkpoint(checkpoint_dir)
else:
self.vs_index = self.create_new_vs(checkpoint_dir)
atexit.register(lambda: self.save_to_checkpoint(checkpoint_dir))
def inspect_vector_store(self): def inspect_vector_store(self):
# This function is for debugging # This function is for debugging

查看文件

@@ -178,7 +178,7 @@ def update_ui(chatbot:ChatBotWithCookies, history, msg="正常", **kwargs): #
yield cookies, chatbot_gr, history, msg yield cookies, chatbot_gr, history, msg
def update_ui_lastest_msg(lastmsg:str, chatbot:ChatBotWithCookies, history:list, delay=1): # 刷新界面 def update_ui_lastest_msg(lastmsg:str, chatbot:ChatBotWithCookies, history:list, delay=1, msg="正常"): # 刷新界面
""" """
刷新用户界面 刷新用户界面
""" """
@@ -186,7 +186,7 @@ def update_ui_lastest_msg(lastmsg:str, chatbot:ChatBotWithCookies, history:list,
chatbot.append(["update_ui_last_msg", lastmsg]) chatbot.append(["update_ui_last_msg", lastmsg])
chatbot[-1] = list(chatbot[-1]) chatbot[-1] = list(chatbot[-1])
chatbot[-1][-1] = lastmsg chatbot[-1][-1] = lastmsg
yield from update_ui(chatbot=chatbot, history=history) yield from update_ui(chatbot=chatbot, history=history, msg=msg)
time.sleep(delay) time.sleep(delay)