diff --git a/.gitignore b/.gitignore index be33f58c..5ad02c3f 100644 --- a/.gitignore +++ b/.gitignore @@ -160,4 +160,5 @@ test.* temp.* objdump* *.min.*.js -TODO \ No newline at end of file +TODO +*.cursorrules diff --git a/crazy_functional.py b/crazy_functional.py index 248262cc..ddac5815 100644 --- a/crazy_functional.py +++ b/crazy_functional.py @@ -15,6 +15,7 @@ def get_crazy_functions(): from crazy_functions.SourceCode_Analyse import 解析一个Rust项目 from crazy_functions.SourceCode_Analyse import 解析一个Java项目 from crazy_functions.SourceCode_Analyse import 解析一个前端项目 + from crazy_functions.Arxiv_论文对话 import Rag论文对话 from crazy_functions.高级功能函数模板 import 高阶功能模板函数 from crazy_functions.高级功能函数模板 import Demo_Wrap from crazy_functions.Latex全文润色 import Latex英文润色 @@ -27,7 +28,6 @@ def get_crazy_functions(): from crazy_functions.Conversation_To_File import Conversation_To_File_Wrap from crazy_functions.Conversation_To_File import 删除所有本地对话历史记录 from crazy_functions.辅助功能 import 清除缓存 - from crazy_functions.批量文件询问 import 批量文件询问 from crazy_functions.Markdown_Translate import Markdown英译中 from crazy_functions.批量总结PDF文档 import 批量总结PDF文档 from crazy_functions.PDF_Translate import 批量翻译PDF文档 @@ -59,34 +59,6 @@ def get_crazy_functions(): "Info": "使用自然语言实现您的想法", "Function": HotReload(虚空终端), }, - "解析整个Python项目": { - "Group": "编程", - "Color": "stop", - "AsButton": True, - "Info": "解析一个Python项目的所有源文件(.py) | 输入参数为路径", - "Function": HotReload(解析一个Python项目), - }, - "注释Python项目": { - "Group": "编程", - "Color": "stop", - "AsButton": False, - "Info": "上传一系列python源文件(或者压缩包), 为这些代码添加docstring | 输入参数为路径", - "Function": HotReload(注释Python项目), - "Class": SourceCodeComment_Wrap, - }, - "载入对话历史存档(先上传存档或输入路径)": { - "Group": "对话", - "Color": "stop", - "AsButton": False, - "Info": "载入对话历史存档 | 输入参数为路径", - "Function": HotReload(载入对话历史存档), - }, - "删除所有本地对话历史记录(谨慎操作)": { - "Group": "对话", - "AsButton": False, - "Info": "删除所有本地对话历史记录,谨慎操作 | 不需要输入参数", - "Function": HotReload(删除所有本地对话历史记录), - }, "清除所有缓存文件(谨慎操作)": { "Group": "对话", "Color": "stop", @@ -94,14 +66,6 @@ def get_crazy_functions(): "Info": "清除所有缓存文件,谨慎操作 | 不需要输入参数", "Function": HotReload(清除缓存), }, - "生成多种Mermaid图表(从当前对话或路径(.pdf/.md/.docx)中生产图表)": { - "Group": "对话", - "Color": "stop", - "AsButton": False, - "Info" : "基于当前对话或文件生成多种Mermaid图表,图表类型由模型判断", - "Function": None, - "Class": Mermaid_Gen - }, "Arxiv论文翻译": { "Group": "学术", "Color": "stop", @@ -110,92 +74,12 @@ def get_crazy_functions(): "Function": HotReload(Latex翻译中文并重新编译PDF), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用 "Class": Arxiv_Localize, # 新一代插件需要注册Class }, - "批量文件询问": { + "Rag论文对话": { "Group": "学术", "Color": "stop", "AsButton": False, - "AdvancedArgs": True, - "Info": "通过在高级参数区写入prompt,可自定义询问逻辑,默认情况下为总结逻辑 | 输入参数为路径", - "Function": HotReload(批量文件询问), - }, - "解析整个Matlab项目": { - "Group": "编程", - "Color": "stop", - "AsButton": False, - "Info": "解析一个Matlab项目的所有源文件(.m) | 输入参数为路径", - "Function": HotReload(解析一个Matlab项目), - }, - "解析整个C++项目头文件": { - "Group": "编程", - "Color": "stop", - "AsButton": False, # 加入下拉菜单中 - "Info": "解析一个C++项目的所有头文件(.h/.hpp) | 输入参数为路径", - "Function": HotReload(解析一个C项目的头文件), - }, - "解析整个C++项目(.cpp/.hpp/.c/.h)": { - "Group": "编程", - "Color": "stop", - "AsButton": False, # 加入下拉菜单中 - "Info": "解析一个C++项目的所有源文件(.cpp/.hpp/.c/.h)| 输入参数为路径", - "Function": HotReload(解析一个C项目), - }, - "解析整个Go项目": { - "Group": "编程", - "Color": "stop", - "AsButton": False, # 加入下拉菜单中 - "Info": "解析一个Go项目的所有源文件 | 输入参数为路径", - "Function": HotReload(解析一个Golang项目), - }, - "解析整个Rust项目": { - "Group": "编程", - "Color": "stop", - "AsButton": False, # 加入下拉菜单中 - "Info": "解析一个Rust项目的所有源文件 | 输入参数为路径", - "Function": HotReload(解析一个Rust项目), - }, - "解析整个Java项目": { - "Group": "编程", - "Color": "stop", - "AsButton": False, # 加入下拉菜单中 - "Info": "解析一个Java项目的所有源文件 | 输入参数为路径", - "Function": HotReload(解析一个Java项目), - }, - "解析整个前端项目(js,ts,css等)": { - "Group": "编程", - "Color": "stop", - "AsButton": False, # 加入下拉菜单中 - "Info": "解析一个前端项目的所有源文件(js,ts,css等) | 输入参数为路径", - "Function": HotReload(解析一个前端项目), - }, - "解析整个Lua项目": { - "Group": "编程", - "Color": "stop", - "AsButton": False, # 加入下拉菜单中 - "Info": "解析一个Lua项目的所有源文件 | 输入参数为路径", - "Function": HotReload(解析一个Lua项目), - }, - "解析整个CSharp项目": { - "Group": "编程", - "Color": "stop", - "AsButton": False, # 加入下拉菜单中 - "Info": "解析一个CSharp项目的所有源文件 | 输入参数为路径", - "Function": HotReload(解析一个CSharp项目), - }, - "解析Jupyter Notebook文件": { - "Group": "编程", - "Color": "stop", - "AsButton": False, - "Info": "解析Jupyter Notebook文件 | 输入参数为路径", - "Function": HotReload(解析ipynb文件), - "AdvancedArgs": True, # 调用时,唤起高级参数输入区(默认False) - "ArgsReminder": "若输入0,则不解析notebook中的Markdown块", # 高级参数输入区的显示提示 - }, - "读Tex论文写摘要": { - "Group": "学术", - "Color": "stop", - "AsButton": False, - "Info": "读取Tex论文并写摘要 | 输入参数为路径", - "Function": HotReload(读文章写摘要), + "Info": "Arixv论文精细翻译 | 输入参数arxiv论文的ID,比如1812.10695", + "Function": HotReload(Rag论文对话), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用 }, "翻译README或MD": { "Group": "编程", diff --git a/crazy_functions/Arxiv_论文对话.py b/crazy_functions/Arxiv_论文对话.py index be448319..ed110325 100644 --- a/crazy_functions/Arxiv_论文对话.py +++ b/crazy_functions/Arxiv_论文对话.py @@ -1,6 +1,7 @@ +import os.path + from toolbox import CatchException, update_ui from crazy_functions.rag_essay_fns.paper_processing import ArxivPaperProcessor -from crazy_functions.rag_essay_fns.rag_handler import RagHandler import asyncio @CatchException @@ -9,7 +10,24 @@ def Rag论文对话(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_pro txt: 用户输入,通常是arxiv论文链接 功能:RAG论文总结和对话 """ - # 初始化处理器 + if_project, if_arxiv = False, False + if os.path.exists(txt): + from crazy_functions.rag_essay_fns.document_splitter import SmartDocumentSplitter + splitter = SmartDocumentSplitter( + char_range=(1000, 1200), + max_workers=32 # 可选,默认会根据CPU核心数自动设置 + ) + if_project = True + else: + from crazy_functions.rag_essay_fns.arxiv_splitter import SmartArxivSplitter + splitter = SmartArxivSplitter( + char_range=(1000, 1200), + root_dir="gpt_log/arxiv_cache" + ) + if_arxiv = True + for fragment in splitter.process(txt): + pass + # 初始化处理器 processor = ArxivPaperProcessor() rag_handler = RagHandler() diff --git a/crazy_functions/Conversation_To_File.py b/crazy_functions/Conversation_To_File.py index b8408748..e572f8fd 100644 --- a/crazy_functions/Conversation_To_File.py +++ b/crazy_functions/Conversation_To_File.py @@ -152,8 +152,6 @@ class Conversation_To_File_Wrap(GptAcademicPluginTemplate): - - def hide_cwd(str): import os current_path = os.getcwd() diff --git a/crazy_functions/rag_essay_fns/arxiv_splitter.py b/crazy_functions/rag_essay_fns/arxiv_splitter.py index e978b25d..0af46acc 100644 --- a/crazy_functions/rag_essay_fns/arxiv_splitter.py +++ b/crazy_functions/rag_essay_fns/arxiv_splitter.py @@ -20,9 +20,7 @@ class ArxivFragment: segment_type: str title: str abstract: str - section: str # 保存完整的section层级路径,如 "Introduction" 或 "Methods-Data Processing" - section_type: str # 新增:标识片段类型,如 "abstract", "section", "subsection" 等 - section_level: int # 新增:section的层级深度,abstract为0,main section为1,subsection为2,等等 + section: str is_appendix: bool @@ -116,6 +114,100 @@ class SmartArxivSplitter: return result + def _smart_split(self, content: str) -> List[Tuple[str, str, bool]]: + """智能分割TEX内容,确保在字符范围内并保持语义完整性""" + content = self._preprocess_content(content) + segments = [] + current_buffer = [] + current_length = 0 + current_section = "Unknown Section" + is_appendix = False + + # 保护特殊环境 + protected_blocks = {} + content = self._protect_special_environments(content, protected_blocks) + + # 按段落分割 + paragraphs = re.split(r'\n\s*\n', content) + + for para in paragraphs: + para = para.strip() + if not para: + continue + + # 恢复特殊环境 + para = self._restore_special_environments(para, protected_blocks) + + # 更新章节信息 + section_info = self._get_section_info(para, content) + if section_info: + current_section, is_appendix = section_info + + # 判断是否是特殊环境 + if self._is_special_environment(para): + # 处理当前缓冲区 + if current_buffer: + segments.append(( + '\n'.join(current_buffer), + current_section, + is_appendix + )) + current_buffer = [] + current_length = 0 + + # 添加特殊环境作为独立片段 + segments.append((para, current_section, is_appendix)) + continue + + # 处理普通段落 + sentences = self._split_into_sentences(para) + for sentence in sentences: + sentence = sentence.strip() + if not sentence: + continue + + sent_length = len(sentence) + new_length = current_length + sent_length + (1 if current_buffer else 0) + + if new_length <= self.max_chars: + current_buffer.append(sentence) + current_length = new_length + else: + # 如果当前缓冲区达到最小长度要求 + if current_length >= self.min_chars: + segments.append(( + '\n'.join(current_buffer), + current_section, + is_appendix + )) + current_buffer = [sentence] + current_length = sent_length + else: + # 尝试将过长的句子分割 + split_sentences = self._split_long_sentence(sentence) + for split_sent in split_sentences: + if current_length + len(split_sent) <= self.max_chars: + current_buffer.append(split_sent) + current_length += len(split_sent) + 1 + else: + segments.append(( + '\n'.join(current_buffer), + current_section, + is_appendix + )) + current_buffer = [split_sent] + current_length = len(split_sent) + + # 处理剩余的缓冲区 + if current_buffer: + segments.append(( + '\n'.join(current_buffer), + current_section, + is_appendix + )) + + return segments + def _split_into_sentences(self, text: str) -> List[str]: """将文本分割成句子""" return re.split(r'(?<=[.!?。!?])\s+', text) @@ -194,7 +286,7 @@ class SmartArxivSplitter: content = re.sub(r'\\(label|ref|cite)\{[^}]*\}', '', content) return content.strip() - def process_paper(self, arxiv_id_or_url: str) -> Generator[ArxivFragment, None, None]: + def process(self, arxiv_id_or_url: str) -> Generator[ArxivFragment, None, None]: """处理单篇arxiv论文""" try: arxiv_id = self._normalize_arxiv_id(arxiv_id_or_url) @@ -318,31 +410,16 @@ class SmartArxivSplitter: return title.strip(), abstract.strip() - def _get_section_info(self, para: str, content: str) -> Optional[Tuple[str, str, int, bool]]: - """获取段落所属的章节信息,返回(section_path, section_type, level, is_appendix)""" - current_path = [] - section_type = "content" - level = 0 + def _get_section_info(self, para: str, content: str) -> Optional[Tuple[str, bool]]: + """获取段落所属的章节信息""" + section = "Unknown Section" is_appendix = False - # 定义section层级的正则模式 - section_patterns = { - r'\\chapter\{([^}]+)\}': 1, - r'\\section\{([^}]+)\}': 1, - r'\\subsection\{([^}]+)\}': 2, - r'\\subsubsection\{([^}]+)\}': 3 - } - # 查找所有章节标记 all_sections = [] - for pattern, sec_level in section_patterns.items(): + for pattern in self.section_patterns: for match in re.finditer(pattern, content): - all_sections.append((match.start(), match.group(1), sec_level)) - - # 检查是否是摘要 - abstract_match = re.search(r'\\begin{abstract}.*?' + re.escape(para), content, re.DOTALL) - if abstract_match: - return "Abstract", "abstract", 0, False + all_sections.append((match.start(), match.group(2))) # 查找appendix标记 appendix_pos = content.find(r'\appendix') @@ -350,118 +427,19 @@ class SmartArxivSplitter: # 确定当前章节 para_pos = content.find(para) if para_pos >= 0: - is_appendix = appendix_pos >= 0 and para_pos > appendix_pos - current_sections = [] - current_level = 0 - - # 按位置排序所有section标记 - for sec_pos, sec_title, sec_level in sorted(all_sections): + current_section = None + for sec_pos, sec_title in sorted(all_sections): if sec_pos > para_pos: break - # 如果遇到更高层级的section,清除所有更低层级的section - if sec_level <= current_level: - current_sections = [s for s in current_sections if s[1] < sec_level] - current_sections.append((sec_title, sec_level)) - current_level = sec_level + current_section = sec_title - # 构建section路径 - if current_sections: - current_path = [s[0] for s in sorted(current_sections, key=lambda x: x[1])] - section_path = "-".join(current_path) - level = max(s[1] for s in current_sections) - section_type = "section" if level == 1 else "subsection" - return section_path, section_type, level, is_appendix + if current_section: + section = current_section + is_appendix = appendix_pos >= 0 and para_pos > appendix_pos - return "Unknown Section", "content", 0, is_appendix + return section, is_appendix - def _smart_split(self, content: str) -> List[Tuple[str, str, str, int, bool]]: - """智能分割TEX内容,确保在字符范围内并保持语义完整性""" - content = self._preprocess_content(content) - segments = [] - current_buffer = [] - current_length = 0 - current_section_info = ("Unknown Section", "content", 0, False) - - # 保护特殊环境 - protected_blocks = {} - content = self._protect_special_environments(content, protected_blocks) - - # 按段落分割 - paragraphs = re.split(r'\n\s*\n', content) - - for para in paragraphs: - para = para.strip() - if not para: - continue - - # 恢复特殊环境 - para = self._restore_special_environments(para, protected_blocks) - - # 更新章节信息 - section_info = self._get_section_info(para, content) - if section_info: - current_section_info = section_info - - # 判断是否是特殊环境 - if self._is_special_environment(para): - # 处理当前缓冲区 - if current_buffer: - segments.append(( - '\n'.join(current_buffer), - *current_section_info - )) - current_buffer = [] - current_length = 0 - - # 添加特殊环境作为独立片段 - segments.append((para, *current_section_info)) - continue - - # 处理普通段落 - sentences = self._split_into_sentences(para) - for sentence in sentences: - sentence = sentence.strip() - if not sentence: - continue - - sent_length = len(sentence) - new_length = current_length + sent_length + (1 if current_buffer else 0) - - if new_length <= self.max_chars: - current_buffer.append(sentence) - current_length = new_length - else: - # 如果当前缓冲区达到最小长度要求 - if current_length >= self.min_chars: - segments.append(( - '\n'.join(current_buffer), - *current_section_info - )) - current_buffer = [sentence] - current_length = sent_length - else: - # 尝试将过长的句子分割 - split_sentences = self._split_long_sentence(sentence) - for split_sent in split_sentences: - if current_length + len(split_sent) <= self.max_chars: - current_buffer.append(split_sent) - current_length += len(split_sent) + 1 - else: - segments.append(( - '\n'.join(current_buffer), - *current_section_info - )) - current_buffer = [split_sent] - current_length = len(split_sent) - - # 处理剩余的缓冲区 - if current_buffer: - segments.append(( - '\n'.join(current_buffer), - *current_section_info - )) - - return segments + return None def _process_single_tex(self, file_path: str) -> List[ArxivFragment]: """处理单个TEX文件""" @@ -481,12 +459,12 @@ class SmartArxivSplitter: segments = self._smart_split(content) fragments = [] - for i, (segment_content, section_path, section_type, level, is_appendix) in enumerate(segments): + for i, (segment_content, section, is_appendix) in enumerate(segments): if segment_content.strip(): segment_type = 'text' for env_type, patterns in self.special_envs.items(): if any(re.search(pattern, segment_content, re.DOTALL) - for pattern in patterns): + for pattern in patterns): segment_type = env_type break @@ -499,9 +477,7 @@ class SmartArxivSplitter: segment_type=segment_type, title=title, abstract=abstract, - section=section_path, - section_type=section_type, - section_level=level, + section=section, is_appendix=is_appendix )) @@ -511,7 +487,6 @@ class SmartArxivSplitter: logging.error(f"Error processing file {file_path}: {e}") return [] - def main(): """使用示例""" # 创建分割器实例 @@ -521,10 +496,11 @@ def main(): ) # 处理论文 - for fragment in splitter.process_paper("2411.03663"): + for fragment in splitter.process("2411.03663"): print(f"Segment {fragment.segment_index + 1}/{fragment.total_segments}") print(f"Length: {len(fragment.content)}") print(f"Section: {fragment.section}") + print(f"Title: {fragment.file_path}") print(fragment.content) print("-" * 80) diff --git a/crazy_functions/rag_essay_fns/paper_processing.py b/crazy_functions/rag_essay_fns/paper_processing.py index 5e60dc67..1e46407a 100644 --- a/crazy_functions/rag_essay_fns/paper_processing.py +++ b/crazy_functions/rag_essay_fns/paper_processing.py @@ -1,7 +1,6 @@ from typing import Tuple, Optional, Generator, List from toolbox import update_ui, update_ui_lastest_msg, get_conf import os, tarfile, requests, time, re - class ArxivPaperProcessor: """Arxiv论文处理器类""" diff --git a/crazy_functions/rag_essay_fns/rag_handler.py b/crazy_functions/rag_essay_fns/rag_handler.py index 89631dae..0a8d4742 100644 --- a/crazy_functions/rag_essay_fns/rag_handler.py +++ b/crazy_functions/rag_essay_fns/rag_handler.py @@ -81,5 +81,84 @@ class RagHandler: ) ) return response + except Exception as e: + return f"查询出错: {str(e)}" + + +class RagHandler: + def __init__(self): + # 初始化工作目录 + self.working_dir = os.path.join(get_conf('ARXIV_CACHE_DIR'), 'rag_cache') + if not os.path.exists(self.working_dir): + os.makedirs(self.working_dir) + + # 初始化 LightRAG + self.rag = LightRAG( + working_dir=self.working_dir, + llm_model_func=self._llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=1536, # OpenAI embedding 维度 + max_token_size=8192, + func=self._embedding_func, + ), + ) + + async def _llm_model_func(self, prompt: str, system_prompt: str = None, + history_messages: List = None, **kwargs) -> str: + """LLM 模型函数""" + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + if history_messages: + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + + response = await openai.ChatCompletion.acreate( + model="gpt-3.5-turbo", + messages=messages, + temperature=kwargs.get("temperature", 0), + max_tokens=kwargs.get("max_tokens", 1000) + ) + return response.choices[0].message.content + + async def _embedding_func(self, texts: List[str]) -> np.ndarray: + """Embedding 函数""" + response = await openai.Embedding.acreate( + model="text-embedding-ada-002", + input=texts + ) + embeddings = [item["embedding"] for item in response["data"]] + return np.array(embeddings) + + def process_paper_content(self, paper_content: Dict) -> None: + """处理论文内容,构建知识图谱""" + # 处理标题和摘要 + content_list = [] + if paper_content['title']: + content_list.append(f"Title: {paper_content['title']}") + if paper_content['abstract']: + content_list.append(f"Abstract: {paper_content['abstract']}") + + # 添加分段内容 + content_list.extend(paper_content['segments']) + + # 插入到 RAG 系统 + self.rag.insert(content_list) + + def query(self, question: str, mode: str = "hybrid") -> str: + """查询论文内容 + mode: 查询模式,可选 naive/local/global/hybrid + """ + try: + response = self.rag.query( + question, + param=QueryParam( + mode=mode, + top_k=5, # 返回相关度最高的5个结果 + max_token_for_text_unit=2048, # 每个文本单元的最大token数 + response_type="detailed" # 返回详细回答 + ) + ) + return response except Exception as e: return f"查询出错: {str(e)}" \ No newline at end of file diff --git a/crazy_functions/rag_fns/llama_index_worker.py b/crazy_functions/rag_fns/llama_index_worker.py index 59a5827c..df1809e4 100644 --- a/crazy_functions/rag_fns/llama_index_worker.py +++ b/crazy_functions/rag_fns/llama_index_worker.py @@ -1,10 +1,10 @@ import atexit -from loguru import logger -from typing import List +from typing import List, Dict, Optional, Any, Tuple from llama_index.core import Document from llama_index.core.ingestion import run_transformations -from llama_index.core.schema import TextNode +from llama_index.core.schema import TextNode, NodeWithScore +from loguru import logger from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel @@ -135,4 +135,170 @@ class LlamaIndexRagWorker(SaveLoad): """ Purges the current vector store and creates a new one. """ - self.purge() \ No newline at end of file + self.purge() + + + + """ + 以下是添加的新方法,原有方法保持不变 + """ + +def add_text_with_metadata(self, text: str, metadata: dict) -> str: + """ + 添加带元数据的文本到向量存储 + + Args: + text: 文本内容 + metadata: 元数据字典 + + Returns: + 添加的节点ID + """ + node = TextNode(text=text, metadata=metadata) + nodes = run_transformations( + [node], + self.vs_index._transformations, + show_progress=True + ) + self.vs_index.insert_nodes(nodes) + return nodes[0].node_id if nodes else None + +def batch_add_texts_with_metadata(self, texts: List[Tuple[str, dict]]) -> List[str]: + """ + 批量添加带元数据的文本 + + Args: + texts: (text, metadata)元组列表 + + Returns: + 添加的节点ID列表 + """ + nodes = [TextNode(text=t, metadata=m) for t, m in texts] + transformed_nodes = run_transformations( + nodes, + self.vs_index._transformations, + show_progress=True + ) + if transformed_nodes: + self.vs_index.insert_nodes(transformed_nodes) + return [node.node_id for node in transformed_nodes] + return [] + +def get_node_metadata(self, node_id: str) -> Optional[dict]: + """ + 获取节点的元数据 + + Args: + node_id: 节点ID + + Returns: + 节点的元数据字典 + """ + node = self.vs_index.storage_context.docstore.docs.get(node_id) + return node.metadata if node else None + +def update_node_metadata(self, node_id: str, metadata: dict, merge: bool = True) -> bool: + """ + 更新节点的元数据 + + Args: + node_id: 节点ID + metadata: 新的元数据 + merge: 是否与现有元数据合并 + + Returns: + 是否更新成功 + """ + docstore = self.vs_index.storage_context.docstore + if node_id in docstore.docs: + node = docstore.docs[node_id] + if merge: + node.metadata.update(metadata) + else: + node.metadata = metadata + return True + return False + +def filter_nodes_by_metadata(self, filters: Dict[str, Any]) -> List[TextNode]: + """ + 按元数据过滤节点 + + Args: + filters: 元数据过滤条件 + + Returns: + 符合条件的节点列表 + """ + docstore = self.vs_index.storage_context.docstore + results = [] + for node in docstore.docs.values(): + if all(node.metadata.get(k) == v for k, v in filters.items()): + results.append(node) + return results + +def retrieve_with_metadata_filter( + self, + query: str, + metadata_filters: Dict[str, Any], + top_k: int = 5 +) -> List[NodeWithScore]: + """ + 结合元数据过滤的检索 + + Args: + query: 查询文本 + metadata_filters: 元数据过滤条件 + top_k: 返回结果数量 + + Returns: + 检索结果节点列表 + """ + retriever = self.vs_index.as_retriever(similarity_top_k=top_k) + nodes = retriever.retrieve(query) + + # 应用元数据过滤 + filtered_nodes = [] + for node in nodes: + if all(node.metadata.get(k) == v for k, v in metadata_filters.items()): + filtered_nodes.append(node) + + return filtered_nodes + +def get_node_stats(self, node_id: str) -> dict: + """ + 获取单个节点的统计信息 + + Args: + node_id: 节点ID + + Returns: + 节点统计信息字典 + """ + node = self.vs_index.storage_context.docstore.docs.get(node_id) + if not node: + return {} + + return { + "text_length": len(node.text), + "token_count": len(node.text.split()), + "has_embedding": node.embedding is not None, + "metadata_keys": list(node.metadata.keys()), + } + +def get_nodes_by_content_pattern(self, pattern: str) -> List[TextNode]: + """ + 按内容模式查找节点 + + Args: + pattern: 正则表达式模式 + + Returns: + 匹配的节点列表 + """ + import re + docstore = self.vs_index.storage_context.docstore + matched_nodes = [] + for node in docstore.docs.values(): + if re.search(pattern, node.text): + matched_nodes.append(node) + return matched_nodes \ No newline at end of file diff --git a/crazy_functions/rag_fns/rag_file_support.py b/crazy_functions/rag_fns/rag_file_support.py index 98ba3bee..f826fab1 100644 --- a/crazy_functions/rag_fns/rag_file_support.py +++ b/crazy_functions/rag_fns/rag_file_support.py @@ -1,22 +1,45 @@ import os from llama_index.core import SimpleDirectoryReader -supports_format = ['.csv', '.docx', '.epub', '.ipynb', '.mbox', '.md', '.pdf', '.txt', '.ppt', - '.pptm', '.pptx'] +supports_format = ['.csv', '.docx','.doc', '.epub', '.ipynb', '.mbox', '.md', '.pdf', '.txt', '.ppt', + '.pptm', '.pptx','.py', '.xls', '.xlsx', '.html', '.json', '.xml', '.yaml', '.yml' ,'.m'] +def read_docx_doc(file_path): + if file_path.split(".")[-1] == "docx": + from docx import Document + doc = Document(file_path) + file_content = "\n".join([para.text for para in doc.paragraphs]) + else: + try: + import win32com.client + word = win32com.client.Dispatch("Word.Application") + word.visible = False + # 打开文件 + doc = word.Documents.Open(os.getcwd() + '/' + file_path) + # file_content = doc.Content.Text + doc = word.ActiveDocument + file_content = doc.Range().Text + doc.Close() + word.Quit() + except: + raise RuntimeError('请先将.doc文档转换为.docx文档。') + return file_content # 修改后的 extract_text 函数,结合 SimpleDirectoryReader 和自定义解析逻辑 +import os + def extract_text(file_path): _, ext = os.path.splitext(file_path.lower()) # 使用 SimpleDirectoryReader 处理它支持的文件格式 - if ext in supports_format: - try: - reader = SimpleDirectoryReader(input_files=[file_path]) - documents = reader.load_data() - if len(documents) > 0: - return documents[0].text - except Exception as e: - pass + if ext in ['.docx', '.doc']: + return read_docx_doc(file_path) + try: + reader = SimpleDirectoryReader(input_files=[file_path]) + documents = reader.load_data() + if len(documents) > 0: + return documents[0].text + except Exception as e: + pass return None diff --git a/instruction.txt b/instruction.txt index 64e8221f..595cbafa 100644 --- a/instruction.txt +++ b/instruction.txt @@ -102,6 +102,7 @@ │ ├── 总结音视频.py │ ├── 批量总结PDF文档.py │ ├── 批量总结PDF文档pdfminer.py +│ ├── 批量文件询问.py │ ├── 批量翻译PDF文档_NOUGAT.py │ ├── 数学动画生成manim.py │ ├── 理解PDF文档内容.py @@ -175,16 +176,39 @@ └── setup.py -3、我需要开发一个rag插件,请帮我实现一个插件,插件的名称是rag论文总结,插件主入口在crazy_functions/Arxiv_论文对话.py中的Rag论文对话函数,插件的功能步骤分为文件处理和RAG两个步骤 - 文件处理步骤流程和要求按顺序如下,请参考gpt_academic已实现的功能复用现有函数即可: - a. 支持从 arXiv 下载论文源码、检查本地项目路径、扫描 .tex 文件,此步骤可参考crazy_functions/Latex_Function.py。 - b、在项目中找到主要的 LaTeX 文件,将多个 TEX 文件合并成一个大的 TEX 文件,便于统一处理,此步骤可参考crazy_functions/Latex_Function.py。 - c、将合并后的文档进行精细切分,包括读取标题和摘要,此步骤可参考crazy_functions/Latex_Function.py。 - d、将文档按照 token 限制(1024)进行进一步分段,此步骤可参考crazy_functions/Latex_Function.py。 +3、我需要开发一个rag插件,请帮我实现一个插件,插件的名称是rag论文总结,插件主入口在crazy_functions/Arxiv_论文对话.py中的Rag论文对话函数,插件的功能步骤分为文件处理和RAG两个步骤,以下是具体的一些要求: +I. 函数头如下: +@CatchException +def rag论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List, + history: List, system_prompt: str, user_request: str): +II. 函数返回可参考crazy_functions/批量文件询问.py中的“批量文件询问”函数,主要采用yield方式 3、对于RAG,我希望采用light_rag的方案,参考已有方案其主要的功能实现是: 主要功能包括: - e 参考- `chunking_by_token_size`,利用`_handle_entity_relation_summary`函数对d步骤生成的文本块进行实体或关系的摘要。 + a. 分别为project和arxiv创建rag_handler,project类的fragment类内容为 + @dataclass +class DocFragment: + """文本片段数据类""" + file_path: str # 原始文件路径 + content: str # 片段内容 + segment_index: int # 片段序号 + total_segments: int # 总片段数 + rel_path: str # 相对路径 + arxiv的fragment内容为: + @dataclass +class ArxivFragment: + """Arxiv论文片段数据类""" + file_path: str + content: str + segment_index: int + total_segments: int + rel_path: str + segment_type: str + title: str + abstract: str + section: str + is_appendix: bool + b 如果目录下不存在抽取好的实体或关系的摘要,利用`_handle_entity_relation_summary`函数对d步骤生成的文本块进行实体或关系的摘要,并将其存储在project或者arxiv的路径下,路径为获取fragment.file_path的前三级目录(按照“/”区分每一级),如果原目录存在抽取好的,请直接使用,不再重复抽取。 f 利用`_handle_single_entity_extraction` 和 `_handle_single_relationship_extraction`:从记录中提取单个实体或关系信息。 g `_merge_nodes_then_upsert` 和 `_merge_edges_then_upsert`:合并并插入节点或边。 h `extract_entities`:处理多个文本块,提取实体和关系,并存储在知识图谱和向量数据库中。 diff --git a/request_llms/embed_models/openai_embed.py b/request_llms/embed_models/openai_embed.py index 9d565173..6f83bde1 100644 --- a/request_llms/embed_models/openai_embed.py +++ b/request_llms/embed_models/openai_embed.py @@ -1,4 +1,3 @@ -from llama_index.embeddings.openai import OpenAIEmbedding from openai import OpenAI from toolbox import get_conf from toolbox import CatchException, update_ui, get_conf, select_api_key, get_log_folder, ProxyNetworkActivate diff --git a/requirements.txt b/requirements.txt index 7a3d9f86..fb0b1550 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,12 +12,14 @@ transformers>=4.27.1,<4.42 scipdf_parser>=0.52 spacy==3.7.4 anthropic>=0.18.1 +sentence-transformers python-markdown-math pymdown-extensions websocket-client beautifulsoup4 prompt_toolkit latex2mathml +scikit-learn python-docx mdtex2html dashscope @@ -43,4 +45,4 @@ llama-index-embeddings-azure-openai==0.1.10 llama-index-embeddings-openai==0.1.10 llama-parse==0.4.9 mdit-py-plugins>=0.3.3 -linkify-it-py==2.0.3 \ No newline at end of file +linkify-it-py==2.0.3