From 2d12b5b27db5e58bd1f04120408dbc5403a7aead Mon Sep 17 00:00:00 2001 From: lbykkkk Date: Fri, 11 Oct 2024 01:06:17 +0800 Subject: [PATCH] RAG interactive prompts added, issues resolved --- crazy_functions/Rag_Interface.py | 25 +++++---- crazy_functions/rag_fns/llama_index_worker.py | 2 +- crazy_functions/rag_fns/rag_file_support.py | 53 +++++++------------ 3 files changed, 35 insertions(+), 45 deletions(-) diff --git a/crazy_functions/Rag_Interface.py b/crazy_functions/Rag_Interface.py index 8e0cb51b..2f0bc0f4 100644 --- a/crazy_functions/Rag_Interface.py +++ b/crazy_functions/Rag_Interface.py @@ -56,9 +56,15 @@ def handle_document_upload(files: List[str], llm_kwargs, plugin_kwargs, chatbot, try: validate_path_safety(file_path, user_name) text = extract_text(file_path) - document = Document(text=text, metadata={"source": file_path}) - rag_worker.add_documents_to_vector_store([document]) - chatbot.append([f"上传文件: {os.path.basename(file_path)}", "文件已成功添加到知识库。"]) + if text is None: + chatbot.append( + [f"上传文件: {os.path.basename(file_path)}", "文件解析失败,无法提取文本内容,请更换文件。"]) + else: + chatbot.append( + [f"上传文件: {os.path.basename(file_path)}", f"上传文件前50个字符为:{text[:50]}。"]) + document = Document(text=text, metadata={"source": file_path}) + rag_worker.add_documents_to_vector_store([document]) + chatbot.append([f"上传文件: {os.path.basename(file_path)}", "文件已成功添加到知识库。"]) except Exception as e: report_exception(chatbot, history, a=f"处理文件: {file_path}", b=str(e)) @@ -100,18 +106,12 @@ def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, u tip = "提示:输入“清空向量数据库”可以清空RAG向量数据库" # 2. Handle special commands - if os.path.exists(txt): + if os.path.exists(txt) and os.path.isdir(txt): project_folder = txt validate_path_safety(project_folder, chatbot.get_user()) # Extract file paths from the user input # Assuming the user inputs file paths separated by commas after the command file_paths = [f for f in glob.glob(f'{project_folder}/**/*', recursive=True)] - - if not txt: - report_exception(chatbot, history, a="上传文档", b="未提供任何文件路径。") - yield from update_ui(chatbot=chatbot, history=history) - return - chatbot.append([txt, f'正在处理上传的文档 ({current_context}) ...']) yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 @@ -125,9 +125,12 @@ def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, u yield from update_ui_lastest_msg('已清空', chatbot, history, delay=0) # 刷新界面 return + else: + report_exception(chatbot, history, a=f"上传文件路径错误: {txt}", b="请检查并提供正确路径。") + # 3. Normal Q&A processing chatbot.append([txt, f'正在召回知识 ({current_context}) ...']) - # yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 + yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 4. Clip history to reduce token consumption txt_origin = txt diff --git a/crazy_functions/rag_fns/llama_index_worker.py b/crazy_functions/rag_fns/llama_index_worker.py index 8733cdb4..59a5827c 100644 --- a/crazy_functions/rag_fns/llama_index_worker.py +++ b/crazy_functions/rag_fns/llama_index_worker.py @@ -59,7 +59,7 @@ class SaveLoad(): def purge(self): import shutil shutil.rmtree(self.checkpoint_dir, ignore_errors=True) - self.vs_index = self.create_new_vs() + self.vs_index = self.create_new_vs(self.checkpoint_dir) class LlamaIndexRagWorker(SaveLoad): diff --git a/crazy_functions/rag_fns/rag_file_support.py b/crazy_functions/rag_fns/rag_file_support.py index 1ceb93c2..50a07615 100644 --- a/crazy_functions/rag_fns/rag_file_support.py +++ b/crazy_functions/rag_fns/rag_file_support.py @@ -1,10 +1,8 @@ import os +from llama_index.core import SimpleDirectoryReader -import markdown +# 保留你原有的自定义解析函数 from PyPDF2 import PdfReader -from bs4 import BeautifulSoup -from docx import Document as DocxDocument - def extract_text_from_pdf(file_path): reader = PdfReader(file_path) @@ -13,35 +11,24 @@ def extract_text_from_pdf(file_path): text += page.extract_text() + "\n" return text -def extract_text_from_docx(file_path): - doc = DocxDocument(file_path) - return "\n".join([para.text for para in doc.paragraphs]) - -def extract_text_from_txt(file_path): - with open(file_path, 'r', encoding='utf-8') as f: - return f.read() - -def extract_text_from_md(file_path): - with open(file_path, 'r', encoding='utf-8') as f: - md_content = f.read() - return markdown.markdown(md_content) - -def extract_text_from_html(file_path): - with open(file_path, 'r', encoding='utf-8') as f: - soup = BeautifulSoup(f, 'html.parser') - return soup.get_text() - +# 修改后的 extract_text 函数,结合 SimpleDirectoryReader 和自定义解析逻辑 def extract_text(file_path): _, ext = os.path.splitext(file_path.lower()) + + # 使用 SimpleDirectoryReader 处理它支持的文件格式 + if ext in ['.txt', '.md', '.pdf', '.docx', '.html']: + try: + reader = SimpleDirectoryReader(input_files=[file_path]) + documents = reader.load_data() + if len(documents) > 0: + return documents[0].text + except Exception as e: + pass + + # 如果 SimpleDirectoryReader 失败,或文件格式不支持,使用自定义解析逻辑 if ext == '.pdf': - return extract_text_from_pdf(file_path) - elif ext in ['.docx', '.doc']: - return extract_text_from_docx(file_path) - elif ext == '.txt': - return extract_text_from_txt(file_path) - elif ext == '.md': - return extract_text_from_md(file_path) - elif ext == '.html': - return extract_text_from_html(file_path) - else: - raise ValueError(f"Unsupported file extension: {ext}") \ No newline at end of file + try: + return extract_text_from_pdf(file_path) + except Exception as e: + pass + return None