From df717f8bba8e2e4e3466ef45bbf34951ae049925 Mon Sep 17 00:00:00 2001 From: lbykkkk Date: Fri, 20 Sep 2024 00:06:59 +0800 Subject: [PATCH] first_version --- crazy_functions/Rag_Interface.py | 158 ++++++++++++++---- crazy_functions/rag_fns/llama_index_worker.py | 65 +++---- crazy_functions/rag_fns/rag_file_support.py | 47 ++++++ 3 files changed, 207 insertions(+), 63 deletions(-) create mode 100644 crazy_functions/rag_fns/rag_file_support.py diff --git a/crazy_functions/Rag_Interface.py b/crazy_functions/Rag_Interface.py index d83d8ca5..198f5267 100644 --- a/crazy_functions/Rag_Interface.py +++ b/crazy_functions/Rag_Interface.py @@ -1,7 +1,12 @@ -from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg -from crazy_functions.crazy_utils import input_clipping -from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive +import os +from typing import List +from llama_index.core import Document +from shared_utils.fastapi_server import validate_path_safety +from crazy_functions.crazy_utils import input_clipping, request_gpt_model_in_new_thread_with_ui_alive +from toolbox import CatchException, update_ui, get_log_folder, update_ui_lastest_msg +from toolbox import report_exception +from crazy_functions.rag_fns.rag_file_support import extract_text VECTOR_STORE_TYPE = "Milvus" if VECTOR_STORE_TYPE == "Milvus": @@ -21,33 +26,109 @@ MAX_CONTEXT_TOKEN_LIMIT = 4096 REMEMBER_PREVIEW = 1000 @CatchException -def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request): +def handle_document_upload(files: List[str], llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request): + """ + Handles document uploads by extracting text and adding it to the vector store. - # 1. we retrieve rag worker from global context + Args: + files (List[str]): List of file paths to process. + llm_kwargs: Language model keyword arguments. + plugin_kwargs: Plugin keyword arguments. + chatbot: Chatbot instance. + history: Chat history. + system_prompt: System prompt. + user_request: User request. + """ user_name = chatbot.get_user() checkpoint_dir = get_log_folder(user_name, plugin_name='experimental_rag') + if user_name in RAG_WORKER_REGISTER: rag_worker = RAG_WORKER_REGISTER[user_name] else: rag_worker = RAG_WORKER_REGISTER[user_name] = LlamaIndexRagWorker( - user_name, - llm_kwargs, - checkpoint_dir=checkpoint_dir, - auto_load_checkpoint=True) + user_name, + llm_kwargs, + checkpoint_dir=checkpoint_dir, + auto_load_checkpoint=True + ) + + for file_path in files: + try: + validate_path_safety(file_path, user_name) + text = extract_text(file_path) + document = Document(text=text, metadata={"source": file_path}) + rag_worker.add_documents_to_vector_store([document]) + chatbot.append([f"上传文件: {os.path.basename(file_path)}", "文件已成功添加到知识库。"]) + except Exception as e: + report_exception(chatbot, history, a=f"处理文件: {file_path}", b=str(e)) + + yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 + +@CatchException +def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request): + """ + Handles RAG-based Q&A, including special commands and document uploads. + + Args: + txt (str): User input text. + llm_kwargs: Language model keyword arguments. + plugin_kwargs: Plugin keyword arguments. + chatbot: Chatbot instance. + history: Chat history. + system_prompt: System prompt. + user_request: User request. + """ + # Define commands + CLEAR_VECTOR_DB_CMD = "清空向量数据库" + UPLOAD_DOCUMENT_CMD = "上传文档" + + # 1. Retrieve RAG worker from global context + user_name = chatbot.get_user() + checkpoint_dir = get_log_folder(user_name, plugin_name='experimental_rag') + + if user_name in RAG_WORKER_REGISTER: + rag_worker = RAG_WORKER_REGISTER[user_name] + else: + rag_worker = RAG_WORKER_REGISTER[user_name] = LlamaIndexRagWorker( + user_name, + llm_kwargs, + checkpoint_dir=checkpoint_dir, + auto_load_checkpoint=True + ) + current_context = f"{VECTOR_STORE_TYPE} @ {checkpoint_dir}" tip = "提示:输入“清空向量数据库”可以清空RAG向量数据库" - if txt == "清空向量数据库": - chatbot.append([txt, f'正在清空 ({current_context}) ...']) - yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 - rag_worker.purge() - yield from update_ui_lastest_msg('已清空', chatbot, history, delay=0) # 刷新界面 + + # 2. Handle special commands + if txt.startswith(UPLOAD_DOCUMENT_CMD): + # Extract file paths from the user input + # Assuming the user inputs file paths separated by commas after the command + file_paths = txt[len(UPLOAD_DOCUMENT_CMD):].strip().split(',') + file_paths = [path.strip() for path in file_paths if path.strip()] + + if not file_paths: + report_exception(chatbot, history, a="上传文档", b="未提供任何文件路径。") + yield from update_ui(chatbot=chatbot, history=history) + return + + chatbot.append([txt, f'正在处理上传的文档 ({current_context}) ...']) + yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 + + yield from handle_document_upload(file_paths, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request) return - chatbot.append([txt, f'正在召回知识 ({current_context}) ...']) - yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 + elif txt == CLEAR_VECTOR_DB_CMD: + chatbot.append([txt, f'正在清空 ({current_context}) ...']) + yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 + rag_worker.purge_vector_store() + yield from update_ui_lastest_msg('已清空', chatbot, history, delay=0) # 刷新界面 + return - # 2. clip history to reduce token consumption - # 2-1. reduce chat round + # 3. Normal Q&A processing + chatbot.append([txt, f'正在召回知识 ({current_context}) ...']) + yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 + + # 4. Clip history to reduce token consumption txt_origin = txt if len(history) > MAX_HISTORY_ROUND * 2: @@ -55,41 +136,48 @@ def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, u txt_clip, history, flags = input_clipping(txt, history, max_token_limit=MAX_CONTEXT_TOKEN_LIMIT, return_clip_flags=True) input_is_clipped_flag = (flags["original_input_len"] != flags["clipped_input_len"]) - # 2-2. if input is clipped, add input to vector store before retrieve + # 5. If input is clipped, add input to vector store before retrieve if input_is_clipped_flag: - yield from update_ui_lastest_msg('检测到长输入, 正在向量化 ...', chatbot, history, delay=0) # 刷新界面 - # save input to vector store + yield from update_ui_lastest_msg('检测到长输入, 正在向量化 ...', chatbot, history, delay=0) # 刷新界面 + # Save input to vector store rag_worker.add_text_to_vector_store(txt_origin) - yield from update_ui_lastest_msg('向量化完成 ...', chatbot, history, delay=0) # 刷新界面 + yield from update_ui_lastest_msg('向量化完成 ...', chatbot, history, delay=0) # 刷新界面 + if len(txt_origin) > REMEMBER_PREVIEW: - HALF = REMEMBER_PREVIEW//2 + HALF = REMEMBER_PREVIEW // 2 i_say_to_remember = txt[:HALF] + f" ...\n...(省略{len(txt_origin)-REMEMBER_PREVIEW}字)...\n... " + txt[-HALF:] if (flags["original_input_len"] - flags["clipped_input_len"]) > HALF: - txt_clip = txt_clip + f" ...\n...(省略{len(txt_origin)-len(txt_clip)-HALF}字)...\n... " + txt[-HALF:] - else: - pass - i_say = txt_clip + txt_clip = txt_clip + f" ...\n...(省略{len(txt_origin)-len(txt_clip)-HALF}字)...\n... " + txt[-HALF:] else: i_say_to_remember = i_say = txt_clip else: i_say_to_remember = i_say = txt_clip - # 3. we search vector store and build prompts + # 6. Search vector store and build prompts nodes = rag_worker.retrieve_from_store_with_query(i_say) prompt = rag_worker.build_prompt(query=i_say, nodes=nodes) - # 4. it is time to query llms - if len(chatbot) != 0: chatbot.pop(-1) # pop temp chat, because we are going to add them again inside `request_gpt_model_in_new_thread_with_ui_alive` + # 7. Query language model + if len(chatbot) != 0: + chatbot.pop(-1) # Pop temp chat, because we are going to add them again inside `request_gpt_model_in_new_thread_with_ui_alive` + model_say = yield from request_gpt_model_in_new_thread_with_ui_alive( - inputs=prompt, inputs_show_user=i_say, - llm_kwargs=llm_kwargs, chatbot=chatbot, history=history, + inputs=prompt, + inputs_show_user=i_say, + llm_kwargs=llm_kwargs, + chatbot=chatbot, + history=history, sys_prompt=system_prompt, retry_times_at_unknown_error=0 ) - # 5. remember what has been asked / answered - yield from update_ui_lastest_msg(model_say + '

' + f'对话记忆中, 请稍等 ({current_context}) ...', chatbot, history, delay=0.5) # 刷新界面 + # 8. Remember Q&A + yield from update_ui_lastest_msg( + model_say + '

' + f'对话记忆中, 请稍等 ({current_context}) ...', + chatbot, history, delay=0.5 + ) rag_worker.remember_qa(i_say_to_remember, model_say) history.extend([i_say, model_say]) - yield from update_ui_lastest_msg(model_say, chatbot, history, delay=0, msg=tip) # 刷新界面 + # 9. Final UI Update + yield from update_ui_lastest_msg(model_say, chatbot, history, delay=0, msg=tip) \ No newline at end of file diff --git a/crazy_functions/rag_fns/llama_index_worker.py b/crazy_functions/rag_fns/llama_index_worker.py index 7a559927..c366156a 100644 --- a/crazy_functions/rag_fns/llama_index_worker.py +++ b/crazy_functions/rag_fns/llama_index_worker.py @@ -1,16 +1,12 @@ -import llama_index -import os import atexit from typing import List + from llama_index.core import Document -from llama_index.core.schema import TextNode -from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel -from shared_utils.connect_void_terminal import get_chat_default_kwargs -from llama_index.core import VectorStoreIndex, SimpleDirectoryReader -from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex from llama_index.core.ingestion import run_transformations -from llama_index.core import PromptTemplate -from llama_index.core.response_synthesizers import TreeSummarize +from llama_index.core.schema import TextNode + +from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex +from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel DEFAULT_QUERY_GENERATION_PROMPT = """\ Now, you have context information as below: @@ -74,7 +70,7 @@ class LlamaIndexRagWorker(SaveLoad): if auto_load_checkpoint: self.vs_index = self.load_from_checkpoint(checkpoint_dir) else: - self.vs_index = self.create_new_vs(checkpoint_dir) + self.vs_index = self.create_new_vs() atexit.register(lambda: self.save_to_checkpoint(checkpoint_dir)) def assign_embedding_model(self): @@ -84,46 +80,59 @@ class LlamaIndexRagWorker(SaveLoad): # This function is for debugging self.vs_index.storage_context.index_store.to_dict() docstore = self.vs_index.storage_context.docstore.docs - vector_store_preview = "\n".join([ f"{_id} | {tn.text}" for _id, tn in docstore.items() ]) + vector_store_preview = "\n".join([f"{_id} | {tn.text}" for _id, tn in docstore.items()]) print('\n++ --------inspect_vector_store begin--------') print(vector_store_preview) print('oo --------inspect_vector_store end--------') return vector_store_preview - def add_documents_to_vector_store(self, document_list): - documents = [Document(text=t) for t in document_list] + def add_documents_to_vector_store(self, document_list: List[Document]): + """ + Adds a list of Document objects to the vector store after processing. + """ + documents = document_list documents_nodes = run_transformations( - documents, # type: ignore - self.vs_index._transformations, - show_progress=True - ) + documents, # type: ignore + self.vs_index._transformations, + show_progress=True + ) self.vs_index.insert_nodes(documents_nodes) - if self.debug_mode: self.inspect_vector_store() + if self.debug_mode: + self.inspect_vector_store() - def add_text_to_vector_store(self, text): + def add_text_to_vector_store(self, text: str): node = TextNode(text=text) documents_nodes = run_transformations( - [node], - self.vs_index._transformations, - show_progress=True - ) + [node], + self.vs_index._transformations, + show_progress=True + ) self.vs_index.insert_nodes(documents_nodes) - if self.debug_mode: self.inspect_vector_store() + if self.debug_mode: + self.inspect_vector_store() def remember_qa(self, question, answer): formatted_str = QUESTION_ANSWER_RECORD.format(question=question, answer=answer) self.add_text_to_vector_store(formatted_str) def retrieve_from_store_with_query(self, query): - if self.debug_mode: self.inspect_vector_store() + if self.debug_mode: + self.inspect_vector_store() retriever = self.vs_index.as_retriever() return retriever.retrieve(query) def build_prompt(self, query, nodes): context_str = self.generate_node_array_preview(nodes) return DEFAULT_QUERY_GENERATION_PROMPT.format(context_str=context_str, query_str=query) - + def generate_node_array_preview(self, nodes): - buf = "\n".join(([f"(No.{i+1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)])) - if self.debug_mode: print(buf) + buf = "\n".join([f"(No.{i+1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)]) + if self.debug_mode: + print(buf) return buf + + def purge_vector_store(self): + """ + Purges the current vector store and creates a new one. + """ + self.purge() \ No newline at end of file diff --git a/crazy_functions/rag_fns/rag_file_support.py b/crazy_functions/rag_fns/rag_file_support.py new file mode 100644 index 00000000..1ceb93c2 --- /dev/null +++ b/crazy_functions/rag_fns/rag_file_support.py @@ -0,0 +1,47 @@ +import os + +import markdown +from PyPDF2 import PdfReader +from bs4 import BeautifulSoup +from docx import Document as DocxDocument + + +def extract_text_from_pdf(file_path): + reader = PdfReader(file_path) + text = "" + for page in reader.pages: + text += page.extract_text() + "\n" + return text + +def extract_text_from_docx(file_path): + doc = DocxDocument(file_path) + return "\n".join([para.text for para in doc.paragraphs]) + +def extract_text_from_txt(file_path): + with open(file_path, 'r', encoding='utf-8') as f: + return f.read() + +def extract_text_from_md(file_path): + with open(file_path, 'r', encoding='utf-8') as f: + md_content = f.read() + return markdown.markdown(md_content) + +def extract_text_from_html(file_path): + with open(file_path, 'r', encoding='utf-8') as f: + soup = BeautifulSoup(f, 'html.parser') + return soup.get_text() + +def extract_text(file_path): + _, ext = os.path.splitext(file_path.lower()) + if ext == '.pdf': + return extract_text_from_pdf(file_path) + elif ext in ['.docx', '.doc']: + return extract_text_from_docx(file_path) + elif ext == '.txt': + return extract_text_from_txt(file_path) + elif ext == '.md': + return extract_text_from_md(file_path) + elif ext == '.html': + return extract_text_from_html(file_path) + else: + raise ValueError(f"Unsupported file extension: {ext}") \ No newline at end of file