diff --git a/crazy_functions/Rag_Interface.py b/crazy_functions/Rag_Interface.py index 0de42b0c..d83d8ca5 100644 --- a/crazy_functions/Rag_Interface.py +++ b/crazy_functions/Rag_Interface.py @@ -4,10 +4,14 @@ from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_ VECTOR_STORE_TYPE = "Milvus" +if VECTOR_STORE_TYPE == "Milvus": + try: + from crazy_functions.rag_fns.milvus_worker import MilvusRagWorker as LlamaIndexRagWorker + except: + VECTOR_STORE_TYPE = "Simple" + if VECTOR_STORE_TYPE == "Simple": from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker -if VECTOR_STORE_TYPE == "Milvus": - from crazy_functions.rag_fns.milvus_worker import MilvusRagWorker as LlamaIndexRagWorker RAG_WORKER_REGISTER = {} diff --git a/crazy_functions/rag_fns/llama_index_worker.py b/crazy_functions/rag_fns/llama_index_worker.py index 761d6943..7a559927 100644 --- a/crazy_functions/rag_fns/llama_index_worker.py +++ b/crazy_functions/rag_fns/llama_index_worker.py @@ -59,6 +59,11 @@ class SaveLoad(): def create_new_vs(self): return GptacVectorStoreIndex.default_vector_store(embed_model=self.embed_model) + def purge(self): + import shutil + shutil.rmtree(self.checkpoint_dir, ignore_errors=True) + self.vs_index = self.create_new_vs() + class LlamaIndexRagWorker(SaveLoad): def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None: diff --git a/requirements.txt b/requirements.txt index 99c841e2..13c0bb56 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,9 +6,7 @@ zhipuai==2.0.1 tiktoken>=0.3.3 requests[socks] pydantic==2.5.2 -llama-index==0.10.47 -llama-index-vector-stores-milvus==0.1.16 -pymilvus==2.4.2 +llama-index==0.10 protobuf==3.20 transformers>=4.27.1,<4.42 scipdf_parser>=0.52