From 21626a44d55c006722c6cacb861b2d5f2f2a9ac1 Mon Sep 17 00:00:00 2001 From: lbykkkk Date: Sat, 16 Nov 2024 00:35:31 +0800 Subject: [PATCH] up --- crazy_functions/Arxiv_论文对话.py | 28 +- .../rag_essay_fns/arxiv_splitter.py | 510 -------- .../rag_essay_fns/paper_processing.py | 311 ----- crazy_functions/rag_essay_fns/rag_handler.py | 164 --- crazy_functions/rag_fns/arxiv_fns/__init__.py | 0 .../rag_fns/arxiv_fns/arxiv_downloader.py | 111 ++ .../rag_fns/arxiv_fns/arxiv_fragment.py | 55 + .../rag_fns/arxiv_fns/arxiv_splitter.py | 449 +++++++ .../rag_fns/arxiv_fns/latex_patterns.py | 395 ++++++ .../rag_fns/arxiv_fns/tex_processor.py | 1099 +++++++++++++++++ .../rag_fns/light_rag/core/storage.py | 272 ++-- crazy_functions/rag_fns/light_rag/example.py | 160 ++- 12 files changed, 2385 insertions(+), 1169 deletions(-) delete mode 100644 crazy_functions/rag_essay_fns/arxiv_splitter.py delete mode 100644 crazy_functions/rag_essay_fns/paper_processing.py delete mode 100644 crazy_functions/rag_essay_fns/rag_handler.py create mode 100644 crazy_functions/rag_fns/arxiv_fns/__init__.py create mode 100644 crazy_functions/rag_fns/arxiv_fns/arxiv_downloader.py create mode 100644 crazy_functions/rag_fns/arxiv_fns/arxiv_fragment.py create mode 100644 crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py create mode 100644 crazy_functions/rag_fns/arxiv_fns/latex_patterns.py create mode 100644 crazy_functions/rag_fns/arxiv_fns/tex_processor.py diff --git a/crazy_functions/Arxiv_论文对话.py b/crazy_functions/Arxiv_论文对话.py index ed110325..86afe134 100644 --- a/crazy_functions/Arxiv_论文对话.py +++ b/crazy_functions/Arxiv_论文对话.py @@ -1,8 +1,8 @@ import os.path from toolbox import CatchException, update_ui -from crazy_functions.rag_essay_fns.paper_processing import ArxivPaperProcessor -import asyncio +from crazy_functions.rag_fns.arxiv_fns.paper_processing import ArxivPaperProcessor + @CatchException def Rag论文对话(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port): @@ -12,14 +12,14 @@ def Rag论文对话(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_pro """ if_project, if_arxiv = False, False if os.path.exists(txt): - from crazy_functions.rag_essay_fns.document_splitter import SmartDocumentSplitter + from crazy_functions.rag_fns.doc_fns.document_splitter import SmartDocumentSplitter splitter = SmartDocumentSplitter( char_range=(1000, 1200), max_workers=32 # 可选,默认会根据CPU核心数自动设置 ) if_project = True else: - from crazy_functions.rag_essay_fns.arxiv_splitter import SmartArxivSplitter + from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import SmartArxivSplitter splitter = SmartArxivSplitter( char_range=(1000, 1200), root_dir="gpt_log/arxiv_cache" @@ -61,23 +61,3 @@ def Rag论文对话(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_pro yield from update_ui(chatbot=chatbot, history=history) # 交互式问答 - chatbot.append(["知识图谱构建完成", "您可以开始提问了。支持以下类型的问题:\n1. 论文的具体内容\n2. 研究方法的细节\n3. 实验结果分析\n4. 与其他工作的比较"]) - yield from update_ui(chatbot=chatbot, history=history) - - # 等待用户提问并回答 - while True: - question = yield from wait_user_input() - if not question: - break - - # 根据问题类型选择不同的查询模式 - if "比较" in question or "关系" in question: - mode = "global" # 使用全局模式处理比较类问题 - elif "具体" in question or "细节" in question: - mode = "local" # 使用局部模式处理细节问题 - else: - mode = "hybrid" # 默认使用混合模式 - - response = rag_handler.query(question, mode=mode) - chatbot.append([question, response]) - yield from update_ui(chatbot=chatbot, history=history) \ No newline at end of file diff --git a/crazy_functions/rag_essay_fns/arxiv_splitter.py b/crazy_functions/rag_essay_fns/arxiv_splitter.py deleted file mode 100644 index 0af46acc..00000000 --- a/crazy_functions/rag_essay_fns/arxiv_splitter.py +++ /dev/null @@ -1,510 +0,0 @@ -import os -import re -import requests -import tarfile -import logging -from dataclasses import dataclass -from typing import Generator, List, Tuple, Optional, Dict, Set -from pathlib import Path -from concurrent.futures import ThreadPoolExecutor, as_completed - - -@dataclass -class ArxivFragment: - """Arxiv论文片段数据类""" - file_path: str - content: str - segment_index: int - total_segments: int - rel_path: str - segment_type: str - title: str - abstract: str - section: str - is_appendix: bool - - -class SmartArxivSplitter: - def __init__(self, - char_range: Tuple[int, int], - root_dir: str = "gpt_log/arxiv_cache", - proxies: Optional[Dict[str, str]] = None, - max_workers: int = 4): - - self.min_chars, self.max_chars = char_range - self.root_dir = Path(root_dir) - self.root_dir.mkdir(parents=True, exist_ok=True) - self.proxies = proxies or {} - self.max_workers = max_workers - - # 定义特殊环境模式 - self._init_patterns() - - # 配置日志 - logging.basicConfig(level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s') - - def _init_patterns(self): - """初始化LaTeX环境和命令模式""" - self.special_envs = { - 'math': [r'\\begin{(equation|align|gather|eqnarray)\*?}.*?\\end{\1\*?}', - r'\$\$.*?\$\$', r'\$[^$]+\$'], - 'table': [r'\\begin{(table|tabular)\*?}.*?\\end{\1\*?}'], - 'figure': [r'\\begin{figure\*?}.*?\\end{figure\*?}'], - 'algorithm': [r'\\begin{(algorithm|algorithmic)}.*?\\end{\1}'] - } - - self.section_patterns = [ - r'\\(sub)*section\{([^}]+)\}', - r'\\chapter\{([^}]+)\}' - ] - - self.include_patterns = [ - r'\\(input|include|subfile)\{([^}]+)\}' - ] - - def _find_main_tex_file(self, directory: str) -> Optional[str]: - """查找主TEX文件""" - tex_files = list(Path(directory).rglob("*.tex")) - if not tex_files: - return None - - # 按以下优先级查找: - # 1. 包含documentclass的文件 - # 2. 文件名为main.tex - # 3. 最大的tex文件 - for tex_file in tex_files: - try: - content = self._read_file(str(tex_file)) - if content and r'\documentclass' in content: - return str(tex_file) - if tex_file.name.lower() == 'main.tex': - return str(tex_file) - except Exception: - continue - - return str(max(tex_files, key=lambda x: x.stat().st_size)) - - def _resolve_includes(self, tex_file: str, processed: Set[str] = None) -> List[str]: - """递归解析tex文件中的include/input命令""" - if processed is None: - processed = set() - - if tex_file in processed: - return [] - - processed.add(tex_file) - result = [tex_file] - content = self._read_file(tex_file) - - if not content: - return result - - base_dir = Path(tex_file).parent - for pattern in self.include_patterns: - for match in re.finditer(pattern, content): - included_file = match.group(2) - if not included_file.endswith('.tex'): - included_file += '.tex' - - # 构建完整路径 - full_path = str(base_dir / included_file) - if os.path.exists(full_path) and full_path not in processed: - result.extend(self._resolve_includes(full_path, processed)) - - return result - - def _smart_split(self, content: str) -> List[Tuple[str, str, bool]]: - """智能分割TEX内容,确保在字符范围内并保持语义完整性""" - content = self._preprocess_content(content) - segments = [] - current_buffer = [] - current_length = 0 - current_section = "Unknown Section" - is_appendix = False - - # 保护特殊环境 - protected_blocks = {} - content = self._protect_special_environments(content, protected_blocks) - - # 按段落分割 - paragraphs = re.split(r'\n\s*\n', content) - - for para in paragraphs: - para = para.strip() - if not para: - continue - - # 恢复特殊环境 - para = self._restore_special_environments(para, protected_blocks) - - # 更新章节信息 - section_info = self._get_section_info(para, content) - if section_info: - current_section, is_appendix = section_info - - # 判断是否是特殊环境 - if self._is_special_environment(para): - # 处理当前缓冲区 - if current_buffer: - segments.append(( - '\n'.join(current_buffer), - current_section, - is_appendix - )) - current_buffer = [] - current_length = 0 - - # 添加特殊环境作为独立片段 - segments.append((para, current_section, is_appendix)) - continue - - # 处理普通段落 - sentences = self._split_into_sentences(para) - for sentence in sentences: - sentence = sentence.strip() - if not sentence: - continue - - sent_length = len(sentence) - new_length = current_length + sent_length + (1 if current_buffer else 0) - - if new_length <= self.max_chars: - current_buffer.append(sentence) - current_length = new_length - else: - # 如果当前缓冲区达到最小长度要求 - if current_length >= self.min_chars: - segments.append(( - '\n'.join(current_buffer), - current_section, - is_appendix - )) - current_buffer = [sentence] - current_length = sent_length - else: - # 尝试将过长的句子分割 - split_sentences = self._split_long_sentence(sentence) - for split_sent in split_sentences: - if current_length + len(split_sent) <= self.max_chars: - current_buffer.append(split_sent) - current_length += len(split_sent) + 1 - else: - segments.append(( - '\n'.join(current_buffer), - current_section, - is_appendix - )) - current_buffer = [split_sent] - current_length = len(split_sent) - - # 处理剩余的缓冲区 - if current_buffer: - segments.append(( - '\n'.join(current_buffer), - current_section, - is_appendix - )) - - return segments - - def _split_into_sentences(self, text: str) -> List[str]: - """将文本分割成句子""" - return re.split(r'(?<=[.!?。!?])\s+', text) - - def _split_long_sentence(self, sentence: str) -> List[str]: - """智能分割过长的句子""" - if len(sentence) <= self.max_chars: - return [sentence] - - result = [] - while sentence: - # 在最大长度位置寻找合适的分割点 - split_pos = self._find_split_position(sentence[:self.max_chars]) - if split_pos <= 0: - split_pos = self.max_chars - - result.append(sentence[:split_pos]) - sentence = sentence[split_pos:].strip() - - return result - - def _find_split_position(self, text: str) -> int: - """找到合适的句子分割位置""" - # 优先在标点符号处分割 - punctuation_match = re.search(r'[,,;;]\s*', text[::-1]) - if punctuation_match: - return len(text) - punctuation_match.end() - - # 其次在空白字符处分割 - space_match = re.search(r'\s+', text[::-1]) - if space_match: - return len(text) - space_match.end() - - return -1 - - def _protect_special_environments(self, content: str, protected_blocks: Dict[str, str]) -> str: - """保护特殊环境内容""" - for env_type, patterns in self.special_envs.items(): - for pattern in patterns: - content = re.sub( - pattern, - lambda m: self._store_protected_block(m.group(0), protected_blocks), - content, - flags=re.DOTALL - ) - return content - - def _store_protected_block(self, content: str, protected_blocks: Dict[str, str]) -> str: - """存储受保护的内容块""" - placeholder = f"PROTECTED_{len(protected_blocks)}" - protected_blocks[placeholder] = content - return placeholder - - def _restore_special_environments(self, content: str, protected_blocks: Dict[str, str]) -> str: - """恢复特殊环境内容""" - for placeholder, original in protected_blocks.items(): - content = content.replace(placeholder, original) - return content - - def _is_special_environment(self, text: str) -> bool: - """判断是否是特殊环境""" - for patterns in self.special_envs.values(): - for pattern in patterns: - if re.search(pattern, text, re.DOTALL): - return True - return False - - def _preprocess_content(self, content: str) -> str: - """预处理TEX内容""" - # 移除注释 - content = re.sub(r'(?m)%.*$', '', content) - # 规范化空白字符 - content = re.sub(r'\s+', ' ', content) - content = re.sub(r'\n\s*\n', '\n\n', content) - # 移除不必要的命令 - content = re.sub(r'\\(label|ref|cite)\{[^}]*\}', '', content) - return content.strip() - - def process(self, arxiv_id_or_url: str) -> Generator[ArxivFragment, None, None]: - """处理单篇arxiv论文""" - try: - arxiv_id = self._normalize_arxiv_id(arxiv_id_or_url) - paper_dir = self._download_and_extract(arxiv_id) - - # 查找主tex文件 - main_tex = self._find_main_tex_file(paper_dir) - if not main_tex: - raise RuntimeError(f"No main tex file found in {paper_dir}") - - # 获取所有相关tex文件 - tex_files = self._resolve_includes(main_tex) - - # 处理所有tex文件 - fragments = [] - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - future_to_file = { - executor.submit(self._process_single_tex, file_path): file_path - for file_path in tex_files - } - - for future in as_completed(future_to_file): - try: - fragments.extend(future.result()) - except Exception as e: - logging.error(f"Error processing file: {e}") - - # 重新计算片段索引 - fragments.sort(key=lambda x: (x.rel_path, x.segment_index)) - total_fragments = len(fragments) - - for i, fragment in enumerate(fragments): - fragment.segment_index = i - fragment.total_segments = total_fragments - yield fragment - - except Exception as e: - logging.error(f"Error processing paper {arxiv_id_or_url}: {e}") - raise RuntimeError(f"Failed to process paper: {str(e)}") - - def _normalize_arxiv_id(self, input_str: str) -> str: - """规范化arxiv ID""" - if input_str.startswith('https://arxiv.org/'): - if '/pdf/' in input_str: - return input_str.split('/pdf/')[-1].split('v')[0] - return input_str.split('/abs/')[-1].split('v')[0] - return input_str.split('v')[0] - - def _download_and_extract(self, arxiv_id: str) -> str: - """下载并解压arxiv论文源码""" - paper_dir = self.root_dir / arxiv_id - tar_path = paper_dir / f"{arxiv_id}.tar.gz" - - # 检查缓存 - if paper_dir.exists() and any(paper_dir.iterdir()): - logging.info(f"Using cached version for {arxiv_id}") - return str(paper_dir) - - paper_dir.mkdir(exist_ok=True) - - urls = [ - f"https://arxiv.org/src/{arxiv_id}", - f"https://arxiv.org/e-print/{arxiv_id}" - ] - - for url in urls: - try: - logging.info(f"Downloading from {url}") - response = requests.get(url, proxies=self.proxies) - if response.status_code == 200: - tar_path.write_bytes(response.content) - with tarfile.open(tar_path, 'r:gz') as tar: - tar.extractall(path=paper_dir) - return str(paper_dir) - except Exception as e: - logging.warning(f"Download failed for {url}: {e}") - continue - - raise RuntimeError(f"Failed to download paper {arxiv_id}") - - def _read_file(self, file_path: str) -> Optional[str]: - """使用多种编码尝试读取文件""" - encodings = ['utf-8', 'latin1', 'gbk', 'gb2312', 'ascii'] - for encoding in encodings: - try: - with open(file_path, 'r', encoding=encoding) as f: - return f.read() - except UnicodeDecodeError: - continue - logging.warning(f"Failed to read file {file_path} with all encodings") - return None - - def _extract_metadata(self, content: str) -> Tuple[str, str]: - """提取论文标题和摘要""" - title = "" - abstract = "" - - # 提取标题 - title_patterns = [ - r'\\title{([^}]*)}', - r'\\Title{([^}]*)}' - ] - for pattern in title_patterns: - match = re.search(pattern, content) - if match: - title = match.group(1) - title = re.sub(r'\\[a-zA-Z]+{([^}]*)}', r'\1', title) - break - - # 提取摘要 - abstract_patterns = [ - r'\\begin{abstract}(.*?)\\end{abstract}', - r'\\abstract{([^}]*)}' - ] - for pattern in abstract_patterns: - match = re.search(pattern, content, re.DOTALL) - if match: - abstract = match.group(1).strip() - abstract = re.sub(r'\\[a-zA-Z]+{([^}]*)}', r'\1', abstract) - break - - return title.strip(), abstract.strip() - - def _get_section_info(self, para: str, content: str) -> Optional[Tuple[str, bool]]: - """获取段落所属的章节信息""" - section = "Unknown Section" - is_appendix = False - - # 查找所有章节标记 - all_sections = [] - for pattern in self.section_patterns: - for match in re.finditer(pattern, content): - all_sections.append((match.start(), match.group(2))) - - # 查找appendix标记 - appendix_pos = content.find(r'\appendix') - - # 确定当前章节 - para_pos = content.find(para) - if para_pos >= 0: - current_section = None - for sec_pos, sec_title in sorted(all_sections): - if sec_pos > para_pos: - break - current_section = sec_title - - if current_section: - section = current_section - is_appendix = appendix_pos >= 0 and para_pos > appendix_pos - - return section, is_appendix - - return None - - def _process_single_tex(self, file_path: str) -> List[ArxivFragment]: - """处理单个TEX文件""" - try: - content = self._read_file(file_path) - if not content: - return [] - - # 提取元数据 - is_main = r'\documentclass' in content - title = "" - abstract = "" - if is_main: - title, abstract = self._extract_metadata(content) - - # 智能分割内容 - segments = self._smart_split(content) - fragments = [] - - for i, (segment_content, section, is_appendix) in enumerate(segments): - if segment_content.strip(): - segment_type = 'text' - for env_type, patterns in self.special_envs.items(): - if any(re.search(pattern, segment_content, re.DOTALL) - for pattern in patterns): - segment_type = env_type - break - - fragments.append(ArxivFragment( - file_path=file_path, - content=segment_content, - segment_index=i, - total_segments=len(segments), - rel_path=os.path.relpath(file_path, str(self.root_dir)), - segment_type=segment_type, - title=title, - abstract=abstract, - section=section, - is_appendix=is_appendix - )) - - return fragments - - except Exception as e: - logging.error(f"Error processing file {file_path}: {e}") - return [] - -def main(): - """使用示例""" - # 创建分割器实例 - splitter = SmartArxivSplitter( - char_range=(1000, 1200), - root_dir="gpt_log/arxiv_cache" - ) - - # 处理论文 - for fragment in splitter.process("2411.03663"): - print(f"Segment {fragment.segment_index + 1}/{fragment.total_segments}") - print(f"Length: {len(fragment.content)}") - print(f"Section: {fragment.section}") - print(f"Title: {fragment.file_path}") - - print(fragment.content) - print("-" * 80) - - -if __name__ == "__main__": - main() diff --git a/crazy_functions/rag_essay_fns/paper_processing.py b/crazy_functions/rag_essay_fns/paper_processing.py deleted file mode 100644 index 1e46407a..00000000 --- a/crazy_functions/rag_essay_fns/paper_processing.py +++ /dev/null @@ -1,311 +0,0 @@ -from typing import Tuple, Optional, Generator, List -from toolbox import update_ui, update_ui_lastest_msg, get_conf -import os, tarfile, requests, time, re -class ArxivPaperProcessor: - """Arxiv论文处理器类""" - - def __init__(self): - self.supported_encodings = ['utf-8', 'latin1', 'gbk', 'gb2312', 'ascii'] - self.arxiv_cache_dir = get_conf("ARXIV_CACHE_DIR") - - def download_and_extract(self, txt: str, chatbot, history) -> Generator[Optional[Tuple[str, str]], None, None]: - """ - Step 1: 下载和提取arxiv论文 - 返回: 生成器: (project_folder, arxiv_id) - """ - try: - if txt == "": - chatbot.append(("", "请输入arxiv论文链接或ID")) - yield from update_ui(chatbot=chatbot, history=history) - return - - project_folder, arxiv_id = self.arxiv_download(txt, chatbot, history) - if project_folder is None or arxiv_id is None: - return - - if not os.path.exists(project_folder): - chatbot.append((txt, f"找不到项目文件夹: {project_folder}")) - yield from update_ui(chatbot=chatbot, history=history) - return - - # 期望的返回值 - yield project_folder, arxiv_id - - except Exception as e: - print(e) - # yield from update_ui_lastest_msg( - # "下载失败,请手动下载latex源码:请前往arxiv打开此论文下载页面,点other Formats,然后download source。", - # chatbot=chatbot, history=history) - return - - def arxiv_download(self, txt: str, chatbot, history) -> Tuple[str, str]: - """ - 下载arxiv论文并提取 - 返回: (project_folder, arxiv_id) - """ - def is_float(s: str) -> bool: - try: - float(s) - return True - except ValueError: - return False - - if txt.startswith('https://arxiv.org/pdf/'): - arxiv_id = txt.split('/')[-1] # 2402.14207v2.pdf - txt = arxiv_id.split('v')[0] # 2402.14207 - - if ('.' in txt) and ('/' not in txt) and is_float(txt): # is arxiv ID - txt = 'https://arxiv.org/abs/' + txt.strip() - if ('.' in txt) and ('/' not in txt) and is_float(txt[:10]): # is arxiv ID - txt = 'https://arxiv.org/abs/' + txt[:10] - - if not txt.startswith('https://arxiv.org'): - chatbot.append((txt, "不是有效的arxiv链接或ID")) - # yield from update_ui(chatbot=chatbot, history=history) - return None, None # 返回两个值,即使其中一个为None - - chatbot.append([f"检测到arxiv文档连接", '尝试下载 ...']) - # yield from update_ui(chatbot=chatbot, history=history) - - url_ = txt # https://arxiv.org/abs/1707.06690 - - if not txt.startswith('https://arxiv.org/abs/'): - msg = f"解析arxiv网址失败, 期望格式例如: https://arxiv.org/abs/1707.06690。实际得到格式: {url_}。" - # yield from update_ui_lastest_msg(msg, chatbot=chatbot, history=history) # 刷新界面 - return None, None # 返回两个值,即使其中一个为None - - arxiv_id = url_.split('/')[-1].split('v')[0] - - dst = os.path.join(self.arxiv_cache_dir, arxiv_id, f'{arxiv_id}.tar.gz') - project_folder = os.path.join(self.arxiv_cache_dir, arxiv_id) - - success = self.download_arxiv_paper(url_, dst, chatbot, history) - - # if os.path.exists(dst) and get_conf('allow_cache'): - # # yield from update_ui_lastest_msg(f"调用缓存 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面 - # success = True - # else: - # # yield from update_ui_lastest_msg(f"开始下载 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面 - # success = self.download_arxiv_paper(url_, dst, chatbot, history) - # # yield from update_ui_lastest_msg(f"下载完成 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面 - - if not success: - # chatbot.append([f"下载失败 {arxiv_id}", ""]) - # yield from update_ui(chatbot=chatbot, history=history) - raise tarfile.ReadError(f"论文下载失败 {arxiv_id}") - - # yield from update_ui_lastest_msg(f"开始解压 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面 - extract_dst = self.extract_tar_file(dst, project_folder, chatbot, history) - # yield from update_ui_lastest_msg(f"解压完成 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面 - - return extract_dst, arxiv_id - - def download_arxiv_paper(self, url_: str, dst: str, chatbot, history) -> bool: - """下载arxiv论文""" - try: - proxies = get_conf('proxies') - for url_tar in [url_.replace('/abs/', '/src/'), url_.replace('/abs/', '/e-print/')]: - r = requests.get(url_tar, proxies=proxies) - if r.status_code == 200: - with open(dst, 'wb+') as f: - f.write(r.content) - return True - return False - except requests.RequestException as e: - # chatbot.append((f"下载失败 {url_}", str(e))) - # yield from update_ui(chatbot=chatbot, history=history) - return False - - def extract_tar_file(self, file_path: str, dest_dir: str, chatbot, history) -> str: - """解压arxiv论文""" - try: - with tarfile.open(file_path, 'r:gz') as tar: - tar.extractall(path=dest_dir) - return dest_dir - except tarfile.ReadError as e: - chatbot.append((f"解压失败 {file_path}", str(e))) - yield from update_ui(chatbot=chatbot, history=history) - raise e - - def find_main_tex_file(self, tex_files: list) -> str: - """查找主TEX文件""" - for tex_file in tex_files: - with open(tex_file, 'r', encoding='utf-8', errors='ignore') as f: - content = f.read() - if r'\documentclass' in content: - return tex_file - return max(tex_files, key=lambda x: os.path.getsize(x)) - - def read_file_with_encoding(self, file_path: str) -> Optional[str]: - """使用多种编码尝试读取文件""" - for encoding in self.supported_encodings: - try: - with open(file_path, 'r', encoding=encoding) as f: - return f.read() - except UnicodeDecodeError: - continue - return None - - def process_tex_content(self, content: str, base_path: str, processed_files=None) -> str: - """处理TEX内容,包括递归处理包含的文件""" - if processed_files is None: - processed_files = set() - - include_patterns = [ - r'\\input{([^}]+)}', - r'\\include{([^}]+)}', - r'\\subfile{([^}]+)}', - r'\\input\s+([^\s{]+)', - ] - - for pattern in include_patterns: - matches = re.finditer(pattern, content) - for match in matches: - include_file = match.group(1) - if not include_file.endswith('.tex'): - include_file += '.tex' - - include_path = os.path.join(base_path, include_file) - include_path = os.path.normpath(include_path) - - if include_path in processed_files: - continue - processed_files.add(include_path) - - if os.path.exists(include_path): - included_content = self.read_file_with_encoding(include_path) - if included_content: - included_content = self.process_tex_content( - included_content, - os.path.dirname(include_path), - processed_files - ) - content = content.replace(match.group(0), included_content) - - return content - - def merge_tex_files(self, folder_path: str, chatbot, history) -> Optional[str]: - """ - Step 2: 合并TEX文件 - 返回: 合并后的内容 - """ - try: - tex_files = [] - for root, _, files in os.walk(folder_path): - tex_files.extend([os.path.join(root, f) for f in files if f.endswith('.tex')]) - - if not tex_files: - chatbot.append(("", "未找到任何TEX文件")) - yield from update_ui(chatbot=chatbot, history=history) - return None - - main_tex_file = self.find_main_tex_file(tex_files) - chatbot.append(("", f"找到主TEX文件:{os.path.basename(main_tex_file)}")) - yield from update_ui(chatbot=chatbot, history=history) - - tex_content = self.read_file_with_encoding(main_tex_file) - if tex_content is None: - chatbot.append(("", "无法读取TEX文件,可能是编码问题")) - yield from update_ui(chatbot=chatbot, history=history) - return None - - full_content = self.process_tex_content( - tex_content, - os.path.dirname(main_tex_file) - ) - - cleaned_content = self.clean_tex_content(full_content) - - chatbot.append(("", - f"成功处理所有TEX文件:\n" - f"- 原始内容大小:{len(full_content)}字符\n" - f"- 清理后内容大小:{len(cleaned_content)}字符" - )) - yield from update_ui(chatbot=chatbot, history=history) - - # 添加标题和摘要提取 - title = "" - abstract = "" - if tex_content: - # 提取标题 - title_match = re.search(r'\\title{([^}]*)}', tex_content) - if title_match: - title = title_match.group(1) - - # 提取摘要 - abstract_match = re.search(r'\\begin{abstract}(.*?)\\end{abstract}', - tex_content, re.DOTALL) - if abstract_match: - abstract = abstract_match.group(1) - - # 按token限制分段 - def split_by_token_limit(text: str, token_limit: int = 1024) -> List[str]: - segments = [] - current_segment = [] - current_tokens = 0 - - for line in text.split('\n'): - line_tokens = len(line.split()) - if current_tokens + line_tokens > token_limit: - segments.append('\n'.join(current_segment)) - current_segment = [line] - current_tokens = line_tokens - else: - current_segment.append(line) - current_tokens += line_tokens - - if current_segment: - segments.append('\n'.join(current_segment)) - - return segments - - text_segments = split_by_token_limit(cleaned_content) - - return { - 'title': title, - 'abstract': abstract, - 'segments': text_segments - } - - except Exception as e: - chatbot.append(("", f"处理TEX文件时发生错误:{str(e)}")) - yield from update_ui(chatbot=chatbot, history=history) - return None - - @staticmethod - def clean_tex_content(content: str) -> str: - """清理TEX内容""" - content = re.sub(r'(?m)%.*$', '', content) # 移除注释 - content = re.sub(r'\\cite{[^}]*}', '', content) # 移除引用 - content = re.sub(r'\\label{[^}]*}', '', content) # 移除标签 - content = re.sub(r'\s+', ' ', content) # 规范化空白 - return content.strip() - -if __name__ == "__main__": - # 测试 arxiv_download 函数 - processor = ArxivPaperProcessor() - chatbot = [] - history = [] - - # 测试不同格式的输入 - test_inputs = [ - "https://arxiv.org/abs/2402.14207", # 标准格式 - "https://arxiv.org/pdf/2402.14207.pdf", # PDF链接格式 - "2402.14207", # 纯ID格式 - "2402.14207v1", # 带版本号的ID格式 - "https://invalid.url", # 无效URL测试 - ] - - for input_url in test_inputs: - print(f"\n测试输入: {input_url}") - try: - project_folder, arxiv_id = processor.arxiv_download(input_url, chatbot, history) - if project_folder and arxiv_id: - print(f"下载成功:") - print(f"- 项目文件夹: {project_folder}") - print(f"- Arxiv ID: {arxiv_id}") - print(f"- 文件夹是否存在: {os.path.exists(project_folder)}") - else: - print("下载失败: 返回值为 None") - except Exception as e: - print(f"发生错误: {str(e)}") diff --git a/crazy_functions/rag_essay_fns/rag_handler.py b/crazy_functions/rag_essay_fns/rag_handler.py deleted file mode 100644 index 0a8d4742..00000000 --- a/crazy_functions/rag_essay_fns/rag_handler.py +++ /dev/null @@ -1,164 +0,0 @@ -from typing import Dict, List, Optional -from lightrag import LightRAG, QueryParam -from lightrag.utils import EmbeddingFunc -import numpy as np -import os -from toolbox import get_conf -import openai - -class RagHandler: - def __init__(self): - # 初始化工作目录 - self.working_dir = os.path.join(get_conf('ARXIV_CACHE_DIR'), 'rag_cache') - if not os.path.exists(self.working_dir): - os.makedirs(self.working_dir) - - # 初始化 LightRAG - self.rag = LightRAG( - working_dir=self.working_dir, - llm_model_func=self._llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=1536, # OpenAI embedding 维度 - max_token_size=8192, - func=self._embedding_func, - ), - ) - - async def _llm_model_func(self, prompt: str, system_prompt: str = None, - history_messages: List = None, **kwargs) -> str: - """LLM 模型函数""" - messages = [] - if system_prompt: - messages.append({"role": "system", "content": system_prompt}) - if history_messages: - messages.extend(history_messages) - messages.append({"role": "user", "content": prompt}) - - response = await openai.ChatCompletion.acreate( - model="gpt-3.5-turbo", - messages=messages, - temperature=kwargs.get("temperature", 0), - max_tokens=kwargs.get("max_tokens", 1000) - ) - return response.choices[0].message.content - - async def _embedding_func(self, texts: List[str]) -> np.ndarray: - """Embedding 函数""" - response = await openai.Embedding.acreate( - model="text-embedding-ada-002", - input=texts - ) - embeddings = [item["embedding"] for item in response["data"]] - return np.array(embeddings) - - def process_paper_content(self, paper_content: Dict) -> None: - """处理论文内容,构建知识图谱""" - # 处理标题和摘要 - content_list = [] - if paper_content['title']: - content_list.append(f"Title: {paper_content['title']}") - if paper_content['abstract']: - content_list.append(f"Abstract: {paper_content['abstract']}") - - # 添加分段内容 - content_list.extend(paper_content['segments']) - - # 插入到 RAG 系统 - self.rag.insert(content_list) - - def query(self, question: str, mode: str = "hybrid") -> str: - """查询论文内容 - mode: 查询模式,可选 naive/local/global/hybrid - """ - try: - response = self.rag.query( - question, - param=QueryParam( - mode=mode, - top_k=5, # 返回相关度最高的5个结果 - max_token_for_text_unit=2048, # 每个文本单元的最大token数 - response_type="detailed" # 返回详细回答 - ) - ) - return response - except Exception as e: - return f"查询出错: {str(e)}" - - -class RagHandler: - def __init__(self): - # 初始化工作目录 - self.working_dir = os.path.join(get_conf('ARXIV_CACHE_DIR'), 'rag_cache') - if not os.path.exists(self.working_dir): - os.makedirs(self.working_dir) - - # 初始化 LightRAG - self.rag = LightRAG( - working_dir=self.working_dir, - llm_model_func=self._llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=1536, # OpenAI embedding 维度 - max_token_size=8192, - func=self._embedding_func, - ), - ) - - async def _llm_model_func(self, prompt: str, system_prompt: str = None, - history_messages: List = None, **kwargs) -> str: - """LLM 模型函数""" - messages = [] - if system_prompt: - messages.append({"role": "system", "content": system_prompt}) - if history_messages: - messages.extend(history_messages) - messages.append({"role": "user", "content": prompt}) - - response = await openai.ChatCompletion.acreate( - model="gpt-3.5-turbo", - messages=messages, - temperature=kwargs.get("temperature", 0), - max_tokens=kwargs.get("max_tokens", 1000) - ) - return response.choices[0].message.content - - async def _embedding_func(self, texts: List[str]) -> np.ndarray: - """Embedding 函数""" - response = await openai.Embedding.acreate( - model="text-embedding-ada-002", - input=texts - ) - embeddings = [item["embedding"] for item in response["data"]] - return np.array(embeddings) - - def process_paper_content(self, paper_content: Dict) -> None: - """处理论文内容,构建知识图谱""" - # 处理标题和摘要 - content_list = [] - if paper_content['title']: - content_list.append(f"Title: {paper_content['title']}") - if paper_content['abstract']: - content_list.append(f"Abstract: {paper_content['abstract']}") - - # 添加分段内容 - content_list.extend(paper_content['segments']) - - # 插入到 RAG 系统 - self.rag.insert(content_list) - - def query(self, question: str, mode: str = "hybrid") -> str: - """查询论文内容 - mode: 查询模式,可选 naive/local/global/hybrid - """ - try: - response = self.rag.query( - question, - param=QueryParam( - mode=mode, - top_k=5, # 返回相关度最高的5个结果 - max_token_for_text_unit=2048, # 每个文本单元的最大token数 - response_type="detailed" # 返回详细回答 - ) - ) - return response - except Exception as e: - return f"查询出错: {str(e)}" \ No newline at end of file diff --git a/crazy_functions/rag_fns/arxiv_fns/__init__.py b/crazy_functions/rag_fns/arxiv_fns/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/crazy_functions/rag_fns/arxiv_fns/arxiv_downloader.py b/crazy_functions/rag_fns/arxiv_fns/arxiv_downloader.py new file mode 100644 index 00000000..6d0a8646 --- /dev/null +++ b/crazy_functions/rag_fns/arxiv_fns/arxiv_downloader.py @@ -0,0 +1,111 @@ +import logging +import requests +import tarfile +from pathlib import Path +from typing import Optional, Dict + +class ArxivDownloader: + """用于下载arXiv论文源码的下载器""" + + def __init__(self, root_dir: str = "./papers", proxies: Optional[Dict[str, str]] = None): + """ + 初始化下载器 + + Args: + root_dir: 保存下载文件的根目录 + proxies: 代理服务器设置,例如 {"http": "http://proxy:port", "https": "https://proxy:port"} + """ + self.root_dir = Path(root_dir) + self.root_dir.mkdir(exist_ok=True) + self.proxies = proxies + + # 配置日志 + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' + ) + + def _download_and_extract(self, arxiv_id: str) -> str: + """ + 下载并解压arxiv论文源码 + + Args: + arxiv_id: arXiv论文ID,例如"2103.00020" + + Returns: + str: 解压后的文件目录路径 + + Raises: + RuntimeError: 当下载失败时抛出 + """ + paper_dir = self.root_dir / arxiv_id + tar_path = paper_dir / f"{arxiv_id}.tar.gz" + + # 检查缓存 + if paper_dir.exists() and any(paper_dir.iterdir()): + logging.info(f"Using cached version for {arxiv_id}") + return str(paper_dir) + + paper_dir.mkdir(exist_ok=True) + + urls = [ + f"https://arxiv.org/src/{arxiv_id}", + f"https://arxiv.org/e-print/{arxiv_id}" + ] + + for url in urls: + try: + logging.info(f"Downloading from {url}") + response = requests.get(url, proxies=self.proxies) + if response.status_code == 200: + tar_path.write_bytes(response.content) + with tarfile.open(tar_path, 'r:gz') as tar: + tar.extractall(path=paper_dir) + return str(paper_dir) + except Exception as e: + logging.warning(f"Download failed for {url}: {e}") + continue + + raise RuntimeError(f"Failed to download paper {arxiv_id}") + + def download_paper(self, arxiv_id: str) -> str: + """ + 下载指定的arXiv论文 + + Args: + arxiv_id: arXiv论文ID + + Returns: + str: 论文文件所在的目录路径 + """ + return self._download_and_extract(arxiv_id) + +def main(): + """测试下载功能""" + # 配置代理(如果需要) + proxies = { + "http": "http://your-proxy:port", + "https": "https://your-proxy:port" + } + + # 创建下载器实例(如果不需要代理,可以不传入proxies参数) + downloader = ArxivDownloader(root_dir="./downloaded_papers", proxies=None) + + # 测试下载一篇论文(这里使用一个示例ID) + try: + paper_id = "2103.00020" # 这是一个示例ID + paper_dir = downloader.download_paper(paper_id) + print(f"Successfully downloaded paper to: {paper_dir}") + + # 检查下载的文件 + paper_path = Path(paper_dir) + if paper_path.exists(): + print("Downloaded files:") + for file in paper_path.rglob("*"): + if file.is_file(): + print(f"- {file.relative_to(paper_path)}") + except Exception as e: + print(f"Error downloading paper: {e}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/crazy_functions/rag_fns/arxiv_fns/arxiv_fragment.py b/crazy_functions/rag_fns/arxiv_fns/arxiv_fragment.py new file mode 100644 index 00000000..544a8c1d --- /dev/null +++ b/crazy_functions/rag_fns/arxiv_fns/arxiv_fragment.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass + +@dataclass +class ArxivFragment: + """Arxiv论文片段数据类""" + file_path: str # 文件路径 + content: str # 内容 + segment_index: int # 片段索引 + total_segments: int # 总片段数 + rel_path: str # 相对路径 + segment_type: str # 片段类型(text/math/table/figure等) + title: str # 论文标题 + abstract: str # 论文摘要 + section: str # 所属章节 + is_appendix: bool # 是否是附录 + importance: float = 1.0 # 重要性得分 + + @staticmethod + def merge_segments(seg1: 'ArxivFragment', seg2: 'ArxivFragment') -> 'ArxivFragment': + """ + 合并两个片段的静态方法 + + Args: + seg1: 第一个片段 + seg2: 第二个片段 + + Returns: + ArxivFragment: 合并后的片段 + """ + # 合并内容 + merged_content = f"{seg1.content}\n{seg2.content}" + + # 确定合并后的类型 + def _merge_segment_type(type1: str, type2: str) -> str: + if type1 == type2: + return type1 + if type1 == 'text': + return type2 + if type2 == 'text': + return type1 + return 'mixed' + + return ArxivFragment( + file_path=seg1.file_path, + content=merged_content, + segment_index=seg1.segment_index, + total_segments=seg1.total_segments - 1, + rel_path=seg1.rel_path, + segment_type=_merge_segment_type(seg1.segment_type, seg2.segment_type), + title=seg1.title, + abstract=seg1.abstract, + section=seg1.section, + is_appendix=seg1.is_appendix, + importance=max(seg1.importance, seg2.importance) + ) \ No newline at end of file diff --git a/crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py b/crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py new file mode 100644 index 00000000..901bb064 --- /dev/null +++ b/crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py @@ -0,0 +1,449 @@ +import os +import re +import time +import aiohttp +import asyncio +import requests +import tarfile +import logging +from pathlib import Path +from typing import Generator, List, Tuple, Optional, Dict, Set +from concurrent.futures import ThreadPoolExecutor, as_completed +from crazy_functions.rag_fns.arxiv_fns.tex_processor import TexProcessor +from crazy_functions.rag_fns.arxiv_fns.arxiv_fragment import ArxivFragment + + + +def save_fragments_to_file(fragments, output_dir: str = "fragment_outputs"): + """ + 将所有fragments保存为单个结构化markdown文件 + + Args: + fragments: fragment列表 + output_dir: 输出目录 + """ + from datetime import datetime + from pathlib import Path + import re + + # 创建输出目录 + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + # 生成文件名 + filename = f"fragments_{timestamp}.md" + file_path = output_path / filename + + current_section = "" + section_count = {} # 用于跟踪每个章节的片段数量 + + with open(file_path, "w", encoding="utf-8") as f: + # 写入文档头部 + f.write("# Document Fragments Analysis\n\n") + f.write("## Overview\n") + f.write(f"- Total Fragments: {len(fragments)}\n") + f.write(f"- Generated Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + + # 如果有标题和摘要,添加到开头 + if fragments and (fragments[0].title or fragments[0].abstract): + f.write("\n## Paper Information\n") + if fragments[0].title: + f.write(f"### Title\n{fragments[0].title}\n") + if fragments[0].abstract: + f.write(f"\n### Abstract\n{fragments[0].abstract}\n") + + # 生成目录 + f.write("\n## Table of Contents\n") + + # 首先收集所有章节信息 + sections = {} + for fragment in fragments: + section = fragment.section or "Uncategorized" + if section not in sections: + sections[section] = [] + sections[section].append(fragment) + + # 写入目录 + for section, section_fragments in sections.items(): + clean_section = section.strip() + if not clean_section: + clean_section = "Uncategorized" + f.write( + f"- [{clean_section}](#{clean_section.lower().replace(' ', '-')}) ({len(section_fragments)} fragments)\n") + + # 写入正文内容 + f.write("\n## Content\n") + + # 按章节组织内容 + for section, section_fragments in sections.items(): + clean_section = section.strip() or "Uncategorized" + f.write(f"\n### {clean_section}\n") + + # 写入每个fragment + for i, fragment in enumerate(section_fragments, 1): + f.write(f"\n#### Fragment {i} ({fragment.segment_type})\n") + + # 元数据 + f.write("**Metadata:**\n") + f.write(f"- Type: {fragment.segment_type}\n") + f.write(f"- Length: {len(fragment.content)} chars\n") + f.write(f"- Importance: {fragment.importance:.2f}\n") + f.write(f"- Is Appendix: {fragment.is_appendix}\n") + f.write(f"- File: {fragment.rel_path}\n") + + # 内容 + f.write("\n**Content:**\n") + f.write("```tex\n") + f.write(fragment.content) + f.write("\n```\n") + + # 添加分隔线 + if i < len(section_fragments): + f.write("\n---\n") + + # 添加统计信息 + f.write("\n## Statistics\n") + f.write("\n### Fragment Type Distribution\n") + type_stats = {} + for fragment in fragments: + type_stats[fragment.segment_type] = type_stats.get(fragment.segment_type, 0) + 1 + + for ftype, count in type_stats.items(): + percentage = (count / len(fragments)) * 100 + f.write(f"- {ftype}: {count} ({percentage:.1f}%)\n") + + # 长度分布 + f.write("\n### Length Distribution\n") + lengths = [len(f.content) for f in fragments] + f.write(f"- Minimum: {min(lengths)} chars\n") + f.write(f"- Maximum: {max(lengths)} chars\n") + f.write(f"- Average: {sum(lengths) / len(lengths):.1f} chars\n") + + print(f"Fragments saved to: {file_path}") + return file_path + + + +class ArxivSplitter: + """Arxiv论文智能分割器""" + + def __init__(self, + char_range: Tuple[int, int], + root_dir: str = "gpt_log/arxiv_cache", + proxies: Optional[Dict[str, str]] = None, + cache_ttl: int = 7 * 24 * 60 * 60): + """ + 初始化分割器 + + Args: + char_range: 字符数范围(最小值, 最大值) + root_dir: 缓存根目录 + proxies: 代理设置 + cache_ttl: 缓存过期时间(秒) + """ + self.min_chars, self.max_chars = char_range + self.root_dir = Path(root_dir) + self.root_dir.mkdir(parents=True, exist_ok=True) + self.proxies = proxies or {} + self.cache_ttl = cache_ttl + + # 动态计算最优线程数 + import multiprocessing + cpu_count = multiprocessing.cpu_count() + # 根据CPU核心数动态设置,但设置上限防止过度并发 + self.max_workers = min(32, cpu_count * 2) + + # 初始化TeX处理器 + self.tex_processor = TexProcessor(char_range) + + # 配置日志 + self._setup_logging() + + + + def _setup_logging(self): + """配置日志""" + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' + ) + self.logger = logging.getLogger(__name__) + + def _normalize_arxiv_id(self, input_str: str) -> str: + """规范化ArXiv ID""" + if 'arxiv.org/' in input_str.lower(): + # 处理URL格式 + if '/pdf/' in input_str: + arxiv_id = input_str.split('/pdf/')[-1] + else: + arxiv_id = input_str.split('/abs/')[-1] + # 移除版本号和其他后缀 + return arxiv_id.split('v')[0].strip() + return input_str.split('v')[0].strip() + + + def _check_cache(self, paper_dir: Path) -> bool: + """ + 检查缓存是否有效,包括文件完整性检查 + + Args: + paper_dir: 论文目录路径 + + Returns: + bool: 如果缓存有效返回True,否则返回False + """ + if not paper_dir.exists(): + return False + + # 检查目录中是否存在必要文件 + has_tex_files = False + has_main_tex = False + + for file_path in paper_dir.rglob("*"): + if file_path.suffix == '.tex': + has_tex_files = True + content = self.tex_processor.read_file(str(file_path)) + if content and r'\documentclass' in content: + has_main_tex = True + break + + if not (has_tex_files and has_main_tex): + return False + + # 检查缓存时间 + cache_time = paper_dir.stat().st_mtime + if (time.time() - cache_time) < self.cache_ttl: + self.logger.info(f"Using valid cache for {paper_dir.name}") + return True + + return False + + async def download_paper(self, arxiv_id: str, paper_dir: Path) -> bool: + """ + 异步下载论文,包含重试机制和临时文件处理 + + Args: + arxiv_id: ArXiv论文ID + paper_dir: 目标目录路径 + + Returns: + bool: 下载成功返回True,否则返回False + """ + from crazy_functions.rag_fns.arxiv_fns.arxiv_downloader import ArxivDownloader + temp_tar_path = paper_dir / f"{arxiv_id}_temp.tar.gz" + final_tar_path = paper_dir / f"{arxiv_id}.tar.gz" + + # 确保目录存在 + paper_dir.mkdir(parents=True, exist_ok=True) + + # 尝试使用 ArxivDownloader 下载 + try: + downloader = ArxivDownloader(root_dir=str(paper_dir), proxies=self.proxies) + downloaded_dir = downloader.download_paper(arxiv_id) + if downloaded_dir: + self.logger.info(f"Successfully downloaded using ArxivDownloader to {downloaded_dir}") + return True + except Exception as e: + self.logger.warning(f"ArxivDownloader failed: {str(e)}. Falling back to direct download.") + + # 如果 ArxivDownloader 失败,使用原有的下载方式作为备选 + urls = [ + f"https://arxiv.org/src/{arxiv_id}", + f"https://arxiv.org/e-print/{arxiv_id}" + ] + + max_retries = 3 + retry_delay = 1 # 初始重试延迟(秒) + + for url in urls: + for attempt in range(max_retries): + try: + self.logger.info(f"Downloading from {url} (attempt {attempt + 1}/{max_retries})") + async with aiohttp.ClientSession() as session: + async with session.get(url, proxy=self.proxies.get('http')) as response: + if response.status == 200: + content = await response.read() + + # 写入临时文件 + temp_tar_path.write_bytes(content) + + try: + # 验证tar文件完整性并解压 + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, self._process_tar_file, temp_tar_path, paper_dir) + + # 下载成功后移动临时文件到最终位置 + temp_tar_path.rename(final_tar_path) + return True + + except Exception as e: + self.logger.warning(f"Invalid tar file: {str(e)}") + if temp_tar_path.exists(): + temp_tar_path.unlink() + + except Exception as e: + self.logger.warning(f"Download attempt {attempt + 1} failed from {url}: {str(e)}") + await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避 + continue + + return False + + def _process_tar_file(self, tar_path: Path, extract_path: Path): + """处理tar文件的同步操作""" + with tarfile.open(tar_path, 'r:gz') as tar: + tar.testall() # 验证文件完整性 + tar.extractall(path=extract_path) # 解压文件 + + def _process_single_tex(self, file_path: str) -> List[ArxivFragment]: + """处理单个TeX文件""" + try: + content = self.tex_processor.read_file(file_path) + if not content: + return [] + + # 提取元数据 + is_main = r'\documentclass' in content + title, abstract = "", "" + if is_main: + title, abstract = self.tex_processor.extract_metadata(content) + + # 分割内容 + segments = self.tex_processor.split_content(content) + fragments = [] + + for i, (segment_content, section, is_appendix) in enumerate(segments): + if not segment_content.strip(): + continue + + segment_type = self.tex_processor.detect_segment_type(segment_content) + importance = self.tex_processor.calculate_importance( + segment_content, segment_type, is_main + ) + fragments.append(ArxivFragment( + file_path=file_path, + content=segment_content, + segment_index=i, + total_segments=len(segments), + rel_path=str(Path(file_path).relative_to(self.root_dir)), + segment_type=segment_type, + title=title, + abstract=abstract, + section=section, + is_appendix=is_appendix, + importance=importance + )) + + return fragments + + except Exception as e: + self.logger.error(f"Error processing {file_path}: {str(e)}") + return [] + + async def process(self, arxiv_id_or_url: str) -> List[ArxivFragment]: + """处理ArXiv论文""" + try: + arxiv_id = self._normalize_arxiv_id(arxiv_id_or_url) + paper_dir = self.root_dir / arxiv_id + + # 检查缓存 + if not self._check_cache(paper_dir): + paper_dir.mkdir(exist_ok=True) + if not await self.download_paper(arxiv_id, paper_dir): + raise RuntimeError(f"Failed to download paper {arxiv_id}") + + # 查找主TeX文件 + main_tex = self.tex_processor.find_main_tex_file(str(paper_dir)) + if not main_tex: + raise RuntimeError(f"No main TeX file found in {paper_dir}") + + # 获取所有相关TeX文件 + tex_files = self.tex_processor.resolve_includes(main_tex) + if not tex_files: + raise RuntimeError(f"No valid TeX files found for {arxiv_id}") + + # 并行处理所有TeX文件 + fragments = [] + chunk_size = max(1, len(tex_files) // self.max_workers) # 计算每个线程处理的文件数 + loop = asyncio.get_event_loop() + + async def process_chunk(chunk_files): + chunk_fragments = [] + for file_path in chunk_files: + try: + result = await loop.run_in_executor(None, self._process_single_tex, file_path) + chunk_fragments.extend(result) + except Exception as e: + self.logger.error(f"Error processing {file_path}: {str(e)}") + return chunk_fragments + + # 将文件分成多个块 + file_chunks = [tex_files[i:i + chunk_size] for i in range(0, len(tex_files), chunk_size)] + # 异步处理每个块 + chunk_results = await asyncio.gather(*[process_chunk(chunk) for chunk in file_chunks]) + for result in chunk_results: + fragments.extend(result) + # 重新计算片段索引并排序 + fragments.sort(key=lambda x: (x.rel_path, x.segment_index)) + total_fragments = len(fragments) + + for i, fragment in enumerate(fragments): + fragment.segment_index = i + fragment.total_segments = total_fragments + # 在返回之前添加过滤 + fragments = self.tex_processor.filter_fragments(fragments) + return fragments + + except Exception as e: + self.logger.error(f"Failed to process {arxiv_id_or_url}: {str(e)}") + raise + + +async def test_arxiv_splitter(): + """测试ArXiv分割器的功能""" + + # 测试配置 + test_cases = [ + { + "arxiv_id": "2411.03663", + "expected_title": "Large Language Models and Simple Scripts", + "min_fragments": 10, + }, + # { + # "arxiv_id": "1805.10988", + # "expected_title": "RAG vs Fine-tuning", + # "min_fragments": 15, + # } + ] + + # 创建分割器实例 + splitter = ArxivSplitter( + char_range=(800, 1800), + root_dir="test_cache" + ) + + + for case in test_cases: + print(f"\nTesting paper: {case['arxiv_id']}") + try: + fragments = await splitter.process(case['arxiv_id']) + + # 保存fragments + output_dir = save_fragments_to_file(fragments,output_dir="crazy_functions/rag_fns/arxiv_fns/gpt_log") + print(f"Output saved to: {output_dir}") + # 内容检查 + for fragment in fragments: + # 长度检查 + + print((fragment.content)) + print(len(fragment.content)) + # 类型检查 + + + except Exception as e: + print(f"✗ Test failed for {case['arxiv_id']}: {str(e)}") + raise + + +if __name__ == "__main__": + asyncio.run(test_arxiv_splitter()) \ No newline at end of file diff --git a/crazy_functions/rag_fns/arxiv_fns/latex_patterns.py b/crazy_functions/rag_fns/arxiv_fns/latex_patterns.py new file mode 100644 index 00000000..0d5b93f9 --- /dev/null +++ b/crazy_functions/rag_fns/arxiv_fns/latex_patterns.py @@ -0,0 +1,395 @@ +from dataclasses import dataclass, field + +@dataclass +class LaTeXPatterns: + """LaTeX模式存储类,用于集中管理所有LaTeX相关的正则表达式模式""" + special_envs = { + 'math': [ + # 基础数学环境 + r'\\begin{(equation|align|gather|eqnarray|multline|flalign|alignat)\*?}.*?\\end{\1\*?}', + r'\$\$.*?\$\$', + r'\$[^$]+\$', + # 矩阵环境 + r'\\begin{(matrix|pmatrix|bmatrix|Bmatrix|vmatrix|Vmatrix|smallmatrix)\*?}.*?\\end{\1\*?}', + # 数组环境 + r'\\begin{(array|cases|aligned|gathered|split)\*?}.*?\\end{\1\*?}', + # 其他数学环境 + r'\\begin{(subequations|math|displaymath)\*?}.*?\\end{\1\*?}' + ], + + 'table': [ + # 基础表格环境 + r'\\begin{(table|tabular|tabularx|tabulary|longtable)\*?}.*?\\end{\1\*?}', + # 复杂表格环境 + r'\\begin{(tabu|supertabular|xtabular|mpsupertabular)\*?}.*?\\end{\1\*?}', + # 自定义表格环境 + r'\\begin{(threeparttable|tablefootnote)\*?}.*?\\end{\1\*?}', + # 表格注释环境 + r'\\begin{(tablenotes)\*?}.*?\\end{\1\*?}' + ], + + 'figure': [ + # 图片环境 + r'\\begin{figure\*?}.*?\\end{figure\*?}', + r'\\begin{(subfigure|wrapfigure)\*?}.*?\\end{\1\*?}', + # 图片插入命令 + r'\\includegraphics(\[.*?\])?\{.*?\}', + # tikz 图形环境 + r'\\begin{(tikzpicture|pgfpicture)\*?}.*?\\end{\1\*?}', + # 其他图形环境 + r'\\begin{(picture|pspicture)\*?}.*?\\end{\1\*?}' + ], + + 'algorithm': [ + # 算法环境 + r'\\begin{(algorithm|algorithmic|algorithm2e|algorithmicx)\*?}.*?\\end{\1\*?}', + r'\\begin{(lstlisting|verbatim|minted|listing)\*?}.*?\\end{\1\*?}', + # 代码块环境 + r'\\begin{(code|verbatimtab|verbatimwrite)\*?}.*?\\end{\1\*?}', + # 伪代码环境 + r'\\begin{(pseudocode|procedure)\*?}.*?\\end{\1\*?}' + ], + + 'list': [ + # 列表环境 + r'\\begin{(itemize|enumerate|description)\*?}.*?\\end{\1\*?}', + r'\\begin{(list|compactlist|bulletlist)\*?}.*?\\end{\1\*?}', + # 自定义列表环境 + r'\\begin{(tasks|todolist)\*?}.*?\\end{\1\*?}' + ], + + 'theorem': [ + # 定理类环境 + r'\\begin{(theorem|lemma|proposition|corollary)\*?}.*?\\end{\1\*?}', + r'\\begin{(definition|example|proof|remark)\*?}.*?\\end{\1\*?}', + # 其他证明环境 + r'\\begin{(axiom|property|assumption|conjecture)\*?}.*?\\end{\1\*?}' + ], + + 'box': [ + # 文本框环境 + r'\\begin{(tcolorbox|mdframed|framed|shaded)\*?}.*?\\end{\1\*?}', + r'\\begin{(boxedminipage|shadowbox)\*?}.*?\\end{\1\*?}', + # 强调环境 + r'\\begin{(important|warning|info|note)\*?}.*?\\end{\1\*?}' + ], + + 'quote': [ + # 引用环境 + r'\\begin{(quote|quotation|verse|abstract)\*?}.*?\\end{\1\*?}', + r'\\begin{(excerpt|epigraph)\*?}.*?\\end{\1\*?}' + ], + + 'bibliography': [ + # 参考文献环境 + r'\\begin{(thebibliography|bibliography)\*?}.*?\\end{\1\*?}', + r'\\begin{(biblist|citelist)\*?}.*?\\end{\1\*?}' + ], + + 'index': [ + # 索引环境 + r'\\begin{(theindex|printindex)\*?}.*?\\end{\1\*?}', + r'\\begin{(glossary|acronym)\*?}.*?\\end{\1\*?}' + ] + } + # 章节模式 + section_patterns = [ + # 基础章节命令 + r'\\chapter\{([^}]+)\}', + r'\\section\{([^}]+)\}', + r'\\subsection\{([^}]+)\}', + r'\\subsubsection\{([^}]+)\}', + r'\\paragraph\{([^}]+)\}', + r'\\subparagraph\{([^}]+)\}', + + # 带星号的变体(不编号) + r'\\chapter\*\{([^}]+)\}', + r'\\section\*\{([^}]+)\}', + r'\\subsection\*\{([^}]+)\}', + r'\\subsubsection\*\{([^}]+)\}', + r'\\paragraph\*\{([^}]+)\}', + r'\\subparagraph\*\{([^}]+)\}', + + # 特殊章节 + r'\\part\{([^}]+)\}', + r'\\part\*\{([^}]+)\}', + r'\\appendix\{([^}]+)\}', + + # 前言部分 + r'\\frontmatter\{([^}]+)\}', + r'\\mainmatter\{([^}]+)\}', + r'\\backmatter\{([^}]+)\}', + + # 目录相关 + r'\\tableofcontents', + r'\\listoffigures', + r'\\listoftables', + + # 自定义章节命令 + r'\\addchap\{([^}]+)\}', # KOMA-Script类 + r'\\addsec\{([^}]+)\}', # KOMA-Script类 + r'\\minisec\{([^}]+)\}', # KOMA-Script类 + + # 带可选参数的章节命令 + r'\\chapter\[([^]]+)\]\{([^}]+)\}', + r'\\section\[([^]]+)\]\{([^}]+)\}', + r'\\subsection\[([^]]+)\]\{([^}]+)\}' + ] + + # 包含模式 + include_patterns = [ + r'\\(input|include|subfile)\{([^}]+)\}' + ] + + metadata_patterns = { + # 标题相关 + 'title': [ + r'\\title\{([^}]+)\}', + r'\\Title\{([^}]+)\}', + r'\\doctitle\{([^}]+)\}', + r'\\subtitle\{([^}]+)\}', + r'\\chapter\*?\{([^}]+)\}', # 第一章可能作为标题 + r'\\maketitle\s*\\section\*?\{([^}]+)\}' # 第一节可能作为标题 + ], + + # 摘要相关 + 'abstract': [ + r'\\begin{abstract}(.*?)\\end{abstract}', + r'\\abstract\{([^}]+)\}', + r'\\begin{摘要}(.*?)\\end{摘要}', + r'\\begin{Summary}(.*?)\\end{Summary}', + r'\\begin{synopsis}(.*?)\\end{synopsis}', + r'\\begin{abstracten}(.*?)\\end{abstracten}' # 英文摘要 + ], + + # 作者信息 + 'author': [ + r'\\author\{([^}]+)\}', + r'\\Author\{([^}]+)\}', + r'\\authorinfo\{([^}]+)\}', + r'\\authors\{([^}]+)\}', + r'\\author\[([^]]+)\]\{([^}]+)\}', # 带附加信息的作者 + r'\\begin{authors}(.*?)\\end{authors}' + ], + + # 日期相关 + 'date': [ + r'\\date\{([^}]+)\}', + r'\\Date\{([^}]+)\}', + r'\\submitdate\{([^}]+)\}', + r'\\publishdate\{([^}]+)\}', + r'\\revisiondate\{([^}]+)\}' + ], + + # 关键词 + 'keywords': [ + r'\\keywords\{([^}]+)\}', + r'\\Keywords\{([^}]+)\}', + r'\\begin{keywords}(.*?)\\end{keywords}', + r'\\key\{([^}]+)\}', + r'\\begin{关键词}(.*?)\\end{关键词}' + ], + + # 机构/单位 + 'institution': [ + r'\\institute\{([^}]+)\}', + r'\\institution\{([^}]+)\}', + r'\\affiliation\{([^}]+)\}', + r'\\organization\{([^}]+)\}', + r'\\department\{([^}]+)\}' + ], + + # 学科/主题 + 'subject': [ + r'\\subject\{([^}]+)\}', + r'\\Subject\{([^}]+)\}', + r'\\field\{([^}]+)\}', + r'\\discipline\{([^}]+)\}' + ], + + # 版本信息 + 'version': [ + r'\\version\{([^}]+)\}', + r'\\revision\{([^}]+)\}', + r'\\release\{([^}]+)\}' + ], + + # 许可证/版权 + 'license': [ + r'\\license\{([^}]+)\}', + r'\\copyright\{([^}]+)\}', + r'\\begin{license}(.*?)\\end{license}' + ], + + # 联系方式 + 'contact': [ + r'\\email\{([^}]+)\}', + r'\\phone\{([^}]+)\}', + r'\\address\{([^}]+)\}', + r'\\contact\{([^}]+)\}' + ], + + # 致谢 + 'acknowledgments': [ + r'\\begin{acknowledgments}(.*?)\\end{acknowledgments}', + r'\\acknowledgments\{([^}]+)\}', + r'\\thanks\{([^}]+)\}', + r'\\begin{致谢}(.*?)\\end{致谢}' + ], + + # 项目/基金 + 'funding': [ + r'\\funding\{([^}]+)\}', + r'\\grant\{([^}]+)\}', + r'\\project\{([^}]+)\}', + r'\\support\{([^}]+)\}' + ], + + # 分类号/编号 + 'classification': [ + r'\\classification\{([^}]+)\}', + r'\\serialnumber\{([^}]+)\}', + r'\\id\{([^}]+)\}', + r'\\doi\{([^}]+)\}' + ], + + # 语言 + 'language': [ + r'\\documentlanguage\{([^}]+)\}', + r'\\lang\{([^}]+)\}', + r'\\language\{([^}]+)\}' + ] +} + latex_only_patterns = { + # 文档类和包引入 + r'\\documentclass(\[.*?\])?\{.*?\}', + r'\\usepackage(\[.*?\])?\{.*?\}', + # 常见的文档设置命令 + r'\\setlength\{.*?\}\{.*?\}', + r'\\newcommand\{.*?\}(\[.*?\])?\{.*?\}', + r'\\renewcommand\{.*?\}(\[.*?\])?\{.*?\}', + r'\\definecolor\{.*?\}\{.*?\}\{.*?\}', + # 页面设置相关 + r'\\pagestyle\{.*?\}', + r'\\thispagestyle\{.*?\}', + # 其他常见的设置命令 + r'\\bibliographystyle\{.*?\}', + r'\\bibliography\{.*?\}', + r'\\setcounter\{.*?\}\{.*?\}', + # 字体和文本设置命令 + r'\\makeFNbottom', + r'\\@setfontsize\\[A-Z]+\{.*?\}\{.*?\}', # 匹配字体大小设置 + r'\\renewcommand\\[A-Z]+\{\\@setfontsize\\[A-Z]+\{.*?\}\{.*?\}\}', + r'\\renewcommand\{?\\thefootnote\}?\{\\fnsymbol\{footnote\}\}', + r'\\renewcommand\\footnoterule\{.*?\}', + r'\\color\{.*?\}', + + # 页面和节标题设置 + r'\\setcounter\{secnumdepth\}\{.*?\}', + r'\\renewcommand\\@biblabel\[.*?\]\{.*?\}', + r'\\renewcommand\\@makefntext\[.*?\](\{.*?\})*', + r'\\renewcommand\{?\\figurename\}?\{.*?\}', + + # 字体样式设置 + r'\\sectionfont\{.*?\}', + r'\\subsectionfont\{.*?\}', + r'\\subsubsectionfont\{.*?\}', + + # 间距和布局设置 + r'\\setstretch\{.*?\}', + r'\\setlength\{\\skip\\footins\}\{.*?\}', + r'\\setlength\{\\footnotesep\}\{.*?\}', + r'\\setlength\{\\jot\}\{.*?\}', + r'\\hrule\s+width\s+.*?\s+height\s+.*?', + + # makeatletter 和 makeatother + r'\\makeatletter\s*', + r'\\makeatother\s*', + r'\\footnotetext\{[^}]*\$\^{[^}]*}\$[^}]*\}', # 带有上标的脚注 + # r'\\footnotetext\{[^}]*\}', # 普通脚注 + # r'\\footnotetext\{.*?(?:\$\^{.*?}\$)?.*?(?:email\s*:\s*[^}]*)?.*?\}', # 带有邮箱的脚注 + # r'\\footnotetext\{.*?(?:ESI|DOI).*?\}', # 带有 DOI 或 ESI 引用的脚注 + # 文档结构命令 + r'\\begin\{document\}', + r'\\end\{document\}', + r'\\maketitle', + r'\\printbibliography', + r'\\newpage', + + # 输入文件命令 + r'\\input\{[^}]*\}', + r'\\input\{.*?\.tex\}', # 特别匹配 .tex 后缀的输入 + + # 脚注相关 + # r'\\footnotetext\[\d+\]\{[^}]*\}', # 带编号的脚注 + + # 致谢环境 + r'\\begin\{ack\}', + r'\\end\{ack\}', + r'\\begin\{ack\}[^\n]*(?:\n.*?)*?\\end\{ack\}', # 匹配整个致谢环境及其内容 + + # 其他文档控制命令 + r'\\renewcommand\{\\thefootnote\}\{\\fnsymbol\{footnote\}\}', + } + math_envs = [ + # 基础数学环境 + (r'\\begin{equation\*?}.*?\\end{equation\*?}', 'equation'), # 单行公式 + (r'\\begin{align\*?}.*?\\end{align\*?}', 'align'), # 多行对齐公式 + (r'\\begin{gather\*?}.*?\\end{gather\*?}', 'gather'), # 多行居中公式 + (r'\$\$.*?\$\$', 'display'), # 行间公式 + (r'\$.*?\$', 'inline'), # 行内公式 + + # 矩阵环境 + (r'\\begin{matrix}.*?\\end{matrix}', 'matrix'), # 基础矩阵 + (r'\\begin{pmatrix}.*?\\end{pmatrix}', 'pmatrix'), # 圆括号矩阵 + (r'\\begin{bmatrix}.*?\\end{bmatrix}', 'bmatrix'), # 方括号矩阵 + (r'\\begin{vmatrix}.*?\\end{vmatrix}', 'vmatrix'), # 竖线矩阵 + (r'\\begin{Vmatrix}.*?\\end{Vmatrix}', 'Vmatrix'), # 双竖线矩阵 + (r'\\begin{smallmatrix}.*?\\end{smallmatrix}', 'smallmatrix'), # 小号矩阵 + + # 数组环境 + (r'\\begin{array}.*?\\end{array}', 'array'), # 数组 + (r'\\begin{cases}.*?\\end{cases}', 'cases'), # 分段函数 + + # 多行公式环境 + (r'\\begin{multline\*?}.*?\\end{multline\*?}', 'multline'), # 多行单个公式 + (r'\\begin{split}.*?\\end{split}', 'split'), # 拆分长公式 + (r'\\begin{alignat\*?}.*?\\end{alignat\*?}', 'alignat'), # 对齐环境带间距控制 + (r'\\begin{flalign\*?}.*?\\end{flalign\*?}', 'flalign'), # 完全左对齐 + + # 特殊数学环境 + (r'\\begin{subequations}.*?\\end{subequations}', 'subequations'), # 子公式编号 + (r'\\begin{gathered}.*?\\end{gathered}', 'gathered'), # 居中对齐组 + (r'\\begin{aligned}.*?\\end{aligned}', 'aligned'), # 内部对齐组 + + # 定理类环境 + (r'\\begin{theorem}.*?\\end{theorem}', 'theorem'), # 定理 + (r'\\begin{lemma}.*?\\end{lemma}', 'lemma'), # 引理 + (r'\\begin{proof}.*?\\end{proof}', 'proof'), # 证明 + + # 数学模式中的表格环境 + (r'\\begin{tabular}.*?\\end{tabular}', 'tabular'), # 表格 + (r'\\begin{array}.*?\\end{array}', 'array'), # 数组 + + # 其他专业数学环境 + (r'\\begin{CD}.*?\\end{CD}', 'CD'), # 交换图 + (r'\\begin{boxed}.*?\\end{boxed}', 'boxed'), # 带框公式 + (r'\\begin{empheq}.*?\\end{empheq}', 'empheq'), # 强调公式 + + # 化学方程式环境 (需要加载 mhchem 包) + (r'\\begin{reaction}.*?\\end{reaction}', 'reaction'), # 化学反应式 + (r'\\ce\{.*?\}', 'chemequation'), # 化学方程式 + + # 物理单位环境 (需要加载 siunitx 包) + (r'\\SI\{.*?\}\{.*?\}', 'SI'), # 物理单位 + (r'\\si\{.*?\}', 'si'), # 单位 + + # 补充环境 + (r'\\begin{equation\+}.*?\\end{equation\+}', 'equation+'), # breqn包的自动换行公式 + (r'\\begin{dmath\*?}.*?\\end{dmath\*?}', 'dmath'), # breqn包的显示数学模式 + (r'\\begin{dgroup\*?}.*?\\end{dgroup\*?}', 'dgroup'), # breqn包的公式组 + ] + + # 示例使用函数 + + # 使用示例 diff --git a/crazy_functions/rag_fns/arxiv_fns/tex_processor.py b/crazy_functions/rag_fns/arxiv_fns/tex_processor.py new file mode 100644 index 00000000..b5fe9c97 --- /dev/null +++ b/crazy_functions/rag_fns/arxiv_fns/tex_processor.py @@ -0,0 +1,1099 @@ +import re +import os +import logging +from pathlib import Path +from typing import List, Tuple, Dict, Set, Optional, Callable +from crazy_functions.rag_fns.arxiv_fns.arxiv_fragment import ArxivFragment +from crazy_functions.rag_fns.arxiv_fns.latex_patterns import LaTeXPatterns + +class TexProcessor: + """TeX文档处理器类""" + + def __init__(self, char_range: Tuple[int, int]): + """ + 初始化TeX处理器 + + Args: + char_range: 字符数范围(最小值, 最大值) + """ + self.min_chars, self.max_chars = char_range + self.logger = logging.getLogger(__name__) + + # 初始化LaTeX环境和命令模式 + self._init_patterns() + self.latex_only_patterns = LaTeXPatterns.latex_only_patterns + # 初始化合并规则列表,每个规则是(priority, rule_func)元组 + self.merge_rules = [] + # 注册默认规则 + self.register_merge_rule(self._merge_short_segments, priority=90) + self.register_merge_rule(self._merge_clauses, priority=100) + + def is_latex_commands_only(self, content: str) -> bool: + """ + 检查内容是否仅包含LaTeX命令 + + Args: + content: 要检查的内容 + + Returns: + bool: 如果内容仅包含LaTeX命令返回True,否则返回False + """ + # 预处理:移除空白字符 + content = content.strip() + if not content: + return True + + # 移除注释 + content = re.sub(r'(?m)%.*$', '', content) + content = content.strip() + + # 移除所有已知的LaTeX命令模式 + for pattern in self.latex_only_patterns: + content = re.sub(pattern, '', content) + + # 移除常见的LaTeX控制序列 + content = re.sub(r'\\[a-zA-Z]+(\[.*?\])?(\{.*?\})?', '', content) + + # 移除剩余的空白字符 + content = re.sub(r'\s+', '', content) + + # 检查是否还有实质性内容 + # 如果长度为0或者只包含花括号、方括号等LaTeX标记,则认为是纯LaTeX命令 + remaining_chars = re.sub(r'[\{\}\[\]\(\)\,\\\s]', '', content) + return len(remaining_chars) == 0 + + def has_meaningful_content(self, content: str, min_text_ratio: float = 0.1) -> bool: + """ + 检查内容是否包含足够的有意义文本 + + Args: + content: 要检查的内容 + min_text_ratio: 最小文本比例(默认0.1,表示至少10%是文本) + + Returns: + bool: 如果内容包含足够的有意义文本返回True,否则返回False + """ + # 移除注释和空白字符 + content = re.sub(r'(?m)%.*$', '', content) + content = content.strip() + + # 计算总长度 + total_length = len(content) + if total_length == 0: + return False + + # 移除所有LaTeX命令和环境 + for pattern in self.latex_only_patterns: + content = re.sub(pattern, '', content) + content = re.sub(r'\\[a-zA-Z]+(\[.*?\])?(\{.*?\})?', '', content) + + # 计算剩余文本长度(移除剩余的LaTeX标记) + remaining_text = re.sub(r'[\{\}\[\]\(\)\,\\\s]', '', content) + text_ratio = len(remaining_text) / total_length + + return text_ratio >= min_text_ratio + + def filter_fragments(self, fragments, + min_text_ratio: float = 0.1): + """ + 过滤fragment列表,移除仅包含LaTeX命令的片段,并合并相邻的片段 + + Args: + fragments: ArxivFragment列表 + min_text_ratio: 最小文本比例 + + Returns: + List[ArxivFragment]: 过滤后的fragment列表 + """ + filtered_fragments = [] + total_count = len(fragments) + filtered_count = 0 + + for fragment in fragments: + if self.has_meaningful_content(fragment.content, min_text_ratio): + filtered_fragments.append(fragment) + else: + filtered_count += 1 + self.logger.debug(f"Filtered out latex-only fragment: {fragment.content[:100]}...") + + # 记录过滤统计 + if filtered_count > 0: + self.logger.info(f"Filtered out {filtered_count}/{total_count} latex-only fragments") + + # 重新计算索引 + for i, fragment in enumerate(filtered_fragments): + fragment.segment_index = i + fragment.total_segments = len(filtered_fragments) + + + filtered_fragments = self.merge_segments(filtered_fragments) + + # 重新计算索引 + for i, fragment in enumerate(filtered_fragments): + fragment.segment_index = i + fragment.total_segments = len(filtered_fragments) + + return filtered_fragments + def _is_special_environment(self, content: str) -> bool: + """ + 检查内容是否属于特殊环境 + + Args: + content: 要检查的内容 + + Returns: + bool: 如果内容属于特殊环境返回True,否则返回False + """ + for env_patterns in self.special_envs.values(): + for pattern in env_patterns: + if re.search(pattern, content, re.DOTALL): + return True + return False + + def _init_patterns(self): + """初始化LaTeX模式匹配规则""" + # 特殊环境模式 + self.special_envs = LaTeXPatterns.special_envs + # 章节模式 + self.section_patterns = LaTeXPatterns.section_patterns + # 包含模式 + self.include_patterns = LaTeXPatterns.include_patterns + # 元数据模式 + self.metadata_patterns = LaTeXPatterns.metadata_patterns + + def read_file(self, file_path: str) -> Optional[str]: + """ + 读取TeX文件内容,支持多种编码 + + Args: + file_path: 文件路径 + + Returns: + Optional[str]: 文件内容或None + """ + encodings = ['utf-8', 'latin1', 'gbk', 'gb2312', 'ascii'] + for encoding in encodings: + try: + with open(file_path, 'r', encoding=encoding) as f: + return f.read() + except UnicodeDecodeError: + continue + + self.logger.warning(f"Failed to read {file_path} with all encodings") + return None + + def find_main_tex_file(self, directory: str) -> Optional[str]: + """ + 查找主TeX文件 + + Args: + directory: 目录路径 + + Returns: + Optional[str]: 主文件路径或None + """ + tex_files = list(Path(directory).rglob("*.tex")) + if not tex_files: + return None + + # 按优先级查找 + for tex_file in tex_files: + content = self.read_file(str(tex_file)) + if content: + if r'\documentclass' in content: + return str(tex_file) + if tex_file.name.lower() == 'main.tex': + return str(tex_file) + + # 返回最大的tex文件 + return str(max(tex_files, key=lambda x: x.stat().st_size)) + + def resolve_includes(self, tex_file: str, processed: Set[str] = None) -> List[str]: + """ + 解析TeX文件中的include引用 + + Args: + tex_file: TeX文件路径 + processed: 已处理的文件集合 + + Returns: + List[str]: 相关文件路径列表 + """ + if processed is None: + processed = set() + + if tex_file in processed: + return [] + + processed.add(tex_file) + result = [tex_file] + content = self.read_file(tex_file) + + if not content: + return result + + base_dir = Path(tex_file).parent + for pattern in self.include_patterns: + for match in re.finditer(pattern, content): + included_file = match.group(2) + if not included_file.endswith('.tex'): + included_file += '.tex' + + full_path = str(base_dir / included_file) + if os.path.exists(full_path) and full_path not in processed: + result.extend(self.resolve_includes(full_path, processed)) + + return result + + + def _preprocess_content(self, content: str) -> str: + """预处理TeX内容""" + # 移除注释 + content = re.sub(r'(?m)%.*$', '', content) + # 规范化空白字符 + # content = re.sub(r'\s+', ' ', content) + content = re.sub(r'\n\s*\n', '\n\n', content) + return content.strip() + + def _protect_special_environments(self, content: str, protected_blocks: Dict[str, str]) -> str: + """保护特殊环境内容""" + for env_patterns in self.special_envs.values(): + for pattern in env_patterns: + content = re.sub( + pattern, + lambda m: self._store_protected_block(m.group(0), protected_blocks), + content, + flags=re.DOTALL + ) + return content + + def _store_protected_block(self, content: str, protected_blocks: Dict[str, str]) -> str: + """存储保护块""" + placeholder = f"__PROTECTED_{len(protected_blocks)}__" + protected_blocks[placeholder] = content + return placeholder + + def _restore_special_environments(self, content: str, protected_blocks: Dict[str, str]) -> str: + """恢复特殊环境内容""" + for placeholder, original in protected_blocks.items(): + content = content.replace(placeholder, original) + return content + + def _get_section_info(self, para: str, content: str) -> Optional[Tuple[str, bool]]: + """获取章节信息""" + # 检查是否是附录 + is_appendix = bool(re.search(r'\\appendix', content)) + + # 提取章节标题 + for pattern in self.section_patterns: + match = re.search(pattern, para) + if match: + section_title = match.group(1) + # 清理LaTeX命令 + section_title = re.sub(r'\\[a-zA-Z]+(?:\[.*?\])?{(.+?)}', r'\1', section_title) + return section_title, is_appendix + + return None + + def _split_long_paragraph(self, paragraph: str) -> List[str]: + """分割长段落""" + parts = [] + current_part = [] + current_length = 0 + + sentences = re.split(r'(?<=[.!?。!?])\s+', paragraph) + for sentence in sentences: + sent_length = len(sentence) + + if current_length + sent_length <= self.max_chars: + current_part.append(sentence) + current_length += sent_length + else: + if current_part: + parts.append(' '.join(current_part)) + current_part = [sentence] + current_length = sent_length + + if current_part: + parts.append(' '.join(current_part)) + + return parts + + def extract_metadata(self, content: str) -> Tuple[str, str]: + """ + 提取文档元数据 + + Args: + content: TeX内容 + + Returns: + Tuple[str, str]: (标题, 摘要) + """ + title = "" + abstract = "" + + # 提取标题 + for pattern in self.metadata_patterns['title']: + match = re.search(pattern, content) + if match: + title = match.group(1) + # 清理LaTeX命令 + title = re.sub(r'\\[a-zA-Z]+(?:\[.*?\])?{(.+?)}', r'\1', title) + break + + # 提取摘要 + for pattern in self.metadata_patterns['abstract']: + match = re.search(pattern, content, re.DOTALL) + if match: + abstract = match.group(1) + # 清理LaTeX命令 + abstract = re.sub(r'\\[a-zA-Z]+(?:\[.*?\])?{(.+?)}', r'\1', abstract) + break + + return title.strip(), abstract.strip() + + def detect_segment_type(self, content: str) -> str: + """ + 检测片段类型 + + Args: + content: 内容片段 + + Returns: + str: 片段类型 + """ + for env_type, patterns in self.special_envs.items(): + for pattern in patterns: + if re.search(pattern, content, re.DOTALL): + return env_type + return 'text' + + def calculate_importance(self, content: str, segment_type: str, is_main: bool) -> float: + """ + 计算内容重要性得分 + + Args: + content: 内容片段 + segment_type: 片段类型 + is_main: 是否在主文件中 + + Returns: + float: 重要性得分 (0-1) + """ + score = 0.5 # 基础分 + + # 根据片段类型调整得分 + type_weights = { + 'text': 0.5, + 'math': 0.7, + 'table': 0.8, + 'figure': 0.6, + 'algorithm': 0.8 + } + score += type_weights.get(segment_type, 0) + + # 根据位置调整得分 + if is_main: + score += 0.2 + + # 根据内容特征调整得分 + if re.search(r'\\label{', content): + score += 0.1 + if re.search(r'\\cite{', content): + score += 0.1 + if re.search(r'\\ref{', content): + score += 0.1 + + # 规范化得分到0-1范围 + return min(1.0, max(0.0, score)) + + + + def split_content(self, content: str) -> List[Tuple[str, str, bool]]: + """ + 按段落分割TeX内容,对超长段落按换行符分割 + + Args: + content: TeX文档内容 + + Returns: + List[Tuple[str, str, bool]]: [(段落内容, 章节名, 是否附录)] + """ + content = self._preprocess_content(content) + segments = [] + current_section = "未命名章节" + is_appendix = False + + # 保护特殊环境 + protected_blocks = {} + # content = self._protect_special_environments(content, protected_blocks) + + # 按段落分割 + paragraphs = re.split(r'\n\s*\n', content) + + for para in paragraphs: + para = para.strip() + if not para: + continue + + # 恢复特殊环境 + para = self._restore_special_environments(para, protected_blocks) + + # 检查章节变化 + section_info = self._get_section_info(para, content) + if section_info: + current_section, is_appendix = section_info + continue + + # 处理特殊环境 + if self._is_special_environment(para): + # 特殊环境超长时分割 + if len(para) > self.max_chars: + split_parts = self._split_special_environment(para) + segments.extend((part, current_section, is_appendix) for part in split_parts) + else: + segments.append((para, current_section, is_appendix)) + continue + + # 处理普通段落 + if len(para) > self.max_chars: + # 按换行符分割超长段落 + split_parts = [p.strip() for p in para.split('\n') if p.strip()] + segments.extend((part, current_section, is_appendix) for part in split_parts) + else: + segments.append((para, current_section, is_appendix)) + + return segments + + def _is_complete_env(self, content: str) -> bool: + """ + 检查是否是完整的LaTeX环境 + + Args: + content: 要检查的内容 + + Returns: + bool: 是否是完整环境 + """ + try: + # 检查基本数学环境配对 + env_pairs = [ + (r'\\begin{(equation\*?)}', r'\\end{equation\*?}'), + (r'\\begin{(align\*?)}', r'\\end{align\*?}'), + (r'\\begin{(gather\*?)}', r'\\end{gather\*?}'), + (r'\\begin{(multline\*?)}', r'\\end{multline\*?}'), + (r'\$\$', r'\$\$'), # 行间数学 + (r'\$', r'\$'), # 行内数学 + (r'\\[', r'\\]'), # 显示数学 + (r'\\(', r'\\)'), # 行内数学 + (r'\\begin{', r'\\end{') # 通用环境 + ] + + # 检查所有环境配对 + for begin_pattern, end_pattern in env_pairs: + if isinstance(begin_pattern, tuple): + begin_pattern, end_pattern = begin_pattern + begin_count = len(re.findall(begin_pattern, content)) + end_count = len(re.findall(end_pattern, content)) + if begin_count != end_count: + return False + + # 检查括号配对 + brackets = {'{': '}', '[': ']', '(': ')'} + bracket_count = {k: 0 for k in brackets.keys() | brackets.values()} + + for char in content: + if char in bracket_count: + bracket_count[char] += 1 + + for open_bracket, close_bracket in brackets.items(): + if bracket_count[open_bracket] != bracket_count[close_bracket]: + return False + + return True + + except Exception as e: + self.logger.warning(f"Error checking environment completeness: {str(e)}") + return False + def _split_special_environment(self, content: str) -> List[str]: + """ + 分割特殊环境内容,确保环境的完整性 + + Args: + content: 特殊环境内容 + + Returns: + List[str]: 分割后的内容列表 + """ + env_type = self.detect_segment_type(content) + + # 如果内容已经在允许的长度范围内,且是完整的环境,直接返回 + try: + if len(content) <= self.max_chars: + if self._is_complete_env(content): + return [content] + except Exception as e: + self.logger.warning(f"Error checking environment in split_special_environment: {str(e)}") + + # 根据不同环境类型选择不同的分割策略 + if env_type == 'math': + return self._split_math_content(content) + elif env_type == 'table': + return self._split_table_content(content) + else: + # 对于其他类型的环境 + parts = [] + current_part = "" + + # 按行分割并尝试保持环境完整性 + lines = content.split('\n') + for line in lines: + line_with_newline = line + '\n' + + # 检查是否添加当前行会超出长度限制 + if len(current_part) + len(line_with_newline) <= self.max_chars: + current_part += line_with_newline + else: + # 如果当前部分不为空,进行处理 + if current_part: + try: + # 尝试找到一个完整的环境结束点 + if self._is_complete_env(current_part): + parts.append(current_part) + current_part = line_with_newline + else: + # 如果当前部分不是完整环境,继续添加 + if len(current_part) + len(line_with_newline) <= self.max_chars * 1.5: # 允许一定程度的超出 + current_part += line_with_newline + else: + # 如果实在太长,强制分割 + parts.append(current_part) + current_part = line_with_newline + except Exception as e: + self.logger.warning(f"Error processing environment part: {str(e)}") + parts.append(current_part) + current_part = line_with_newline + else: + # 如果当前行本身就超过长度限制 + parts.append(line_with_newline) + + # 处理最后剩余的部分 + if current_part: + parts.append(current_part) + + # 清理并返回非空片段 + return [p.strip() for p in parts if p.strip()] + def _split_math_content(self, content: str) -> List[str]: + """ + 分割数学公式内容,确保公式环境的完整性 + + Args: + content: 数学公式内容 + + Returns: + List[str]: 分割后的公式列表 + """ + # 首先识别完整的数学环境 + math_envs = LaTeXPatterns.math_envs + + # 提取所有完整的数学环境 + parts = [] + last_end = 0 + math_blocks = [] + + for pattern, env_type in math_envs: + for match in re.finditer(pattern, content, re.DOTALL): + math_blocks.append((match.start(), match.end(), match.group(0))) + + # 按照位置排序 + math_blocks.sort(key=lambda x: x[0]) + + # 保持数学环境的完整性 + if not math_blocks: + # 如果没有识别到完整的数学环境,作为单个块处理 + return [content] if len(content) <= self.max_chars else self._basic_content_split(content) + + current_part = "" + for start, end, block in math_blocks: + # 添加数学环境之前的文本 + if start > last_end: + text_before = content[last_end:start] + if text_before.strip(): + current_part += text_before + + # 处理数学环境 + if len(block) > self.max_chars: + # 如果当前部分已经有内容,先保存 + if current_part: + parts.append(current_part) + current_part = "" + # 将过长的数学环境作为独立部分 + parts.append(block) + else: + # 如果添加当前数学环境会导致超出长度限制 + if current_part and len(current_part) + len(block) > self.max_chars: + parts.append(current_part) + current_part = block + else: + current_part += block + + last_end = end + + # 处理最后的文本部分 + if last_end < len(content): + remaining = content[last_end:] + if remaining.strip(): + if current_part and len(current_part) + len(remaining) > self.max_chars: + parts.append(current_part) + current_part = remaining + else: + current_part += remaining + + if current_part: + parts.append(current_part) + + return parts + + + def _split_table_content(self, content: str) -> List[str]: + """ + 分割表格内容 + + Args: + content: 表格内容 + + Returns: + List[str]: 分割后的表格部分列表 + """ + # 在表格行之间分割 + rows = re.split(r'(\\\\|\\hline)', content) + result = [] + current_part = "" + header = self._extract_table_header(content) + + for row in rows: + if len(current_part + row) <= self.max_chars: + current_part += row + else: + if current_part: + # 确保每个部分都是完整的表格结构 + result.append(self._wrap_table_content(current_part, header)) + current_part = header + row if header else row + + if current_part: + result.append(self._wrap_table_content(current_part, header)) + + return result + + def _extract_table_header(self, content: str) -> str: + """ + 提取表格头部 + + Args: + content: 表格内容 + + Returns: + str: 表格头部 + """ + # 提取表格环境声明和列格式 + header_match = re.match(r'(\\begin{(?:table|tabular|longtable)\*?}.*?\\hline)', content, re.DOTALL) + return header_match.group(1) if header_match else "" + + def _wrap_table_content(self, content: str, header: str) -> str: + """ + 包装表格内容为完整结构 + + Args: + content: 表格内容 + header: 表格头部 + + Returns: + str: 完整的表格结构 + """ + # 确保表格有正确的开始和结束标签 + env_match = re.search(r'\\begin{(table|tabular|longtable)\*?}', header or content) + if env_match: + env_type = env_match.group(1) + if not content.startswith('\\begin'): + content = f"{header}\n{content}" if header else content + if not content.endswith(f'\\end{{{env_type}}}'): + content = f"{content}\n\\end{{{env_type}}}" + return content + + def _basic_content_split(self, content: str) -> List[str]: + """ + 基本的内容分割策略 + + Args: + content: 要分割的内容 + + Returns: + List[str]: 分割后的内容列表 + """ + parts = [] + while content: + if len(content) <= self.max_chars: + parts.append(content) + break + + # 尝试在最后一个完整行处分割 + split_pos = content[:self.max_chars].rfind('\n') + if split_pos == -1: # 如果找不到换行符,则在最后一个空格处分割 + split_pos = content[:self.max_chars].rfind(' ') + if split_pos == -1: # 如果仍然找不到分割点,则强制分割 + split_pos = self.max_chars + + parts.append(content[:split_pos]) + content = content[split_pos:].strip() + + return parts + + def _ensure_segment_lengths(self, segments: List[Tuple[str, str, bool]]) -> List[Tuple[str, str, bool]]: + """ + 确保所有片段都在指定的长度范围内 + + Args: + segments: 原始片段列表 + + Returns: + List[Tuple[str, str, bool]]: 处理后的片段列表 + """ + result = [] + for content, section, is_appendix in segments: + if len(content) <= self.max_chars: + result.append((content, section, is_appendix)) + else: + # 根据内容类型选择合适的分割方法 + if self._is_special_environment(content): + split_parts = self._split_special_environment(content) + else: + split_parts = self._split_long_paragraph(content) + + result.extend((part, section, is_appendix) for part in split_parts) + + return result + + def register_merge_rule(self, rule_func: Callable[[List['ArxivFragment']], List['ArxivFragment']], + priority: int = 0) -> None: + """ + 注册新的合并规则 + + Args: + rule_func: 合并规则函数,接收fragment列表返回处理后的列表 + priority: 规则优先级,数字越大优先级越高 + """ + self.merge_rules.append((priority, rule_func)) + # 按优先级排序,保证高优先级规则先执行 + self.merge_rules.sort(reverse=True, key=lambda x: x[0]) + + + def _merge_segments(self, seg1: 'ArxivFragment', seg2: 'ArxivFragment') -> 'ArxivFragment': + """ + 合并两个片段的通用方法 + + Args: + seg1: 第一个片段 + seg2: 第二个片段 + + Returns: + ArxivFragment: 合并后的片段 + """ + return ArxivFragment( + file_path=seg1.file_path, + content=f"{seg1.content}\n{seg2.content}", + segment_index=seg1.segment_index, + total_segments=seg1.total_segments - 1, + rel_path=seg1.rel_path, + segment_type=self._merge_segment_type(seg1.segment_type, seg2.segment_type), + title=seg1.title, + abstract=seg1.abstract, + section=seg1.section, + is_appendix=seg1.is_appendix, + importance=max(seg1.importance, seg2.importance) + ) + + def _merge_segment_type(self, type1: str, type2: str) -> str: + """ + 确定合并后片段的类型 + + Args: + type1: 第一个片段的类型 + type2: 第二个片段的类型 + + Returns: + str: 合并后的类型 + """ + # 如果类型相同,保持不变 + if type1 == type2: + return type1 + # 如果其中之一是文本,返回非文本的类型 + if type1 == 'text': + return type2 + if type2 == 'text': + return type1 + # 如果是不同的特殊类型,返回 mixed + return 'mixed' + + def _merge_short_segments(self, fragments: List['ArxivFragment']) -> List['ArxivFragment']: + """ + 合并短片段规则 + + Args: + fragments: 片段列表 + + Returns: + List[ArxivFragment]: 处理后的片段列表 + """ + if not fragments: + return fragments + + # 持续合并直到没有可以合并的片段 + need_merge = True + current_fragments = fragments + max_iterations = len(fragments) * 2 # 设置最大迭代次数防止意外情况 + iteration_count = 0 + + while need_merge and iteration_count < max_iterations: + need_merge = False + iteration_count += 1 + result = [] + i = 0 + + while i < len(current_fragments): + current = current_fragments[i] + current_len = len(current.content) + + # 如果当前片段长度足够或是最后一个片段 + if current_len >= self.min_chars or i == len(current_fragments) - 1: + result.append(current) + i += 1 + continue + + # 查找最适合合并的相邻片段 + best_target_idx = -1 + min_combined_length = float('inf') + + # 检查前后片段,选择合并后总长度最小的 + for idx in [i - 1, i + 1]: + if 0 <= idx < len(current_fragments): + target = current_fragments[idx] + target_len = len(target.content) + combined_len = current_len + target_len + + # 更新最佳合并目标 + if combined_len < min_combined_length and ( + target_len < self.min_chars or # 目标也是短片段 + current_len < target_len # 或当前片段更短 + ): + min_combined_length = combined_len + best_target_idx = idx + + # 执行合并 + if best_target_idx != -1: + if best_target_idx < i: # 与前一个片段合并 + result.pop() # 移除之前添加的片段 + merged = self._merge_segments(current_fragments[best_target_idx], current) + result.append(merged) + else: # 与后一个片段合并 + merged = self._merge_segments(current, current_fragments[best_target_idx]) + result.append(merged) + i += 1 # 跳过下一个片段 + need_merge = True # 标记发生了合并,需要继续检查 + i += 1 + else: + # 如果没找到合适的合并目标,保留当前片段 + result.append(current) + i += 1 + + # 更新当前片段列表 + current_fragments = result + + # 检查是否还需要继续合并 + if not need_merge: + # 最后检查一遍是否还有短片段 + has_short = any(len(f.content) < self.min_chars for f in result) + need_merge = has_short and len(result) > 1 + + # 如果达到最大迭代次数,记录警告 + if iteration_count >= max_iterations: + self.logger.warning(f"Reached maximum iterations ({max_iterations}) in merge_short_segments") + + return current_fragments + + def _merge_where_clauses(self, fragments: List['ArxivFragment']) -> List['ArxivFragment']: + """ + 合并 where 子句规则 + + Args: + fragments: 片段列表 + + Returns: + List[ArxivFragment]: 处理后的片段列表 + """ + if not fragments: + return fragments + + result = [] + i = 0 + while i < len(fragments): + current = fragments[i] + + # 检查是否是 where 子句 + if current.content.strip().lower().startswith('where'): + if result: # 确保有前一个片段可以合并 + merged = self._merge_segments(result.pop(), current) + result.append(merged) + else: + result.append(current) + else: + result.append(current) + i += 1 + + return result + + def _merge_clauses(self, fragments: List['ArxivFragment']) -> List['ArxivFragment']: + """ + 合并从句和连接词规则,确保句子的完整性 + + 处理以下情况: + 1. where/which/that等从句 + 2. 连接词(such that, so that等) + 3. 条件句(if, when等) + 4. 其他常见的数学论文连接词 + + Args: + fragments: 片段列表 + + Returns: + List[ArxivFragment]: 处理后的片段列表 + """ + if not fragments: + return fragments + + # 需要合并的从句和连接词模式 + clause_patterns = [ + # 从句引导词 + r'^(?:where|which|that|whose|when)\b', + # 数学中的连接词 + r'^(?:such\s+that|so\s+that|in\s+which|for\s+which)\b', + # 条件引导词 + r'^(?:if|unless|provided|assuming)\b', + # 其他常见数学连接词 + r'^(?:therefore|thus|hence|consequently|furthermore|moreover)\b', + # 并列连接词 + r'^(?:and|or|but|while|whereas)\b', + # 因果关系词 + r'^(?:because|since|as)\b', + # 时序关系词 + r'^(?:after|before|until|whenever)\b', + # 让步关系词 + r'^(?:although|though|even\s+if|even\s+though)\b', + # 比较关系词 + r'^(?:than|as\s+[.\w]+\s+as)\b', + # 目的关系词 + r'^(?:in\s+order\s+to|so\s+as\s+to)\b', + # 条件关系词组 + r'^(?:on\s+condition\s+that|given\s+that|suppose\s+that)\b', + # 常见数学术语 + r'^(?:denoted\s+by|defined\s+as|written\s+as|expressed\s+as)\b' + ] + # 编译正则表达式模式 + clause_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in clause_patterns] + + def is_clause_start(text: str) -> bool: + """检查文本是否以从句或连接词开始""" + text = text.strip() + return any(pattern.search(text) for pattern in clause_patterns) + + def is_sentence_complete(text: str) -> bool: + """检查句子是否完整(基于简单的标点符号检查)""" + # 检查常见的句子结束符号 + end_markers = ['.', '。', '!', '?', '!', '?'] + # 排除可能的小数点和缩写 + text = text.strip() + if not text: + return False + last_char = text[-1] + if last_char in end_markers: + # 确保不是小数点 + if last_char == '.' and re.search(r'\d\.$', text): + return False + return True + return False + + def should_merge(prev: ArxivFragment, curr: ArxivFragment) -> bool: + """判断两个片段是否应该合并""" + # 检查当前片段是否以从句开始 + if is_clause_start(curr.content): + return True + + # 检查前一个片段是否句子完整 + if not is_sentence_complete(prev.content): + # 如果前一个片段以数学公式结束,检查当前片段是否是其补充说明 + if re.search(r'[\$\)]\\?$', prev.content.strip()): + return True + + # 检查是否存在被截断的括号对 + brackets = { + '(': ')', '[': ']', '{': '}', + r'\{': r'\}', r'\[': r'\]', r'\(': r'\)' + } + for open_b, close_b in brackets.items(): + open_count = prev.content.count(open_b) + close_count = prev.content.count(close_b) + if open_count > close_count: + return True + + return False + + result = [] + i = 0 + while i < len(fragments): + current = fragments[i] + if "which means that the graph convolution adds up all atom features" in current.content: + print("find here") + if not result: + result.append(current) + i += 1 + continue + + prev = result[-1] + if should_merge(prev, current): + # 合并片段,确保不超过最大长度限制 + merged_content = f"{prev.content}\n{current.content}" + if len(current.content) <= self.min_chars: + merged = self._merge_segments(prev, current) + result.pop() # 移除前一个片段 + result.append(merged) # 添加合并后的片段 + else: + # 如果合并后超过长度限制,保持分开 + result.append(current) + else: + result.append(current) + i += 1 + + return result + + # 在TexProcessor类中更新merge_segments方法 + def merge_segments(self, fragments: List['ArxivFragment']) -> List['ArxivFragment']: + """ + 按注册的规则合并片段 + + Args: + fragments: 要合并的片段列表 + + Returns: + List[ArxivFragment]: 合并后的片段列表 + """ + result = fragments + + # 首先处理从句和连接词 + result = self._merge_clauses(result) + + # 然后执行其他合并规则 + for _, rule_func in self.merge_rules: + if rule_func != self._merge_where_clauses: # 跳过旧的where从句处理 + result = rule_func(result) + + return result + diff --git a/crazy_functions/rag_fns/light_rag/core/storage.py b/crazy_functions/rag_fns/light_rag/core/storage.py index eca33386..42ecb838 100644 --- a/crazy_functions/rag_fns/light_rag/core/storage.py +++ b/crazy_functions/rag_fns/light_rag/core/storage.py @@ -13,18 +13,19 @@ from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker T = TypeVar('T') -@dataclass + +@dataclass class StorageBase: """Base class for all storage implementations""" namespace: str working_dir: str - + async def index_done_callback(self): """Hook called after indexing operations""" pass - + async def query_done_callback(self): - """Hook called after query operations""" + """Hook called after query operations""" pass @@ -32,37 +33,37 @@ class StorageBase: class JsonKVStorage(StorageBase, Generic[T]): """ Key-Value storage using JSON files - + Attributes: namespace (str): Storage namespace working_dir (str): Working directory for storage files _file_name (str): JSON file path _data (Dict[str, T]): In-memory storage """ - + def __post_init__(self): """Initialize storage file and load data""" - self._file_name = os.path.join(self.working_dir, f"kv_{self.namespace}.json") + self._file_name = os.path.join(self.working_dir, f"kv_store_{self.namespace}.json") self._data: Dict[str, T] = {} self.load() - + def load(self): """Load data from JSON file""" if os.path.exists(self._file_name): with open(self._file_name, 'r', encoding='utf-8') as f: self._data = json.load(f) logger.info(f"Loaded {len(self._data)} items from {self._file_name}") - + async def save(self): """Save data to JSON file""" os.makedirs(os.path.dirname(self._file_name), exist_ok=True) with open(self._file_name, 'w', encoding='utf-8') as f: json.dump(self._data, f, ensure_ascii=False, indent=2) - + async def get_by_id(self, id: str) -> Optional[T]: """Get item by ID""" return self._data.get(id) - + async def get_by_ids(self, ids: List[str], fields: Optional[Set[str]] = None) -> List[Optional[T]]: """Get multiple items by IDs with optional field filtering""" if fields is None: @@ -70,16 +71,16 @@ class JsonKVStorage(StorageBase, Generic[T]): return [{k: v for k, v in self._data[id].items() if k in fields} if id in self._data else None for id in ids] - + async def filter_keys(self, keys: List[str]) -> Set[str]: """Return keys that don't exist in storage""" return set(k for k in keys if k not in self._data) - + async def upsert(self, data: Dict[str, T]): """Insert or update items""" self._data.update(data) await self.save() - + async def drop(self): """Clear all data""" self._data = {} @@ -95,148 +96,225 @@ class JsonKVStorage(StorageBase, Generic[T]): await self.save() + @dataclass class VectorStorage(StorageBase): """ - Vector storage using LlamaIndex - + Vector storage using LlamaIndexRagWorker + Attributes: - namespace (str): Storage namespace + namespace (str): Storage namespace (e.g., 'entities', 'relationships', 'chunks') working_dir (str): Working directory for storage files llm_kwargs (dict): LLM configuration embedding_func (OpenAiEmbeddingModel): Embedding function - meta_fields (Set[str]): Additional fields to store - cosine_better_than_threshold (float): Similarity threshold + meta_fields (Set[str]): Additional metadata fields to store """ llm_kwargs: dict embedding_func: OpenAiEmbeddingModel meta_fields: Set[str] = field(default_factory=set) - cosine_better_than_threshold: float = 0.2 - + def __post_init__(self): """Initialize LlamaIndex worker""" - checkpoint_dir = os.path.join(self.working_dir, f"vector_{self.namespace}") + # 使用正确的文件命名格式 + self._vector_file = os.path.join(self.working_dir, f"vdb_{self.namespace}.json") + + # 设置检查点目录 + checkpoint_dir = os.path.join(self.working_dir, f"vector_{self.namespace}_checkpoint") + os.makedirs(checkpoint_dir, exist_ok=True) + + # 初始化向量存储 self.vector_store = LlamaIndexRagWorker( user_name=self.namespace, llm_kwargs=self.llm_kwargs, checkpoint_dir=checkpoint_dir, - auto_load_checkpoint=True # 自动加载检查点 + auto_load_checkpoint=True ) - - async def query(self, query: str, top_k: int = 5) -> List[dict]: + logger.info(f"Initialized vector storage for {self.namespace}") + + async def query(self, query: str, top_k: int = 5, metadata_filters: Optional[Dict[str, Any]] = None) -> List[dict]: """ - Query vectors by similarity - + Query vectors by similarity with optional metadata filtering + Args: query: Query text - top_k: Maximum number of results - + top_k: Maximum number of results to return + metadata_filters: Optional metadata filters + Returns: List of similar documents with scores """ - nodes = self.vector_store.retrieve_from_store_with_query(query) - results = [{ - "id": node.node_id, - "text": node.text, - "score": node.score, - **{k: getattr(node, k) for k in self.meta_fields if hasattr(node, k)} - } for node in nodes[:top_k]] - return [r for r in results if r.get('score', 0) > self.cosine_better_than_threshold] - + try: + if metadata_filters: + nodes = self.vector_store.retrieve_with_metadata_filter(query, metadata_filters, top_k) + else: + nodes = self.vector_store.retrieve_from_store_with_query(query)[:top_k] + + results = [] + for node in nodes: + result = { + "id": node.node_id, + "text": node.text, + "score": node.score if hasattr(node, 'score') else 0.0, + } + # Add metadata fields if they exist and are in meta_fields + if hasattr(node, 'metadata'): + result.update({ + k: node.metadata[k] + for k in self.meta_fields + if k in node.metadata + }) + results.append(result) + + return results + + except Exception as e: + logger.error(f"Error in vector query: {e}") + raise + async def upsert(self, data: Dict[str, dict]): """ Insert or update vectors - + Args: - data: Dictionary of documents to insert/update + data: Dictionary of documents to insert/update with format: + {id: {"content": text, "metadata": dict}} """ - for id, item in data.items(): - content = item["content"] - metadata = {k: item[k] for k in self.meta_fields if k in item} - self.vector_store.add_text_with_metadata(content, metadata=metadata) - + try: + for doc_id, item in data.items(): + content = item["content"] + # 提取元数据 + metadata = { + k: item[k] + for k in self.meta_fields + if k in item + } + # 添加文档ID到元数据 + metadata["doc_id"] = doc_id + + # 添加到向量存储 + self.vector_store.add_text_with_metadata(content, metadata) + + # 导出向量数据到json文件 + self.vector_store.export_nodes( + self._vector_file, + format="json", + include_embeddings=True + ) + + except Exception as e: + logger.error(f"Error in vector upsert: {e}") + raise + + async def save(self): + """Save vector store to checkpoint and export data""" + try: + # 保存检查点 + self.vector_store.save_to_checkpoint() + + # 导出向量数据 + self.vector_store.export_nodes( + self._vector_file, + format="json", + include_embeddings=True + ) + except Exception as e: + logger.error(f"Error saving vector storage: {e}") + raise + async def index_done_callback(self): """Save after indexing""" - self.vector_store.save_to_checkpoint() + await self.save() + + def get_statistics(self) -> Dict[str, Any]: + """Get vector store statistics""" + return self.vector_store.get_statistics() @dataclass class NetworkStorage(StorageBase): """ Graph storage using NetworkX - + Attributes: namespace (str): Storage namespace working_dir (str): Working directory for storage files """ - + def __post_init__(self): """Initialize graph and storage file""" self._file_name = os.path.join(self.working_dir, f"graph_{self.namespace}.graphml") self._graph = self._load_graph() or nx.Graph() - + logger.info(f"Initialized graph storage for {self.namespace}") + def _load_graph(self) -> Optional[nx.Graph]: """Load graph from GraphML file""" if os.path.exists(self._file_name): try: - return nx.read_graphml(self._file_name) + graph = nx.read_graphml(self._file_name) + logger.info(f"Loaded graph with {graph.number_of_nodes()} nodes and {graph.number_of_edges()} edges") + return graph except Exception as e: logger.error(f"Error loading graph from {self._file_name}: {e}") return None return None - + async def save_graph(self): """Save graph to GraphML file""" - os.makedirs(os.path.dirname(self._file_name), exist_ok=True) - logger.info(f"Saving graph with {self._graph.number_of_nodes()} nodes, {self._graph.number_of_edges()} edges") - nx.write_graphml(self._graph, self._file_name) - + try: + os.makedirs(os.path.dirname(self._file_name), exist_ok=True) + logger.info( + f"Saving graph with {self._graph.number_of_nodes()} nodes, {self._graph.number_of_edges()} edges") + nx.write_graphml(self._graph, self._file_name) + except Exception as e: + logger.error(f"Error saving graph: {e}") + raise + async def has_node(self, node_id: str) -> bool: """Check if node exists""" return self._graph.has_node(node_id) - + async def has_edge(self, source_id: str, target_id: str) -> bool: """Check if edge exists""" return self._graph.has_edge(source_id, target_id) - + async def get_node(self, node_id: str) -> Optional[dict]: """Get node attributes""" if not self._graph.has_node(node_id): return None return dict(self._graph.nodes[node_id]) - + async def get_edge(self, source_id: str, target_id: str) -> Optional[dict]: """Get edge attributes""" if not self._graph.has_edge(source_id, target_id): return None return dict(self._graph.edges[source_id, target_id]) - + async def node_degree(self, node_id: str) -> int: """Get node degree""" return self._graph.degree(node_id) - + async def edge_degree(self, source_id: str, target_id: str) -> int: """Get sum of degrees of edge endpoints""" return self._graph.degree(source_id) + self._graph.degree(target_id) - + async def get_node_edges(self, source_id: str) -> Optional[List[Tuple[str, str]]]: """Get all edges connected to node""" if not self._graph.has_node(source_id): return None return list(self._graph.edges(source_id)) - + async def upsert_node(self, node_id: str, node_data: Dict[str, str]): """Insert or update node""" - # Clean and normalize node data cleaned_data = {k: html.escape(str(v).upper().strip()) for k, v in node_data.items()} self._graph.add_node(node_id, **cleaned_data) - + await self.save_graph() + async def upsert_edge(self, source_id: str, target_id: str, edge_data: Dict[str, str]): """Insert or update edge""" - # Clean and normalize edge data cleaned_data = {k: html.escape(str(v).strip()) for k, v in edge_data.items()} self._graph.add_edge(source_id, target_id, **cleaned_data) - + await self.save_graph() + async def index_done_callback(self): """Save after indexing""" await self.save_graph() @@ -245,47 +323,47 @@ class NetworkStorage(StorageBase): """Get the largest connected component of the graph""" if not self._graph: return nx.Graph() - + components = list(nx.connected_components(self._graph)) if not components: return nx.Graph() - + largest_component = max(components, key=len) return self._graph.subgraph(largest_component).copy() - - async def embed_nodes(self, algorithm: str, **kwargs) -> Tuple[np.ndarray, List[str]]: - """ - Embed nodes using specified algorithm - - Args: - algorithm: Node embedding algorithm name - **kwargs: Additional algorithm parameters - - Returns: - Tuple of (node embeddings, node IDs) - """ + + async def embed_nodes( + self, + algorithm: str = "node2vec", + dimensions: int = 128, + walk_length: int = 30, + num_walks: int = 200, + workers: int = 4, + window: int = 10, + min_count: int = 1, + **kwargs + ) -> Tuple[np.ndarray, List[str]]: + """Generate node embeddings using specified algorithm""" if algorithm == "node2vec": from node2vec import Node2Vec - - # Create node2vec model - node2vec = Node2Vec( + + # Create and train node2vec model + n2v = Node2Vec( self._graph, - dimensions=kwargs.get('dimensions', 128), - walk_length=kwargs.get('walk_length', 30), - num_walks=kwargs.get('num_walks', 200), - workers=kwargs.get('workers', 4) + dimensions=dimensions, + walk_length=walk_length, + num_walks=num_walks, + workers=workers ) - - # Train model - model = node2vec.fit( - window=kwargs.get('window', 10), - min_count=kwargs.get('min_count', 1) + + model = n2v.fit( + window=window, + min_count=min_count ) - - # Get embeddings + + # Get embeddings for all nodes node_ids = list(self._graph.nodes()) embeddings = np.array([model.wv[node] for node in node_ids]) - + return embeddings, node_ids - else: - raise ValueError(f"Unsupported embedding algorithm: {algorithm}") + + raise ValueError(f"Unsupported embedding algorithm: {algorithm}") \ No newline at end of file diff --git a/crazy_functions/rag_fns/light_rag/example.py b/crazy_functions/rag_fns/light_rag/example.py index 0ee55fe1..00fa323f 100644 --- a/crazy_functions/rag_fns/light_rag/example.py +++ b/crazy_functions/rag_fns/light_rag/example.py @@ -23,25 +23,29 @@ class ExtractionExample: def __init__(self): """Initialize RAG system components""" # 设置工作目录 - self.working_dir = f"private_upload/default_user/rag_cache_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}" + self.working_dir = f"crazy_functions/rag_fns/LightRAG/rag_cache_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}" os.makedirs(self.working_dir, exist_ok=True) logger.info(f"Working directory: {self.working_dir}") # 初始化embedding - self.llm_kwargs = {'api_key': os.getenv("one_api_key"), 'client_ip': '127.0.0.1', - 'embed_model': 'text-embedding-3-small', 'llm_model': 'one-api-Qwen2.5-72B-Instruct', - 'max_length': 4096, 'most_recent_uploaded': None, 'temperature': 1, 'top_p': 1} + self.llm_kwargs = { + 'api_key': os.getenv("one_api_key"), + 'client_ip': '127.0.0.1', + 'embed_model': 'text-embedding-3-small', + 'llm_model': 'one-api-Qwen2.5-72B-Instruct', + 'max_length': 4096, + 'most_recent_uploaded': None, + 'temperature': 1, + 'top_p': 1 + } self.embedding_func = OpenAiEmbeddingModel(self.llm_kwargs) # 初始化提示模板和抽取器 self.prompt_templates = PromptTemplates() self.extractor = EntityRelationExtractor( prompt_templates=self.prompt_templates, - required_prompts = { - 'entity_extraction' - }, + required_prompts={'entity_extraction'}, entity_extract_max_gleaning=1 - ) # 初始化存储系统 @@ -63,18 +67,33 @@ class ExtractionExample: working_dir=self.working_dir ) - # 向量存储 - 用于相似度检索 - self.vector_store = VectorStorage( - namespace="vectors", + # 向量存储 - 用于实体、关系和文本块的向量表示 + self.entities_vdb = VectorStorage( + namespace="entities", working_dir=self.working_dir, llm_kwargs=self.llm_kwargs, embedding_func=self.embedding_func, meta_fields={"entity_name", "entity_type"} ) + self.relationships_vdb = VectorStorage( + namespace="relationships", + working_dir=self.working_dir, + llm_kwargs=self.llm_kwargs, + embedding_func=self.embedding_func, + meta_fields={"src_id", "tgt_id"} + ) + + self.chunks_vdb = VectorStorage( + namespace="chunks", + working_dir=self.working_dir, + llm_kwargs=self.llm_kwargs, + embedding_func=self.embedding_func + ) + # 图存储 - 用于实体关系 self.graph_store = NetworkStorage( - namespace="graph", + namespace="chunk_entity_relation", working_dir=self.working_dir ) @@ -152,7 +171,7 @@ class ExtractionExample: try: # 向量存储 logger.info("Adding chunks to vector store...") - await self.vector_store.upsert(chunks) + await self.chunks_vdb.upsert(chunks) # 初始化对话历史 self.conversation_history = {chunk_key: [] for chunk_key in chunks.keys()} @@ -178,14 +197,32 @@ class ExtractionExample: # 获取结果 nodes, edges = self.extractor.get_results() - # 存储到图数据库 - logger.info("Storing extracted information in graph database...") + # 存储实体到向量数据库和图数据库 for node_name, node_instances in nodes.items(): for node in node_instances: + # 存储到向量数据库 + await self.entities_vdb.upsert({ + f"entity_{node_name}": { + "content": f"{node_name}: {node['description']}", + "entity_name": node_name, + "entity_type": node['entity_type'] + } + }) + # 存储到图数据库 await self.graph_store.upsert_node(node_name, node) + # 存储关系到向量数据库和图数据库 for (src, tgt), edge_instances in edges.items(): for edge in edge_instances: + # 存储到向量数据库 + await self.relationships_vdb.upsert({ + f"rel_{src}_{tgt}": { + "content": f"{edge['description']} | {edge['keywords']}", + "src_id": src, + "tgt_id": tgt + } + }) + # 存储到图数据库 await self.graph_store.upsert_edge(src, tgt, edge) return nodes, edges @@ -197,26 +234,39 @@ class ExtractionExample: async def query_knowledge_base(self, query: str, top_k: int = 5): """Query the knowledge base using various methods""" try: - # 向量相似度搜索 - vector_results = await self.vector_store.query(query, top_k=top_k) + # 向量相似度搜索 - 文本块 + chunk_results = await self.chunks_vdb.query(query, top_k=top_k) + + # 向量相似度搜索 - 实体 + entity_results = await self.entities_vdb.query(query, top_k=top_k) # 获取相关文本块 - chunk_ids = [r["id"] for r in vector_results] + chunk_ids = [r["id"] for r in chunk_results] chunks = await self.text_chunks.get_by_ids(chunk_ids) - # 获取相关实体 - # 假设query中包含实体名称 - relevant_nodes = [] - for word in query.split(): - if await self.graph_store.has_node(word.upper()): - node_data = await self.graph_store.get_node(word.upper()) - if node_data: - relevant_nodes.append(node_data) + # 获取实体相关的图结构信息 + relevant_edges = [] + for entity in entity_results: + if "entity_name" in entity: + entity_name = entity["entity_name"] + if await self.graph_store.has_node(entity_name): + edges = await self.graph_store.get_node_edges(entity_name) + if edges: + edge_data = [] + for edge in edges: + edge_info = await self.graph_store.get_edge(edge[0], edge[1]) + if edge_info: + edge_data.append({ + "source": edge[0], + "target": edge[1], + "data": edge_info + }) + relevant_edges.extend(edge_data) return { - "vector_results": vector_results, - "text_chunks": chunks, - "relevant_entities": relevant_nodes + "chunks": chunks, + "entities": entity_results, + "relationships": relevant_edges } except Exception as e: @@ -228,30 +278,27 @@ class ExtractionExample: os.makedirs(export_dir, exist_ok=True) try: - # 导出向量存储 - self.vector_store.vector_store.export_nodes( - os.path.join(export_dir, "vector_nodes.json"), - include_embeddings=True - ) - - # 导出图数据统计 - graph_stats = { - "total_nodes": len(list(self.graph_store._graph.nodes())), - "total_edges": len(list(self.graph_store._graph.edges())), - "node_degrees": dict(self.graph_store._graph.degree()), - "largest_component_size": len(self.graph_store.get_largest_connected_component()) - } - - with open(os.path.join(export_dir, "graph_stats.json"), "w") as f: - json.dump(graph_stats, f, indent=2) - - # 导出存储统计 + # 导出统计信息 storage_stats = { - "chunks": len(self.text_chunks._data), - "docs": len(self.full_docs._data), - "vector_store": self.vector_store.vector_store.get_statistics() + "chunks": { + "total": len(self.text_chunks._data), + "vector_stats": self.chunks_vdb.get_statistics() + }, + "entities": { + "vector_stats": self.entities_vdb.get_statistics() + }, + "relationships": { + "vector_stats": self.relationships_vdb.get_statistics() + }, + "graph": { + "total_nodes": len(list(self.graph_store._graph.nodes())), + "total_edges": len(list(self.graph_store._graph.edges())), + "node_degrees": dict(self.graph_store._graph.degree()), + "largest_component_size": len(self.graph_store.get_largest_connected_component()) + } } + # 导出统计 with open(os.path.join(export_dir, "storage_stats.json"), "w") as f: json.dump(storage_stats, f, indent=2) @@ -299,19 +346,6 @@ async def main(): the company's commitment to innovation and sustainability. The new iPhone features groundbreaking AI capabilities. """, - - # "business_news": """ - # Microsoft and OpenAI expanded their partnership today. - # Satya Nadella emphasized the importance of AI development while - # Sam Altman discussed the future of large language models. The collaboration - # aims to accelerate AI research and deployment. - # """, - # - # "science_paper": """ - # Researchers at DeepMind published a breakthrough paper on quantum computing. - # The team demonstrated novel approaches to quantum error correction. - # Dr. Sarah Johnson led the research, collaborating with Google's quantum lab. - # """ } try: