From 8b91d2ac0a4a7a122c4ad85c079efb2b82d0f50f Mon Sep 17 00:00:00 2001 From: binary-husky Date: Sun, 8 Sep 2024 15:19:03 +0000 Subject: [PATCH] add milvus vector store --- TODO | 11 ++- crazy_functions/Rag_Interface.py | 9 +- crazy_functions/rag_fns/llama_index_worker.py | 10 +- crazy_functions/rag_fns/milvus_worker.py | 94 +++++++++++++++++++ request_llms/embed_models/openai_embed.py | 8 +- requirements.txt | 4 +- 6 files changed, 128 insertions(+), 8 deletions(-) create mode 100644 crazy_functions/rag_fns/milvus_worker.py diff --git a/TODO b/TODO index 4ab3721b..72416115 100644 --- a/TODO +++ b/TODO @@ -1 +1,10 @@ -RAG忘了触发保存了! \ No newline at end of file +RAG忘了触发保存了! + + +刘博寅: 用llama index 实现 RAG 文档向量化 + RAG代码参考: + crazy_functions/rag_fns/llama_index_worker.py + crazy_functions/rag_fns/milvus_worker.py + crazy_functions/rag_fns/vector_store_index.py + 读取文件的代码参考(使用glob): + crazy_functions/SourceCode_Analyse.py diff --git a/crazy_functions/Rag_Interface.py b/crazy_functions/Rag_Interface.py index 9e1d9075..0c5c4e2b 100644 --- a/crazy_functions/Rag_Interface.py +++ b/crazy_functions/Rag_Interface.py @@ -1,7 +1,14 @@ from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg from crazy_functions.crazy_utils import input_clipping from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive -from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker + +VECTOR_STORE_TYPE = "Milvus" + +if VECTOR_STORE_TYPE == "Simple": + from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker +if VECTOR_STORE_TYPE == "Milvus": + from crazy_functions.rag_fns.milvus_worker import MilvusRagWorker as LlamaIndexRagWorker + RAG_WORKER_REGISTER = {} diff --git a/crazy_functions/rag_fns/llama_index_worker.py b/crazy_functions/rag_fns/llama_index_worker.py index de1ef38d..761d6943 100644 --- a/crazy_functions/rag_fns/llama_index_worker.py +++ b/crazy_functions/rag_fns/llama_index_worker.py @@ -1,4 +1,7 @@ import llama_index +import os +import atexit +from typing import List from llama_index.core import Document from llama_index.core.schema import TextNode from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel @@ -38,6 +41,7 @@ class SaveLoad(): return True def save_to_checkpoint(self, checkpoint_dir=None): + print(f'saving vector store to: {checkpoint_dir}') if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir self.vs_index.storage_context.persist(persist_dir=checkpoint_dir) @@ -65,7 +69,8 @@ class LlamaIndexRagWorker(SaveLoad): if auto_load_checkpoint: self.vs_index = self.load_from_checkpoint(checkpoint_dir) else: - self.vs_index = self.create_new_vs() + self.vs_index = self.create_new_vs(checkpoint_dir) + atexit.register(lambda: self.save_to_checkpoint(checkpoint_dir)) def assign_embedding_model(self): pass @@ -117,6 +122,3 @@ class LlamaIndexRagWorker(SaveLoad): buf = "\n".join(([f"(No.{i+1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)])) if self.debug_mode: print(buf) return buf - - - diff --git a/crazy_functions/rag_fns/milvus_worker.py b/crazy_functions/rag_fns/milvus_worker.py new file mode 100644 index 00000000..4cfc1678 --- /dev/null +++ b/crazy_functions/rag_fns/milvus_worker.py @@ -0,0 +1,94 @@ +import llama_index +import os +import atexit +from typing import List +from llama_index.core import Document +from llama_index.core.schema import TextNode +from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel +from shared_utils.connect_void_terminal import get_chat_default_kwargs +from llama_index.core import VectorStoreIndex, SimpleDirectoryReader +from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex +from llama_index.core.ingestion import run_transformations +from llama_index.core import PromptTemplate +from llama_index.core.response_synthesizers import TreeSummarize +from llama_index.core import StorageContext +from llama_index.vector_stores.milvus import MilvusVectorStore +from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker + +DEFAULT_QUERY_GENERATION_PROMPT = """\ +Now, you have context information as below: +--------------------- +{context_str} +--------------------- +Answer the user request below (use the context information if necessary, otherwise you can ignore them): +--------------------- +{query_str} +""" + +QUESTION_ANSWER_RECORD = """\ +{{ + "type": "This is a previous conversation with the user", + "question": "{question}", + "answer": "{answer}", +}} +""" + + +class MilvusSaveLoad(): + + def does_checkpoint_exist(self, checkpoint_dir=None): + import os, glob + if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir + if not os.path.exists(checkpoint_dir): return False + if len(glob.glob(os.path.join(checkpoint_dir, "*.json"))) == 0: return False + return True + + def save_to_checkpoint(self, checkpoint_dir=None): + print(f'saving vector store to: {checkpoint_dir}') + # if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir + # self.vs_index.storage_context.persist(persist_dir=checkpoint_dir) + + def load_from_checkpoint(self, checkpoint_dir=None): + if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir + if self.does_checkpoint_exist(checkpoint_dir=checkpoint_dir): + print('loading checkpoint from disk') + from llama_index.core import StorageContext, load_index_from_storage + storage_context = StorageContext.from_defaults(persist_dir=checkpoint_dir) + try: + self.vs_index = load_index_from_storage(storage_context, embed_model=self.embed_model) + return self.vs_index + except: + return self.create_new_vs(checkpoint_dir) + else: + return self.create_new_vs(checkpoint_dir) + + def create_new_vs(self, checkpoint_dir): + vector_store = MilvusVectorStore( + uri=os.path.join(checkpoint_dir, "milvus_demo.db"), + dim=self.embed_model.embedding_dimension() + ) + storage_context = StorageContext.from_defaults(vector_store=vector_store) + index = GptacVectorStoreIndex.default_vector_store(storage_context=storage_context, embed_model=self.embed_model) + return index + + +class MilvusRagWorker(LlamaIndexRagWorker): + + + def inspect_vector_store(self): + # This function is for debugging + try: + self.vs_index.storage_context.index_store.to_dict() + docstore = self.vs_index.storage_context.docstore.docs + if not docstore.items(): + raise ValueError("cannot inspect") + vector_store_preview = "\n".join([ f"{_id} | {tn.text}" for _id, tn in docstore.items() ]) + except: + dummy_retrieve_res: List["NodeWithScore"] = self.vs_index.as_retriever().retrieve(' ') + vector_store_preview = "\n".join( + [f"{node.id_} | {node.text}" for node in dummy_retrieve_res] + ) + print('\n++ --------inspect_vector_store begin--------') + print(vector_store_preview) + print('oo --------inspect_vector_store end--------') + return vector_store_preview diff --git a/request_llms/embed_models/openai_embed.py b/request_llms/embed_models/openai_embed.py index c559e1c3..9d565173 100644 --- a/request_llms/embed_models/openai_embed.py +++ b/request_llms/embed_models/openai_embed.py @@ -71,7 +71,13 @@ class OpenAiEmbeddingModel(EmbeddingModel): embedding = res.data[0].embedding return embedding - def embedding_dimension(self, llm_kwargs): + def embedding_dimension(self, llm_kwargs=None): + # load kwargs + if llm_kwargs is None: + llm_kwargs = self.llm_kwargs + if llm_kwargs is None: + raise RuntimeError("llm_kwargs is not provided!") + from .bridge_all_embed import embed_model_info return embed_model_info[llm_kwargs['embed_model']]['embed_dimension'] diff --git a/requirements.txt b/requirements.txt index 757df774..99c841e2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,9 @@ tiktoken>=0.3.3 requests[socks] pydantic==2.5.2 llama-index==0.10.47 -protobuf==3.18 +llama-index-vector-stores-milvus==0.1.16 +pymilvus==2.4.2 +protobuf==3.20 transformers>=4.27.1,<4.42 scipdf_parser>=0.52 anthropic>=0.18.1