文件
gpt_academic/crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py
2024-12-01 23:26:02 +08:00

837 行
30 KiB
Python

import asyncio
import logging
import re
import tarfile
import time
from copy import deepcopy
from pathlib import Path
from typing import List, Optional, Dict, Set
import aiohttp
from crazy_functions.rag_fns.arxiv_fns.author_extractor import LatexAuthorExtractor
from crazy_functions.rag_fns.arxiv_fns.essay_structure import EssayStructureParser, DocumentStructure, read_tex_file
from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment
from crazy_functions.rag_fns.arxiv_fns.tex_utils import TexUtils
from crazy_functions.doc_fns.content_folder import PaperContentFormatter, PaperMetadata
def save_fragments_to_file(fragments: List[SectionFragment], output_dir: Path ) -> Path:
"""
Save all fragments to a single structured markdown file.
Args:
fragments: List of SectionFragment objects
output_dir: Output directory path
Returns:
Path: Path to the generated markdown file
"""
from datetime import datetime
from pathlib import Path
# Create output directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Generate filename
filename = f"paper_latex_content_{timestamp}.md"
file_path = output_path/ filename
# Group fragments by section
sections = {}
for fragment in fragments:
section = fragment.current_section or "Uncategorized"
if section not in sections:
sections[section] = []
sections[section].append(fragment)
with open(file_path, "w", encoding="utf-8") as f:
# Write document header
f.write("# Document Fragments Analysis\n\n")
f.write("## Overview\n")
f.write(f"- Total Fragments: {len(fragments)}\n")
f.write(f"- Generated Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
# Add paper information if available
if fragments and (fragments[0].title or fragments[0].abstract):
f.write("\n## Paper Information\n")
if fragments[0].title:
f.write(f"### Title\n{fragments[0].title}\n")
if fragments[0].authors:
f.write(f"\n### Authors\n{fragments[0].authors}\n")
if fragments[0].abstract:
f.write(f"\n### Abstract\n{fragments[0].abstract}\n")
# Write section tree if available
if fragments and fragments[0].catalogs:
f.write("\n## Section Tree\n")
f.write("```\n") # 添加代码块开始标记
f.write(fragments[0].catalogs)
f.write("\n```") # 添加代码块结束标记
# Generate table of contents
f.write("\n## Table of Contents\n")
for section in sections:
clean_section = section.strip() or "Uncategorized"
fragment_count = len(sections[section])
f.write(f"- [{clean_section}](#{clean_section.lower().replace(' ', '-')}) "
f"({fragment_count} fragments)\n")
# Write content sections
f.write("\n## Content\n")
for section, section_fragments in sections.items():
clean_section = section.strip() or "Uncategorized"
f.write(f"\n### {clean_section}\n")
# Write each fragment
for i, fragment in enumerate(section_fragments, 1):
f.write(f"\n#### Fragment {i}\n")
# Metadata
f.write("**Metadata:**\n")
metadata = [
f"- Section: {fragment.current_section}",
f"- Length: {len(fragment.content)} chars",
f"- ArXiv ID: {fragment.arxiv_id}" if fragment.arxiv_id else None
]
f.write("\n".join(filter(None, metadata)) + "\n")
# Content
f.write("\n**Content:**\n")
f.write("\n")
f.write(fragment.content)
f.write("\n")
# Bibliography if exists
if fragment.bibliography:
f.write("\n**Bibliography:**\n")
f.write("```bibtex\n")
f.write(fragment.bibliography)
f.write("\n```\n")
# Add separator
if i < len(section_fragments):
f.write("\n---\n")
# Add statistics
f.write("\n## Statistics\n")
# Length distribution
lengths = [len(f.content) for f in fragments]
f.write("\n### Length Distribution\n")
f.write(f"- Minimum: {min(lengths)} chars\n")
f.write(f"- Maximum: {max(lengths)} chars\n")
f.write(f"- Average: {sum(lengths) / len(lengths):.1f} chars\n")
# Section distribution
f.write("\n### Section Distribution\n")
for section, section_fragments in sections.items():
percentage = (len(section_fragments) / len(fragments)) * 100
f.write(f"- {section}: {len(section_fragments)} ({percentage:.1f}%)\n")
print(f"Fragments saved to: {file_path}")
return file_path
# 定义各种引用命令的模式
CITATION_PATTERNS = [
# 基本的 \cite{} 格式
r'\\cite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
# natbib 格式
r'\\citep(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
r'\\citet(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
r'\\citeauthor(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
r'\\citeyear(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
r'\\citealt(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
r'\\citealp(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
# biblatex 格式
r'\\textcite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
r'\\parencite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
r'\\autocite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
# 自定义 [cite:...] 格式
r'\[cite:([^\]]+)\]',
]
# 编译所有模式
COMPILED_PATTERNS = [re.compile(pattern) for pattern in CITATION_PATTERNS]
class ArxivSplitter:
"""Arxiv论文智能分割器"""
def __init__(self,
root_dir: str = "gpt_log/arxiv_cache",
proxies: Optional[Dict[str, str]] = None,
cache_ttl: int = 7 * 24 * 60 * 60):
"""
初始化分割器
Args:
char_range: 字符数范围(最小值, 最大值)
root_dir: 缓存根目录
proxies: 代理设置
cache_ttl: 缓存过期时间(秒)
"""
self.root_dir = Path(root_dir)
self.root_dir.mkdir(parents=True, exist_ok=True)
self.proxies = proxies or {}
self.cache_ttl = cache_ttl
# 动态计算最优线程数
import multiprocessing
cpu_count = multiprocessing.cpu_count()
# 根据CPU核心数动态设置,但设置上限防止过度并发
self.document_structure = DocumentStructure()
self.document_parser = EssayStructureParser()
self.max_workers = min(32, cpu_count * 2)
# 初始化TeX处理器
self.tex_processor = TexUtils()
# 配置日志
self._setup_logging()
def _setup_logging(self):
"""配置日志"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
self.logger = logging.getLogger(__name__)
def _normalize_arxiv_id(self, input_str: str) -> str:
"""规范化ArXiv ID"""
if 'arxiv.org/' in input_str.lower():
# 处理URL格式
if '/pdf/' in input_str:
arxiv_id = input_str.split('/pdf/')[-1]
else:
arxiv_id = input_str.split('/abs/')[-1]
# 移除版本号和其他后缀
return arxiv_id.split('v')[0].strip()
return input_str.split('v')[0].strip()
def _check_cache(self, paper_dir: Path) -> bool:
"""
检查缓存是否有效,包括文件完整性检查
Args:
paper_dir: 论文目录路径
Returns:
bool: 如果缓存有效返回True,否则返回False
"""
if not paper_dir.exists():
return False
# 检查目录中是否存在必要文件
has_tex_files = False
has_main_tex = False
for file_path in paper_dir.rglob("*"):
if file_path.suffix == '.tex':
has_tex_files = True
content = self.tex_processor.read_file(str(file_path))
if content and r'\documentclass' in content:
has_main_tex = True
break
if not (has_tex_files and has_main_tex):
return False
# 检查缓存时间
cache_time = paper_dir.stat().st_mtime
if (time.time() - cache_time) < self.cache_ttl:
self.logger.info(f"Using valid cache for {paper_dir.name}")
return True
return False
async def download_paper(self, arxiv_id: str, paper_dir: Path) -> bool:
"""
异步下载论文,包含重试机制和临时文件处理
Args:
arxiv_id: ArXiv论文ID
paper_dir: 目标目录路径
Returns:
bool: 下载成功返回True,否则返回False
"""
from crazy_functions.rag_fns.arxiv_fns.arxiv_downloader import ArxivDownloader
temp_tar_path = paper_dir / f"{arxiv_id}_temp.tar.gz"
final_tar_path = paper_dir / f"{arxiv_id}.tar.gz"
# 确保目录存在
paper_dir.mkdir(parents=True, exist_ok=True)
# 尝试使用 ArxivDownloader 下载
try:
downloader = ArxivDownloader(root_dir=str(paper_dir), proxies=self.proxies)
downloaded_dir = downloader.download_paper(arxiv_id)
if downloaded_dir:
self.logger.info(f"Successfully downloaded using ArxivDownloader to {downloaded_dir}")
return True
except Exception as e:
self.logger.warning(f"ArxivDownloader failed: {str(e)}. Falling back to direct download.")
# 如果 ArxivDownloader 失败,使用原有的下载方式作为备选
urls = [
f"https://arxiv.org/src/{arxiv_id}",
f"https://arxiv.org/e-print/{arxiv_id}"
]
max_retries = 3
retry_delay = 1 # 初始重试延迟(秒)
for url in urls:
for attempt in range(max_retries):
try:
self.logger.info(f"Downloading from {url} (attempt {attempt + 1}/{max_retries})")
async with aiohttp.ClientSession() as session:
async with session.get(url, proxy=self.proxies.get('http')) as response:
if response.status == 200:
content = await response.read()
# 写入临时文件
temp_tar_path.write_bytes(content)
try:
# 验证tar文件完整性并解压
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, self._process_tar_file, temp_tar_path, paper_dir)
# 下载成功后移动临时文件到最终位置
temp_tar_path.rename(final_tar_path)
return True
except Exception as e:
self.logger.warning(f"Invalid tar file: {str(e)}")
if temp_tar_path.exists():
temp_tar_path.unlink()
except Exception as e:
self.logger.warning(f"Download attempt {attempt + 1} failed from {url}: {str(e)}")
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
continue
return False
def _process_tar_file(self, tar_path: Path, extract_path: Path):
"""处理tar文件的同步操作"""
with tarfile.open(tar_path, 'r:gz') as tar:
tar.testall() # 验证文件完整性
tar.extractall(path=extract_path) # 解压文件
def process_references(self, doc_structure: DocumentStructure, ref_bib: str) -> DocumentStructure:
"""
Process citations in document structure and add referenced literature for each section
Args:
doc_structure: DocumentStructure object
ref_bib: String containing references separated by newlines
Returns:
Updated DocumentStructure object
"""
try:
# Create a copy to avoid modifying the original
doc = deepcopy(doc_structure)
# Parse references into a mapping
ref_map = self._parse_references(ref_bib)
if not ref_map:
self.logger.warning("No valid references found in ref_bib")
return doc
# Process all sections recursively
self._process_section_references(doc.toc, ref_map)
return doc
except Exception as e:
self.logger.error(f"Error processing references: {str(e)}")
return doc_structure # Return original if processing fails
def _process_section_references(self, sections: List[Section], ref_map: Dict[str, str]) -> None:
"""
Recursively process sections to add references
Args:
sections: List of Section objects
ref_map: Mapping of citation keys to full references
"""
for section in sections:
if section.content:
# Find citations in current section
cited_refs = self.find_citations(section.content)
if cited_refs:
# Get full references for citations
full_refs = []
for ref_key in cited_refs:
ref_text = ref_map.get(ref_key)
if ref_text:
full_refs.append(ref_text)
else:
self.logger.warning(f"Reference not found for citation key: {ref_key}")
# Add references to section content
if full_refs:
section.bibliography = "\n\n".join(full_refs)
# Process subsections recursively
if section.subsections:
self._process_section_references(section.subsections, ref_map)
def _parse_references(self, ref_bib: str) -> Dict[str, str]:
"""
Parse reference string into a mapping of citation keys to full references
Args:
ref_bib: Reference string with references separated by newlines
Returns:
Dict mapping citation keys to full reference text
"""
ref_map = {}
current_ref = []
current_key = None
try:
for line in ref_bib.split('\n'):
line = line.strip()
if not line:
continue
# New reference entry
if line.startswith('@'):
# Save previous reference if exists
if current_key and current_ref:
ref_map[current_key] = '\n'.join(current_ref)
current_ref = []
# Extract key from new reference
key_match = re.search(r'{(.*?),', line)
if key_match:
current_key = key_match.group(1)
current_ref.append(line)
else:
if current_ref is not None:
current_ref.append(line)
# Save last reference
if current_key and current_ref:
ref_map[current_key] = '\n'.join(current_ref)
except Exception as e:
self.logger.error(f"Error parsing references: {str(e)}")
return ref_map
# 编译一次正则表达式以提高效率
@staticmethod
def _clean_citation_key(key: str) -> str:
"""Clean individual citation key."""
return key.strip().strip(',').strip()
def _extract_keys_from_group(self, keys_str: str) -> Set[str]:
"""Extract and clean individual citation keys from a group."""
try:
# 分割多个引用键(支持逗号和分号分隔)
separators = '[,;]'
keys = re.split(separators, keys_str)
# 清理并过滤空键
return {self._clean_citation_key(k) for k in keys if self._clean_citation_key(k)}
except Exception as e:
self.logger.warning(f"Error processing citation group '{keys_str}': {e}")
return set()
def find_citations(self, content: str) -> Set[str]:
"""
Find citation keys in text content in various formats.
Args:
content: Text content to search for citations
Returns:
Set of unique citation keys
Examples:
Supported formats include:
- \cite{key1,key2}
- \cite[p. 1]{key}
- \citep{key}
- \citet{key}
- [cite:key1, key2]
- And many other variants
"""
citations = set()
if not content:
return citations
try:
# 对每个编译好的模式进行搜索
for pattern in COMPILED_PATTERNS:
matches = pattern.finditer(content)
for match in matches:
# 获取捕获组中的引用键
keys_str = match.group(1)
if keys_str:
# 提取并添加所有引用键
new_keys = self._extract_keys_from_group(keys_str)
citations.update(new_keys)
except Exception as e:
self.logger.error(f"Error finding citations: {str(e)}")
# 移除明显无效的键
citations = {key for key in citations
if key and not key.startswith(('\\', '{', '}', '[', ']'))}
return citations
def get_citation_contexts(self, content: str, context_chars: int = 100) -> dict:
"""
Find citations and their surrounding context.
Args:
content: Text content to search for citations
context_chars: Number of characters of context to include before/after
Returns:
Dict mapping citation keys to lists of context strings
"""
contexts = {}
if not content:
return contexts
try:
for pattern in COMPILED_PATTERNS:
matches = pattern.finditer(content)
for match in matches:
# 获取匹配的位置
start = max(0, match.start() - context_chars)
end = min(len(content), match.end() + context_chars)
# 获取上下文
context = content[start:end]
# 获取并处理引用键
keys_str = match.group(1)
keys = self._extract_keys_from_group(keys_str)
# 为每个键添加上下文
for key in keys:
if key not in contexts:
contexts[key] = []
contexts[key].append(context)
except Exception as e:
self.logger.error(f"Error finding citation contexts: {str(e)}")
return contexts
async def process(self, arxiv_id_or_url: str) -> List[SectionFragment]:
"""
Process ArXiv paper and convert to list of SectionFragments.
Each fragment represents the smallest section unit.
Args:
arxiv_id_or_url: ArXiv paper ID or URL
Returns:
List[SectionFragment]: List of processed paper fragments
"""
try:
arxiv_id = self._normalize_arxiv_id(arxiv_id_or_url)
paper_dir = self.root_dir / arxiv_id
# Check if paper directory exists, if not, try to download
if not paper_dir.exists():
self.logger.info(f"Downloading paper {arxiv_id}")
await self.download_paper(arxiv_id, paper_dir)
# Find main TeX file
main_tex = self.tex_processor.find_main_tex_file(str(paper_dir))
if not main_tex:
raise RuntimeError(f"No main TeX file found in {paper_dir}")
# 读取主 TeX 文件内容
main_tex_content = read_tex_file(main_tex)
# Get all related TeX files and references
tex_files = self.tex_processor.resolve_includes(main_tex)
ref_bib = self.tex_processor.resolve_references(main_tex, paper_dir)
if not tex_files:
raise RuntimeError(f"No valid TeX files found for {arxiv_id}")
# Reset document structure for new processing
self.document_structure = DocumentStructure()
# 提取作者信息
author_extractor = LatexAuthorExtractor()
authors = author_extractor.extract_authors(main_tex_content)
self.document_structure.authors = authors # 保存到文档结构中
# Process each TeX file
for file_path in tex_files:
self.logger.info(f"Processing TeX file: {file_path}")
tex_content = read_tex_file(file_path)
if tex_content:
additional_doc = self.document_parser.parse(tex_content)
self.document_structure = self.document_structure.merge(additional_doc)
# Process references if available
if ref_bib:
self.document_structure = self.process_references(self.document_structure, ref_bib)
self.logger.info("Successfully processed references")
else:
self.logger.info("No references found to process")
# Generate table of contents once
section_tree = self.document_structure.generate_toc_tree()
# Convert DocumentStructure to SectionFragments
fragments = self._convert_to_fragments(
doc_structure=self.document_structure,
arxiv_id=arxiv_id,
section_tree=section_tree
)
return fragments
except Exception as e:
self.logger.error(f"Failed to process {arxiv_id_or_url}: {str(e)}")
raise
def _convert_to_fragments(self,
doc_structure: DocumentStructure,
arxiv_id: str,
section_tree: str) -> List[SectionFragment]:
"""
Convert DocumentStructure to list of SectionFragments.
Creates a fragment for each leaf section in the document hierarchy.
Args:
doc_structure: Source DocumentStructure
arxiv_id: ArXiv paper ID
section_tree: Pre-generated table of contents tree
Returns:
List[SectionFragment]: List of paper fragments
"""
fragments = []
# Create a base template for all fragments to avoid repetitive assignments
base_fragment_template = {
'title': doc_structure.title,
'authors': doc_structure.authors,
'abstract': doc_structure.abstract,
'catalogs': section_tree,
'arxiv_id': arxiv_id
}
def get_leaf_sections(section: Section, path: List[str] = None) -> None:
"""
Recursively find all leaf sections and create fragments.
A leaf section is one that has content but no subsections, or has neither.
Args:
section: Current section being processed
path: List of section titles forming the path to current section
"""
if path is None:
path = []
current_path = path + [section.title]
if not section.subsections:
# This is a leaf section, create a fragment if it has content
if section.content or section.bibliography:
fragment = SectionFragment(
**base_fragment_template,
current_section="/".join(current_path),
content=self._clean_content(section.content),
bibliography=section.bibliography
)
if self._validate_fragment(fragment):
fragments.append(fragment)
else:
# Process each subsection
for subsection in section.subsections:
get_leaf_sections(subsection, current_path)
# Process all top-level sections
for section in doc_structure.toc:
get_leaf_sections(section)
# Add a fragment for the abstract if it exists
if doc_structure.abstract:
abstract_fragment = SectionFragment(
**base_fragment_template,
current_section="Abstract",
content=self._clean_content(doc_structure.abstract)
)
if self._validate_fragment(abstract_fragment):
fragments.insert(0, abstract_fragment)
self.logger.info(f"Created {len(fragments)} fragments")
return fragments
def _validate_fragment(self, fragment: SectionFragment) -> bool:
"""
Validate if the fragment has all required fields with meaningful content.
Args:
fragment: SectionFragment to validate
Returns:
bool: True if fragment is valid, False otherwise
"""
try:
return all([
fragment.title.strip(),
fragment.catalogs.strip(),
fragment.current_section.strip(),
fragment.content.strip() or fragment.bibliography.strip()
])
except AttributeError:
return False
def _clean_content(self, content: str) -> str:
"""
Clean and normalize content text.
Args:
content: Raw content text
Returns:
str: Cleaned content text
"""
if not content:
return ""
# Remove excessive whitespace
content = re.sub(r'\s+', ' ', content)
# Remove remaining LaTeX artifacts
content = re.sub(r'\\item\s*', '', content) # Convert \item to bullet points
content = re.sub(r'\\[a-zA-Z]+\{([^}]*)\}', r'\1', content) # Remove simple LaTeX commands
# Clean special characters
content = content.replace('\\\\', '\n') # Convert LaTeX newlines to actual newlines
content = re.sub(r'\s*\n\s*', '\n', content) # Clean up newlines
return content.strip()
def process_arxiv_sync(splitter: ArxivSplitter, arxiv_id: str) -> tuple[List[SectionFragment], str, List[Path]]:
"""
同步处理 ArXiv 文档并返回分割后的片段
Args:
splitter: ArxivSplitter 实例
arxiv_id: ArXiv 文档ID
Returns:
list: 分割后的文档片段列表
"""
try:
from crazy_functions.doc_fns.tex_html_formatter import PaperHtmlFormatter
# 创建一个异步函数来执行异步操作
async def _process():
return await splitter.process(arxiv_id)
# 使用 asyncio.run() 运行异步函数
output_files=[]
fragments = asyncio.run(_process())
file_save_path = splitter.root_dir / "arxiv_fragments"
# 保存片段到文件
try:
md_output_dir = save_fragments_to_file(
fragments,
output_dir = file_save_path
)
output_files.append(md_output_dir)
except:
pass
# 创建论文格式化器
formatter = PaperContentFormatter()
# 准备元数据
# 创建格式化选项
metadata = PaperMetadata(
title=fragments[0].title,
authors=fragments[0].authors,
abstract=fragments[0].abstract,
catalogs=fragments[0].catalogs,
arxiv_id=fragments[0].arxiv_id
)
# 格式化内容
formatted_content = formatter.format(fragments, metadata)
try:
html_formatter = PaperHtmlFormatter(fragments, file_save_path)
html_output_dir = html_formatter.save_html()
output_files.append(html_output_dir)
except:
pass
return fragments, formatted_content, output_files
except Exception as e:
print(f"✗ Processing failed for {arxiv_id}: {str(e)}")
raise
def test_arxiv_splitter():
"""测试ArXiv分割器的功能"""
# 测试配置
test_cases = [
{
"arxiv_id": "2411.03663",
"expected_title": "Large Language Models and Simple Scripts",
"min_fragments": 10,
},
# {
# "arxiv_id": "1805.10988",
# "expected_title": "RAG vs Fine-tuning",
# "min_fragments": 15,
# }
]
# 创建分割器实例
splitter = ArxivSplitter(
root_dir="private_upload/default_user"
)
for case in test_cases:
print(f"\nTesting paper: {case['arxiv_id']}")
try:
# fragments = await splitter.process(case['arxiv_id'])
fragments, formatted_content, output_dir = process_arxiv_sync(splitter, case['arxiv_id'])
# 保存fragments
for fragment in fragments:
# 长度检查
print((fragment.content))
print(len(fragment.content))
# 类型检查
print(output_dir)
except Exception as e:
print(f"✗ Test failed for {case['arxiv_id']}: {str(e)}")
raise
if __name__ == "__main__":
test_arxiv_splitter()