import asyncio import logging import re import tarfile import time from copy import deepcopy from pathlib import Path from typing import List, Optional, Dict, Set import aiohttp from crazy_functions.rag_fns.arxiv_fns.author_extractor import LatexAuthorExtractor from crazy_functions.rag_fns.arxiv_fns.essay_structure import EssayStructureParser, DocumentStructure, read_tex_file from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment from crazy_functions.rag_fns.arxiv_fns.tex_utils import TexUtils from crazy_functions.doc_fns.content_folder import PaperContentFormatter, PaperMetadata def save_fragments_to_file(fragments: List[SectionFragment], output_dir: Path ) -> Path: """ Save all fragments to a single structured markdown file. Args: fragments: List of SectionFragment objects output_dir: Output directory path Returns: Path: Path to the generated markdown file """ from datetime import datetime from pathlib import Path # Create output directory timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) # Generate filename filename = f"paper_latex_content_{timestamp}.md" file_path = output_path/ filename # Group fragments by section sections = {} for fragment in fragments: section = fragment.current_section or "Uncategorized" if section not in sections: sections[section] = [] sections[section].append(fragment) with open(file_path, "w", encoding="utf-8") as f: # Write document header f.write("# Document Fragments Analysis\n\n") f.write("## Overview\n") f.write(f"- Total Fragments: {len(fragments)}\n") f.write(f"- Generated Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") # Add paper information if available if fragments and (fragments[0].title or fragments[0].abstract): f.write("\n## Paper Information\n") if fragments[0].title: f.write(f"### Title\n{fragments[0].title}\n") if fragments[0].authors: f.write(f"\n### Authors\n{fragments[0].authors}\n") if fragments[0].abstract: f.write(f"\n### Abstract\n{fragments[0].abstract}\n") # Write section tree if available if fragments and fragments[0].catalogs: f.write("\n## Section Tree\n") f.write("```\n") # 添加代码块开始标记 f.write(fragments[0].catalogs) f.write("\n```") # 添加代码块结束标记 # Generate table of contents f.write("\n## Table of Contents\n") for section in sections: clean_section = section.strip() or "Uncategorized" fragment_count = len(sections[section]) f.write(f"- [{clean_section}](#{clean_section.lower().replace(' ', '-')}) " f"({fragment_count} fragments)\n") # Write content sections f.write("\n## Content\n") for section, section_fragments in sections.items(): clean_section = section.strip() or "Uncategorized" f.write(f"\n### {clean_section}\n") # Write each fragment for i, fragment in enumerate(section_fragments, 1): f.write(f"\n#### Fragment {i}\n") # Metadata f.write("**Metadata:**\n") metadata = [ f"- Section: {fragment.current_section}", f"- Length: {len(fragment.content)} chars", f"- ArXiv ID: {fragment.arxiv_id}" if fragment.arxiv_id else None ] f.write("\n".join(filter(None, metadata)) + "\n") # Content f.write("\n**Content:**\n") f.write("\n") f.write(fragment.content) f.write("\n") # Bibliography if exists if fragment.bibliography: f.write("\n**Bibliography:**\n") f.write("```bibtex\n") f.write(fragment.bibliography) f.write("\n```\n") # Add separator if i < len(section_fragments): f.write("\n---\n") # Add statistics f.write("\n## Statistics\n") # Length distribution lengths = [len(f.content) for f in fragments] f.write("\n### Length Distribution\n") f.write(f"- Minimum: {min(lengths)} chars\n") f.write(f"- Maximum: {max(lengths)} chars\n") f.write(f"- Average: {sum(lengths) / len(lengths):.1f} chars\n") # Section distribution f.write("\n### Section Distribution\n") for section, section_fragments in sections.items(): percentage = (len(section_fragments) / len(fragments)) * 100 f.write(f"- {section}: {len(section_fragments)} ({percentage:.1f}%)\n") print(f"Fragments saved to: {file_path}") return file_path # 定义各种引用命令的模式 CITATION_PATTERNS = [ # 基本的 \cite{} 格式 r'\\cite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}', # natbib 格式 r'\\citep(?:\*)?(?:\[[^\]]*\])?{([^}]+)}', r'\\citet(?:\*)?(?:\[[^\]]*\])?{([^}]+)}', r'\\citeauthor(?:\*)?(?:\[[^\]]*\])?{([^}]+)}', r'\\citeyear(?:\*)?(?:\[[^\]]*\])?{([^}]+)}', r'\\citealt(?:\*)?(?:\[[^\]]*\])?{([^}]+)}', r'\\citealp(?:\*)?(?:\[[^\]]*\])?{([^}]+)}', # biblatex 格式 r'\\textcite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}', r'\\parencite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}', r'\\autocite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}', # 自定义 [cite:...] 格式 r'\[cite:([^\]]+)\]', ] # 编译所有模式 COMPILED_PATTERNS = [re.compile(pattern) for pattern in CITATION_PATTERNS] class ArxivSplitter: """Arxiv论文智能分割器""" def __init__(self, root_dir: str = "gpt_log/arxiv_cache", proxies: Optional[Dict[str, str]] = None, cache_ttl: int = 7 * 24 * 60 * 60): """ 初始化分割器 Args: char_range: 字符数范围(最小值, 最大值) root_dir: 缓存根目录 proxies: 代理设置 cache_ttl: 缓存过期时间(秒) """ self.root_dir = Path(root_dir) self.root_dir.mkdir(parents=True, exist_ok=True) self.proxies = proxies or {} self.cache_ttl = cache_ttl # 动态计算最优线程数 import multiprocessing cpu_count = multiprocessing.cpu_count() # 根据CPU核心数动态设置,但设置上限防止过度并发 self.document_structure = DocumentStructure() self.document_parser = EssayStructureParser() self.max_workers = min(32, cpu_count * 2) # 初始化TeX处理器 self.tex_processor = TexUtils() # 配置日志 self._setup_logging() def _setup_logging(self): """配置日志""" logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) self.logger = logging.getLogger(__name__) def _normalize_arxiv_id(self, input_str: str) -> str: """规范化ArXiv ID""" if 'arxiv.org/' in input_str.lower(): # 处理URL格式 if '/pdf/' in input_str: arxiv_id = input_str.split('/pdf/')[-1] else: arxiv_id = input_str.split('/abs/')[-1] # 移除版本号和其他后缀 return arxiv_id.split('v')[0].strip() return input_str.split('v')[0].strip() def _check_cache(self, paper_dir: Path) -> bool: """ 检查缓存是否有效,包括文件完整性检查 Args: paper_dir: 论文目录路径 Returns: bool: 如果缓存有效返回True,否则返回False """ if not paper_dir.exists(): return False # 检查目录中是否存在必要文件 has_tex_files = False has_main_tex = False for file_path in paper_dir.rglob("*"): if file_path.suffix == '.tex': has_tex_files = True content = self.tex_processor.read_file(str(file_path)) if content and r'\documentclass' in content: has_main_tex = True break if not (has_tex_files and has_main_tex): return False # 检查缓存时间 cache_time = paper_dir.stat().st_mtime if (time.time() - cache_time) < self.cache_ttl: self.logger.info(f"Using valid cache for {paper_dir.name}") return True return False async def download_paper(self, arxiv_id: str, paper_dir: Path) -> bool: """ 异步下载论文,包含重试机制和临时文件处理 Args: arxiv_id: ArXiv论文ID paper_dir: 目标目录路径 Returns: bool: 下载成功返回True,否则返回False """ from crazy_functions.rag_fns.arxiv_fns.arxiv_downloader import ArxivDownloader temp_tar_path = paper_dir / f"{arxiv_id}_temp.tar.gz" final_tar_path = paper_dir / f"{arxiv_id}.tar.gz" # 确保目录存在 paper_dir.mkdir(parents=True, exist_ok=True) # 尝试使用 ArxivDownloader 下载 try: downloader = ArxivDownloader(root_dir=str(paper_dir), proxies=self.proxies) downloaded_dir = downloader.download_paper(arxiv_id) if downloaded_dir: self.logger.info(f"Successfully downloaded using ArxivDownloader to {downloaded_dir}") return True except Exception as e: self.logger.warning(f"ArxivDownloader failed: {str(e)}. Falling back to direct download.") # 如果 ArxivDownloader 失败,使用原有的下载方式作为备选 urls = [ f"https://arxiv.org/src/{arxiv_id}", f"https://arxiv.org/e-print/{arxiv_id}" ] max_retries = 3 retry_delay = 1 # 初始重试延迟(秒) for url in urls: for attempt in range(max_retries): try: self.logger.info(f"Downloading from {url} (attempt {attempt + 1}/{max_retries})") async with aiohttp.ClientSession() as session: async with session.get(url, proxy=self.proxies.get('http')) as response: if response.status == 200: content = await response.read() # 写入临时文件 temp_tar_path.write_bytes(content) try: # 验证tar文件完整性并解压 loop = asyncio.get_event_loop() await loop.run_in_executor(None, self._process_tar_file, temp_tar_path, paper_dir) # 下载成功后移动临时文件到最终位置 temp_tar_path.rename(final_tar_path) return True except Exception as e: self.logger.warning(f"Invalid tar file: {str(e)}") if temp_tar_path.exists(): temp_tar_path.unlink() except Exception as e: self.logger.warning(f"Download attempt {attempt + 1} failed from {url}: {str(e)}") await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避 continue return False def _process_tar_file(self, tar_path: Path, extract_path: Path): """处理tar文件的同步操作""" with tarfile.open(tar_path, 'r:gz') as tar: tar.testall() # 验证文件完整性 tar.extractall(path=extract_path) # 解压文件 def process_references(self, doc_structure: DocumentStructure, ref_bib: str) -> DocumentStructure: """ Process citations in document structure and add referenced literature for each section Args: doc_structure: DocumentStructure object ref_bib: String containing references separated by newlines Returns: Updated DocumentStructure object """ try: # Create a copy to avoid modifying the original doc = deepcopy(doc_structure) # Parse references into a mapping ref_map = self._parse_references(ref_bib) if not ref_map: self.logger.warning("No valid references found in ref_bib") return doc # Process all sections recursively self._process_section_references(doc.toc, ref_map) return doc except Exception as e: self.logger.error(f"Error processing references: {str(e)}") return doc_structure # Return original if processing fails def _process_section_references(self, sections: List[Section], ref_map: Dict[str, str]) -> None: """ Recursively process sections to add references Args: sections: List of Section objects ref_map: Mapping of citation keys to full references """ for section in sections: if section.content: # Find citations in current section cited_refs = self.find_citations(section.content) if cited_refs: # Get full references for citations full_refs = [] for ref_key in cited_refs: ref_text = ref_map.get(ref_key) if ref_text: full_refs.append(ref_text) else: self.logger.warning(f"Reference not found for citation key: {ref_key}") # Add references to section content if full_refs: section.bibliography = "\n\n".join(full_refs) # Process subsections recursively if section.subsections: self._process_section_references(section.subsections, ref_map) def _parse_references(self, ref_bib: str) -> Dict[str, str]: """ Parse reference string into a mapping of citation keys to full references Args: ref_bib: Reference string with references separated by newlines Returns: Dict mapping citation keys to full reference text """ ref_map = {} current_ref = [] current_key = None try: for line in ref_bib.split('\n'): line = line.strip() if not line: continue # New reference entry if line.startswith('@'): # Save previous reference if exists if current_key and current_ref: ref_map[current_key] = '\n'.join(current_ref) current_ref = [] # Extract key from new reference key_match = re.search(r'{(.*?),', line) if key_match: current_key = key_match.group(1) current_ref.append(line) else: if current_ref is not None: current_ref.append(line) # Save last reference if current_key and current_ref: ref_map[current_key] = '\n'.join(current_ref) except Exception as e: self.logger.error(f"Error parsing references: {str(e)}") return ref_map # 编译一次正则表达式以提高效率 @staticmethod def _clean_citation_key(key: str) -> str: """Clean individual citation key.""" return key.strip().strip(',').strip() def _extract_keys_from_group(self, keys_str: str) -> Set[str]: """Extract and clean individual citation keys from a group.""" try: # 分割多个引用键(支持逗号和分号分隔) separators = '[,;]' keys = re.split(separators, keys_str) # 清理并过滤空键 return {self._clean_citation_key(k) for k in keys if self._clean_citation_key(k)} except Exception as e: self.logger.warning(f"Error processing citation group '{keys_str}': {e}") return set() def find_citations(self, content: str) -> Set[str]: """ Find citation keys in text content in various formats. Args: content: Text content to search for citations Returns: Set of unique citation keys Examples: Supported formats include: - \cite{key1,key2} - \cite[p. 1]{key} - \citep{key} - \citet{key} - [cite:key1, key2] - And many other variants """ citations = set() if not content: return citations try: # 对每个编译好的模式进行搜索 for pattern in COMPILED_PATTERNS: matches = pattern.finditer(content) for match in matches: # 获取捕获组中的引用键 keys_str = match.group(1) if keys_str: # 提取并添加所有引用键 new_keys = self._extract_keys_from_group(keys_str) citations.update(new_keys) except Exception as e: self.logger.error(f"Error finding citations: {str(e)}") # 移除明显无效的键 citations = {key for key in citations if key and not key.startswith(('\\', '{', '}', '[', ']'))} return citations def get_citation_contexts(self, content: str, context_chars: int = 100) -> dict: """ Find citations and their surrounding context. Args: content: Text content to search for citations context_chars: Number of characters of context to include before/after Returns: Dict mapping citation keys to lists of context strings """ contexts = {} if not content: return contexts try: for pattern in COMPILED_PATTERNS: matches = pattern.finditer(content) for match in matches: # 获取匹配的位置 start = max(0, match.start() - context_chars) end = min(len(content), match.end() + context_chars) # 获取上下文 context = content[start:end] # 获取并处理引用键 keys_str = match.group(1) keys = self._extract_keys_from_group(keys_str) # 为每个键添加上下文 for key in keys: if key not in contexts: contexts[key] = [] contexts[key].append(context) except Exception as e: self.logger.error(f"Error finding citation contexts: {str(e)}") return contexts async def process(self, arxiv_id_or_url: str) -> List[SectionFragment]: """ Process ArXiv paper and convert to list of SectionFragments. Each fragment represents the smallest section unit. Args: arxiv_id_or_url: ArXiv paper ID or URL Returns: List[SectionFragment]: List of processed paper fragments """ try: arxiv_id = self._normalize_arxiv_id(arxiv_id_or_url) paper_dir = self.root_dir / arxiv_id # Check if paper directory exists, if not, try to download if not paper_dir.exists(): self.logger.info(f"Downloading paper {arxiv_id}") await self.download_paper(arxiv_id, paper_dir) # Find main TeX file main_tex = self.tex_processor.find_main_tex_file(str(paper_dir)) if not main_tex: raise RuntimeError(f"No main TeX file found in {paper_dir}") # 读取主 TeX 文件内容 main_tex_content = read_tex_file(main_tex) # Get all related TeX files and references tex_files = self.tex_processor.resolve_includes(main_tex) ref_bib = self.tex_processor.resolve_references(main_tex, paper_dir) if not tex_files: raise RuntimeError(f"No valid TeX files found for {arxiv_id}") # Reset document structure for new processing self.document_structure = DocumentStructure() # 提取作者信息 author_extractor = LatexAuthorExtractor() authors = author_extractor.extract_authors(main_tex_content) self.document_structure.authors = authors # 保存到文档结构中 # Process each TeX file for file_path in tex_files: self.logger.info(f"Processing TeX file: {file_path}") tex_content = read_tex_file(file_path) if tex_content: additional_doc = self.document_parser.parse(tex_content) self.document_structure = self.document_structure.merge(additional_doc) # Process references if available if ref_bib: self.document_structure = self.process_references(self.document_structure, ref_bib) self.logger.info("Successfully processed references") else: self.logger.info("No references found to process") # Generate table of contents once section_tree = self.document_structure.generate_toc_tree() # Convert DocumentStructure to SectionFragments fragments = self._convert_to_fragments( doc_structure=self.document_structure, arxiv_id=arxiv_id, section_tree=section_tree ) return fragments except Exception as e: self.logger.error(f"Failed to process {arxiv_id_or_url}: {str(e)}") raise def _convert_to_fragments(self, doc_structure: DocumentStructure, arxiv_id: str, section_tree: str) -> List[SectionFragment]: """ Convert DocumentStructure to list of SectionFragments. Creates a fragment for each leaf section in the document hierarchy. Args: doc_structure: Source DocumentStructure arxiv_id: ArXiv paper ID section_tree: Pre-generated table of contents tree Returns: List[SectionFragment]: List of paper fragments """ fragments = [] # Create a base template for all fragments to avoid repetitive assignments base_fragment_template = { 'title': doc_structure.title, 'authors': doc_structure.authors, 'abstract': doc_structure.abstract, 'catalogs': section_tree, 'arxiv_id': arxiv_id } def get_leaf_sections(section: Section, path: List[str] = None) -> None: """ Recursively find all leaf sections and create fragments. A leaf section is one that has content but no subsections, or has neither. Args: section: Current section being processed path: List of section titles forming the path to current section """ if path is None: path = [] current_path = path + [section.title] if not section.subsections: # This is a leaf section, create a fragment if it has content if section.content or section.bibliography: fragment = SectionFragment( **base_fragment_template, current_section="/".join(current_path), content=self._clean_content(section.content), bibliography=section.bibliography ) if self._validate_fragment(fragment): fragments.append(fragment) else: # Process each subsection for subsection in section.subsections: get_leaf_sections(subsection, current_path) # Process all top-level sections for section in doc_structure.toc: get_leaf_sections(section) # Add a fragment for the abstract if it exists if doc_structure.abstract: abstract_fragment = SectionFragment( **base_fragment_template, current_section="Abstract", content=self._clean_content(doc_structure.abstract) ) if self._validate_fragment(abstract_fragment): fragments.insert(0, abstract_fragment) self.logger.info(f"Created {len(fragments)} fragments") return fragments def _validate_fragment(self, fragment: SectionFragment) -> bool: """ Validate if the fragment has all required fields with meaningful content. Args: fragment: SectionFragment to validate Returns: bool: True if fragment is valid, False otherwise """ try: return all([ fragment.title.strip(), fragment.catalogs.strip(), fragment.current_section.strip(), fragment.content.strip() or fragment.bibliography.strip() ]) except AttributeError: return False def _clean_content(self, content: str) -> str: """ Clean and normalize content text. Args: content: Raw content text Returns: str: Cleaned content text """ if not content: return "" # Remove excessive whitespace content = re.sub(r'\s+', ' ', content) # Remove remaining LaTeX artifacts content = re.sub(r'\\item\s*', '• ', content) # Convert \item to bullet points content = re.sub(r'\\[a-zA-Z]+\{([^}]*)\}', r'\1', content) # Remove simple LaTeX commands # Clean special characters content = content.replace('\\\\', '\n') # Convert LaTeX newlines to actual newlines content = re.sub(r'\s*\n\s*', '\n', content) # Clean up newlines return content.strip() def process_arxiv_sync(splitter: ArxivSplitter, arxiv_id: str) -> tuple[List[SectionFragment], str, List[Path]]: """ 同步处理 ArXiv 文档并返回分割后的片段 Args: splitter: ArxivSplitter 实例 arxiv_id: ArXiv 文档ID Returns: list: 分割后的文档片段列表 """ try: from crazy_functions.doc_fns.tex_html_formatter import PaperHtmlFormatter # 创建一个异步函数来执行异步操作 async def _process(): return await splitter.process(arxiv_id) # 使用 asyncio.run() 运行异步函数 output_files=[] fragments = asyncio.run(_process()) file_save_path = splitter.root_dir / "arxiv_fragments" # 保存片段到文件 try: md_output_dir = save_fragments_to_file( fragments, output_dir = file_save_path ) output_files.append(md_output_dir) except: pass # 创建论文格式化器 formatter = PaperContentFormatter() # 准备元数据 # 创建格式化选项 metadata = PaperMetadata( title=fragments[0].title, authors=fragments[0].authors, abstract=fragments[0].abstract, catalogs=fragments[0].catalogs, arxiv_id=fragments[0].arxiv_id ) # 格式化内容 formatted_content = formatter.format(fragments, metadata) try: html_formatter = PaperHtmlFormatter(fragments, file_save_path) html_output_dir = html_formatter.save_html() output_files.append(html_output_dir) except: pass return fragments, formatted_content, output_files except Exception as e: print(f"✗ Processing failed for {arxiv_id}: {str(e)}") raise def test_arxiv_splitter(): """测试ArXiv分割器的功能""" # 测试配置 test_cases = [ { "arxiv_id": "2411.03663", "expected_title": "Large Language Models and Simple Scripts", "min_fragments": 10, }, # { # "arxiv_id": "1805.10988", # "expected_title": "RAG vs Fine-tuning", # "min_fragments": 15, # } ] # 创建分割器实例 splitter = ArxivSplitter( root_dir="private_upload/default_user" ) for case in test_cases: print(f"\nTesting paper: {case['arxiv_id']}") try: # fragments = await splitter.process(case['arxiv_id']) fragments, formatted_content, output_dir = process_arxiv_sync(splitter, case['arxiv_id']) # 保存fragments for fragment in fragments: # 长度检查 print((fragment.content)) print(len(fragment.content)) # 类型检查 print(output_dir) except Exception as e: print(f"✗ Test failed for {case['arxiv_id']}: {str(e)}") raise if __name__ == "__main__": test_arxiv_splitter()