From 7f0ffa58f01a3a18e8a47767ca398d7b2a1edcc9 Mon Sep 17 00:00:00 2001 From: Boyin Liu <143202311+lbykkkk@users.noreply.github.com> Date: Mon, 14 Oct 2024 22:48:24 +0800 Subject: [PATCH] Boyin rag (#1983) * first_version * rag document support * RAG interactive prompts added, issues resolved * Resolve conflicts * Resolve conflicts * Resolve conflicts * more file format support * move import * Resolve LlamaIndexRagWorker bug * new resolve * Address import LlamaIndexRagWorker problem * change import order --------- Co-authored-by: binary-husky --- crazy_functions/Rag_Interface.py | 122 +++++++++++++----- crazy_functions/rag_fns/llama_index_worker.py | 60 +++++---- crazy_functions/rag_fns/rag_file_support.py | 22 ++++ 3 files changed, 148 insertions(+), 56 deletions(-) create mode 100644 crazy_functions/rag_fns/rag_file_support.py diff --git a/crazy_functions/Rag_Interface.py b/crazy_functions/Rag_Interface.py index 1bc740ad..7f43f7d1 100644 --- a/crazy_functions/Rag_Interface.py +++ b/crazy_functions/Rag_Interface.py @@ -1,3 +1,9 @@ +import os,glob +from typing import List + +from shared_utils.fastapi_server import validate_path_safety + +from toolbox import report_exception from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg from crazy_functions.crazy_utils import input_clipping from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive @@ -7,6 +13,37 @@ MAX_HISTORY_ROUND = 5 MAX_CONTEXT_TOKEN_LIMIT = 4096 REMEMBER_PREVIEW = 1000 +@CatchException +def handle_document_upload(files: List[str], llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request, rag_worker): + """ + Handles document uploads by extracting text and adding it to the vector store. + """ + from llama_index.core import Document + from crazy_functions.rag_fns.rag_file_support import extract_text, supports_format + user_name = chatbot.get_user() + checkpoint_dir = get_log_folder(user_name, plugin_name='experimental_rag') + + for file_path in files: + try: + validate_path_safety(file_path, user_name) + text = extract_text(file_path) + if text is None: + chatbot.append( + [f"上传文件: {os.path.basename(file_path)}", f"文件解析失败,无法提取文本内容,请更换文件。失败原因可能为:1.文档格式过于复杂;2. 不支持的文件格式,支持的文件格式后缀有:" + ", ".join(supports_format)]) + else: + chatbot.append( + [f"上传文件: {os.path.basename(file_path)}", f"上传文件前50个字符为:{text[:50]}。"]) + document = Document(text=text, metadata={"source": file_path}) + rag_worker.add_documents_to_vector_store([document]) + chatbot.append([f"上传文件: {os.path.basename(file_path)}", "文件已成功添加到知识库。"]) + except Exception as e: + report_exception(chatbot, history, a=f"处理文件: {file_path}", b=str(e)) + + yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 + + + +# Main Q&A function with document upload support @CatchException def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request): @@ -27,24 +64,43 @@ def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, u rag_worker = RAG_WORKER_REGISTER[user_name] else: rag_worker = RAG_WORKER_REGISTER[user_name] = LlamaIndexRagWorker( - user_name, - llm_kwargs, - checkpoint_dir=checkpoint_dir, - auto_load_checkpoint=True) + user_name, + llm_kwargs, + checkpoint_dir=checkpoint_dir, + auto_load_checkpoint=True + ) + current_context = f"{VECTOR_STORE_TYPE} @ {checkpoint_dir}" tip = "提示:输入“清空向量数据库”可以清空RAG向量数据库" - if txt == "清空向量数据库": - chatbot.append([txt, f'正在清空 ({current_context}) ...']) - yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 - rag_worker.purge() - yield from update_ui_lastest_msg('已清空', chatbot, history, delay=0) # 刷新界面 + + # 2. Handle special commands + if os.path.exists(txt) and os.path.isdir(txt): + project_folder = txt + validate_path_safety(project_folder, chatbot.get_user()) + # Extract file paths from the user input + # Assuming the user inputs file paths separated by commas after the command + file_paths = [f for f in glob.glob(f'{project_folder}/**/*', recursive=True)] + chatbot.append([txt, f'正在处理上传的文档 ({current_context}) ...']) + yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 + + yield from handle_document_upload(file_paths, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request, rag_worker) return - chatbot.append([txt, f'正在召回知识 ({current_context}) ...']) - yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 + elif txt == "清空向量数据库": + chatbot.append([txt, f'正在清空 ({current_context}) ...']) + yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 + rag_worker.purge_vector_store() + yield from update_ui_lastest_msg('已清空', chatbot, history, delay=0) # 刷新界面 + return - # 2. clip history to reduce token consumption - # 2-1. reduce chat round + else: + report_exception(chatbot, history, a=f"上传文件路径错误: {txt}", b="请检查并提供正确路径。") + + # 3. Normal Q&A processing + chatbot.append([txt, f'正在召回知识 ({current_context}) ...']) + yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 + + # 4. Clip history to reduce token consumption txt_origin = txt if len(history) > MAX_HISTORY_ROUND * 2: @@ -52,41 +108,47 @@ def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, u txt_clip, history, flags = input_clipping(txt, history, max_token_limit=MAX_CONTEXT_TOKEN_LIMIT, return_clip_flags=True) input_is_clipped_flag = (flags["original_input_len"] != flags["clipped_input_len"]) - # 2-2. if input is clipped, add input to vector store before retrieve + # 5. If input is clipped, add input to vector store before retrieve if input_is_clipped_flag: - yield from update_ui_lastest_msg('检测到长输入, 正在向量化 ...', chatbot, history, delay=0) # 刷新界面 - # save input to vector store + yield from update_ui_lastest_msg('检测到长输入, 正在向量化 ...', chatbot, history, delay=0) # 刷新界面 + # Save input to vector store rag_worker.add_text_to_vector_store(txt_origin) - yield from update_ui_lastest_msg('向量化完成 ...', chatbot, history, delay=0) # 刷新界面 + yield from update_ui_lastest_msg('向量化完成 ...', chatbot, history, delay=0) # 刷新界面 + if len(txt_origin) > REMEMBER_PREVIEW: - HALF = REMEMBER_PREVIEW//2 + HALF = REMEMBER_PREVIEW // 2 i_say_to_remember = txt[:HALF] + f" ...\n...(省略{len(txt_origin)-REMEMBER_PREVIEW}字)...\n... " + txt[-HALF:] if (flags["original_input_len"] - flags["clipped_input_len"]) > HALF: - txt_clip = txt_clip + f" ...\n...(省略{len(txt_origin)-len(txt_clip)-HALF}字)...\n... " + txt[-HALF:] - else: - pass - i_say = txt_clip + txt_clip = txt_clip + f" ...\n...(省略{len(txt_origin)-len(txt_clip)-HALF}字)...\n... " + txt[-HALF:] else: i_say_to_remember = i_say = txt_clip else: i_say_to_remember = i_say = txt_clip - # 3. we search vector store and build prompts + # 6. Search vector store and build prompts nodes = rag_worker.retrieve_from_store_with_query(i_say) prompt = rag_worker.build_prompt(query=i_say, nodes=nodes) + # 7. Query language model + if len(chatbot) != 0: + chatbot.pop(-1) # Pop temp chat, because we are going to add them again inside `request_gpt_model_in_new_thread_with_ui_alive` - # 4. it is time to query llms - if len(chatbot) != 0: chatbot.pop(-1) # pop temp chat, because we are going to add them again inside `request_gpt_model_in_new_thread_with_ui_alive` model_say = yield from request_gpt_model_in_new_thread_with_ui_alive( - inputs=prompt, inputs_show_user=i_say, - llm_kwargs=llm_kwargs, chatbot=chatbot, history=history, + inputs=prompt, + inputs_show_user=i_say, + llm_kwargs=llm_kwargs, + chatbot=chatbot, + history=history, sys_prompt=system_prompt, retry_times_at_unknown_error=0 ) - # 5. remember what has been asked / answered - yield from update_ui_lastest_msg(model_say + '

' + f'对话记忆中, 请稍等 ({current_context}) ...', chatbot, history, delay=0.5) # 刷新界面 + # 8. Remember Q&A + yield from update_ui_lastest_msg( + model_say + '

' + f'对话记忆中, 请稍等 ({current_context}) ...', + chatbot, history, delay=0.5 + ) rag_worker.remember_qa(i_say_to_remember, model_say) history.extend([i_say, model_say]) - yield from update_ui_lastest_msg(model_say, chatbot, history, delay=0, msg=tip) # 刷新界面 + # 9. Final UI Update + yield from update_ui_lastest_msg(model_say, chatbot, history, delay=0, msg=tip) \ No newline at end of file diff --git a/crazy_functions/rag_fns/llama_index_worker.py b/crazy_functions/rag_fns/llama_index_worker.py index f6f7f0ab..59a5827c 100644 --- a/crazy_functions/rag_fns/llama_index_worker.py +++ b/crazy_functions/rag_fns/llama_index_worker.py @@ -1,17 +1,13 @@ -import llama_index -import os import atexit from loguru import logger from typing import List + from llama_index.core import Document -from llama_index.core.schema import TextNode -from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel -from shared_utils.connect_void_terminal import get_chat_default_kwargs -from llama_index.core import VectorStoreIndex, SimpleDirectoryReader -from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex from llama_index.core.ingestion import run_transformations -from llama_index.core import PromptTemplate -from llama_index.core.response_synthesizers import TreeSummarize +from llama_index.core.schema import TextNode + +from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex +from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel DEFAULT_QUERY_GENERATION_PROMPT = """\ Now, you have context information as below: @@ -63,7 +59,7 @@ class SaveLoad(): def purge(self): import shutil shutil.rmtree(self.checkpoint_dir, ignore_errors=True) - self.vs_index = self.create_new_vs() + self.vs_index = self.create_new_vs(self.checkpoint_dir) class LlamaIndexRagWorker(SaveLoad): @@ -75,7 +71,7 @@ class LlamaIndexRagWorker(SaveLoad): if auto_load_checkpoint: self.vs_index = self.load_from_checkpoint(checkpoint_dir) else: - self.vs_index = self.create_new_vs(checkpoint_dir) + self.vs_index = self.create_new_vs() atexit.register(lambda: self.save_to_checkpoint(checkpoint_dir)) def assign_embedding_model(self): @@ -91,40 +87,52 @@ class LlamaIndexRagWorker(SaveLoad): logger.info('oo --------inspect_vector_store end--------') return vector_store_preview - def add_documents_to_vector_store(self, document_list): - documents = [Document(text=t) for t in document_list] + def add_documents_to_vector_store(self, document_list: List[Document]): + """ + Adds a list of Document objects to the vector store after processing. + """ + documents = document_list documents_nodes = run_transformations( - documents, # type: ignore - self.vs_index._transformations, - show_progress=True - ) + documents, # type: ignore + self.vs_index._transformations, + show_progress=True + ) self.vs_index.insert_nodes(documents_nodes) - if self.debug_mode: self.inspect_vector_store() + if self.debug_mode: + self.inspect_vector_store() - def add_text_to_vector_store(self, text): + def add_text_to_vector_store(self, text: str): node = TextNode(text=text) documents_nodes = run_transformations( - [node], - self.vs_index._transformations, - show_progress=True - ) + [node], + self.vs_index._transformations, + show_progress=True + ) self.vs_index.insert_nodes(documents_nodes) - if self.debug_mode: self.inspect_vector_store() + if self.debug_mode: + self.inspect_vector_store() def remember_qa(self, question, answer): formatted_str = QUESTION_ANSWER_RECORD.format(question=question, answer=answer) self.add_text_to_vector_store(formatted_str) def retrieve_from_store_with_query(self, query): - if self.debug_mode: self.inspect_vector_store() + if self.debug_mode: + self.inspect_vector_store() retriever = self.vs_index.as_retriever() return retriever.retrieve(query) def build_prompt(self, query, nodes): context_str = self.generate_node_array_preview(nodes) return DEFAULT_QUERY_GENERATION_PROMPT.format(context_str=context_str, query_str=query) - + def generate_node_array_preview(self, nodes): buf = "\n".join(([f"(No.{i+1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)])) if self.debug_mode: logger.info(buf) return buf + + def purge_vector_store(self): + """ + Purges the current vector store and creates a new one. + """ + self.purge() \ No newline at end of file diff --git a/crazy_functions/rag_fns/rag_file_support.py b/crazy_functions/rag_fns/rag_file_support.py new file mode 100644 index 00000000..98ba3bee --- /dev/null +++ b/crazy_functions/rag_fns/rag_file_support.py @@ -0,0 +1,22 @@ +import os +from llama_index.core import SimpleDirectoryReader + +supports_format = ['.csv', '.docx', '.epub', '.ipynb', '.mbox', '.md', '.pdf', '.txt', '.ppt', + '.pptm', '.pptx'] + + +# 修改后的 extract_text 函数,结合 SimpleDirectoryReader 和自定义解析逻辑 +def extract_text(file_path): + _, ext = os.path.splitext(file_path.lower()) + + # 使用 SimpleDirectoryReader 处理它支持的文件格式 + if ext in supports_format: + try: + reader = SimpleDirectoryReader(input_files=[file_path]) + documents = reader.load_data() + if len(documents) > 0: + return documents[0].text + except Exception as e: + pass + + return None