diff --git a/crazy_functions/Rag_Interface.py b/crazy_functions/Rag_Interface.py index 43ebf8cd..33175df0 100644 --- a/crazy_functions/Rag_Interface.py +++ b/crazy_functions/Rag_Interface.py @@ -1,9 +1,12 @@ -import os, glob - +import os,glob from typing import List -from toolbox import report_exception -from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg + +from llama_index.core import Document from shared_utils.fastapi_server import validate_path_safety + +from toolbox import report_exception +from crazy_functions.rag_fns.rag_file_support import extract_text, supports_format +from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg from crazy_functions.crazy_utils import input_clipping from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive @@ -12,44 +15,14 @@ MAX_HISTORY_ROUND = 5 MAX_CONTEXT_TOKEN_LIMIT = 4096 REMEMBER_PREVIEW = 1000 -# import vector store lib -VECTOR_STORE_TYPE = "Milvus" -if VECTOR_STORE_TYPE == "Milvus": - try: - from crazy_functions.rag_fns.milvus_worker import MilvusRagWorker as LlamaIndexRagWorker - except: - VECTOR_STORE_TYPE = "Simple" -if VECTOR_STORE_TYPE == "Simple": - from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker - @CatchException -def handle_document_upload(files: List[str], llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request): +def handle_document_upload(files: List[str], llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request, rag_worker): """ Handles document uploads by extracting text and adding it to the vector store. - - Args: - files (List[str]): List of file paths to process. - llm_kwargs: Language model keyword arguments. - plugin_kwargs: Plugin keyword arguments. - chatbot: Chatbot instance. - history: Chat history. - system_prompt: System prompt. - user_request: User request. """ - from llama_index.core import Document - from crazy_functions.rag_fns.rag_file_support import extract_text, supports_format user_name = chatbot.get_user() checkpoint_dir = get_log_folder(user_name, plugin_name='experimental_rag') - if user_name in RAG_WORKER_REGISTER: - rag_worker = RAG_WORKER_REGISTER[user_name] - else: - rag_worker = RAG_WORKER_REGISTER[user_name] = LlamaIndexRagWorker( - user_name, - llm_kwargs, - checkpoint_dir=checkpoint_dir, - auto_load_checkpoint=True - ) for file_path in files: try: @@ -73,10 +46,19 @@ def handle_document_upload(files: List[str], llm_kwargs, plugin_kwargs, chatbot, @CatchException def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request): + # import vector store lib + VECTOR_STORE_TYPE = "Milvus" + if VECTOR_STORE_TYPE == "Milvus": + try: + from crazy_functions.rag_fns.milvus_worker import MilvusRagWorker as LlamaIndexRagWorker + except: + VECTOR_STORE_TYPE = "Simple" + if VECTOR_STORE_TYPE == "Simple": + from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker + # 1. we retrieve rag worker from global context user_name = chatbot.get_user() checkpoint_dir = get_log_folder(user_name, plugin_name='experimental_rag') - if user_name in RAG_WORKER_REGISTER: rag_worker = RAG_WORKER_REGISTER[user_name] else: @@ -100,7 +82,7 @@ def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, u chatbot.append([txt, f'正在处理上传的文档 ({current_context}) ...']) yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 - yield from handle_document_upload(file_paths, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request) + yield from handle_document_upload(file_paths, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request, rag_worker) return elif txt == "清空向量数据库": @@ -145,7 +127,6 @@ def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, u # 6. Search vector store and build prompts nodes = rag_worker.retrieve_from_store_with_query(i_say) prompt = rag_worker.build_prompt(query=i_say, nodes=nodes) - # 7. Query language model if len(chatbot) != 0: chatbot.pop(-1) # Pop temp chat, because we are going to add them again inside `request_gpt_model_in_new_thread_with_ui_alive`