镜像自地址
https://github.com/binary-husky/gpt_academic.git
已同步 2025-12-08 07:26:48 +00:00
up
这个提交包含在:
@@ -140,11 +140,31 @@ class ArxivRagWorker:
|
|||||||
self.rag_worker.add_text_to_vector_store(overview_text)
|
self.rag_worker.add_text_to_vector_store(overview_text)
|
||||||
logger.info(f"Added paper overview for {overview['arxiv_id']}")
|
logger.info(f"Added paper overview for {overview['arxiv_id']}")
|
||||||
|
|
||||||
# 并行处理其余片段
|
# 创建线程池
|
||||||
tasks = []
|
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||||
for i, fragment in enumerate(fragments):
|
# 使用 asyncio.gather 收集所有任务
|
||||||
tasks.append(self._process_single_fragment(fragment, i))
|
loop = asyncio.get_event_loop()
|
||||||
await asyncio.gather(*tasks)
|
tasks = [
|
||||||
|
loop.run_in_executor(
|
||||||
|
executor,
|
||||||
|
self._process_single_fragment,
|
||||||
|
fragment,
|
||||||
|
i
|
||||||
|
)
|
||||||
|
for i, fragment in enumerate(fragments)
|
||||||
|
]
|
||||||
|
|
||||||
|
# 等待所有任务完成
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# 处理结果和异常
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
logger.error(f"Error processing fragment {i}: {result}")
|
||||||
|
else:
|
||||||
|
# 处理成功的结果
|
||||||
|
pass
|
||||||
|
|
||||||
logger.info(f"Processed {len(fragments)} fragments successfully")
|
logger.info(f"Processed {len(fragments)} fragments successfully")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,772 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
import requests
|
||||||
|
import tarfile
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
from typing import Generator, List, Tuple, Optional, Dict, Set
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from crazy_functions.rag_fns.arxiv_fns.tex_utils import TexUtils
|
||||||
|
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment
|
||||||
|
from crazy_functions.rag_fns.arxiv_fns.essay_structure import EssayStructureParser, DocumentStructure, read_tex_file
|
||||||
|
from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section
|
||||||
|
|
||||||
|
|
||||||
|
def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "fragment_outputs") -> Path:
|
||||||
|
"""
|
||||||
|
Save all fragments to a single structured markdown file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fragments: List of SectionFragment objects
|
||||||
|
output_dir: Output directory path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path: Path to the generated markdown file
|
||||||
|
"""
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Create output directory
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
output_path = Path(output_dir)
|
||||||
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Generate filename
|
||||||
|
filename = f"fragments_{timestamp}.md"
|
||||||
|
file_path = output_path / filename
|
||||||
|
|
||||||
|
# Group fragments by section
|
||||||
|
sections = {}
|
||||||
|
for fragment in fragments:
|
||||||
|
section = fragment.current_section or "Uncategorized"
|
||||||
|
if section not in sections:
|
||||||
|
sections[section] = []
|
||||||
|
sections[section].append(fragment)
|
||||||
|
|
||||||
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
|
# Write document header
|
||||||
|
f.write("# Document Fragments Analysis\n\n")
|
||||||
|
f.write("## Overview\n")
|
||||||
|
f.write(f"- Total Fragments: {len(fragments)}\n")
|
||||||
|
f.write(f"- Generated Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||||
|
|
||||||
|
# Add paper information if available
|
||||||
|
if fragments and (fragments[0].title or fragments[0].abstract):
|
||||||
|
f.write("\n## Paper Information\n")
|
||||||
|
if fragments[0].title:
|
||||||
|
f.write(f"### Title\n{fragments[0].title}\n")
|
||||||
|
if fragments[0].abstract:
|
||||||
|
f.write(f"\n### Abstract\n{fragments[0].abstract}\n")
|
||||||
|
|
||||||
|
# Write section tree if available
|
||||||
|
if fragments and fragments[0].section_tree:
|
||||||
|
f.write("\n## Section Tree\n")
|
||||||
|
f.write(fragments[0].section_tree)
|
||||||
|
|
||||||
|
# Generate table of contents
|
||||||
|
f.write("\n## Table of Contents\n")
|
||||||
|
for section in sections:
|
||||||
|
clean_section = section.strip() or "Uncategorized"
|
||||||
|
fragment_count = len(sections[section])
|
||||||
|
f.write(f"- [{clean_section}](#{clean_section.lower().replace(' ', '-')}) "
|
||||||
|
f"({fragment_count} fragments)\n")
|
||||||
|
|
||||||
|
# Write content sections
|
||||||
|
f.write("\n## Content\n")
|
||||||
|
for section, section_fragments in sections.items():
|
||||||
|
clean_section = section.strip() or "Uncategorized"
|
||||||
|
f.write(f"\n### {clean_section}\n")
|
||||||
|
|
||||||
|
# Write each fragment
|
||||||
|
for i, fragment in enumerate(section_fragments, 1):
|
||||||
|
f.write(f"\n#### Fragment {i}\n")
|
||||||
|
|
||||||
|
# Metadata
|
||||||
|
f.write("**Metadata:**\n")
|
||||||
|
metadata = [
|
||||||
|
f"- Section: {fragment.current_section}",
|
||||||
|
f"- Length: {len(fragment.content)} chars",
|
||||||
|
f"- ArXiv ID: {fragment.arxiv_id}" if fragment.arxiv_id else None
|
||||||
|
]
|
||||||
|
f.write("\n".join(filter(None, metadata)) + "\n")
|
||||||
|
|
||||||
|
# Content
|
||||||
|
f.write("\n**Content:**\n")
|
||||||
|
f.write("```tex\n")
|
||||||
|
f.write(fragment.content)
|
||||||
|
f.write("\n```\n")
|
||||||
|
|
||||||
|
# Bibliography if exists
|
||||||
|
if fragment.bibliography:
|
||||||
|
f.write("\n**Bibliography:**\n")
|
||||||
|
f.write("```bibtex\n")
|
||||||
|
f.write(fragment.bibliography)
|
||||||
|
f.write("\n```\n")
|
||||||
|
|
||||||
|
# Add separator
|
||||||
|
if i < len(section_fragments):
|
||||||
|
f.write("\n---\n")
|
||||||
|
|
||||||
|
# Add statistics
|
||||||
|
f.write("\n## Statistics\n")
|
||||||
|
|
||||||
|
# Length distribution
|
||||||
|
lengths = [len(f.content) for f in fragments]
|
||||||
|
f.write("\n### Length Distribution\n")
|
||||||
|
f.write(f"- Minimum: {min(lengths)} chars\n")
|
||||||
|
f.write(f"- Maximum: {max(lengths)} chars\n")
|
||||||
|
f.write(f"- Average: {sum(lengths) / len(lengths):.1f} chars\n")
|
||||||
|
|
||||||
|
# Section distribution
|
||||||
|
f.write("\n### Section Distribution\n")
|
||||||
|
for section, section_fragments in sections.items():
|
||||||
|
percentage = (len(section_fragments) / len(fragments)) * 100
|
||||||
|
f.write(f"- {section}: {len(section_fragments)} ({percentage:.1f}%)\n")
|
||||||
|
|
||||||
|
print(f"Fragments saved to: {file_path}")
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
# 定义各种引用命令的模式
|
||||||
|
CITATION_PATTERNS = [
|
||||||
|
# 基本的 \cite{} 格式
|
||||||
|
r'\\cite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||||
|
# natbib 格式
|
||||||
|
r'\\citep(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||||
|
r'\\citet(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||||
|
r'\\citeauthor(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||||
|
r'\\citeyear(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||||
|
r'\\citealt(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||||
|
r'\\citealp(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||||
|
# biblatex 格式
|
||||||
|
r'\\textcite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||||
|
r'\\parencite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||||
|
r'\\autocite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
|
||||||
|
# 自定义 [cite:...] 格式
|
||||||
|
r'\[cite:([^\]]+)\]',
|
||||||
|
]
|
||||||
|
|
||||||
|
# 编译所有模式
|
||||||
|
COMPILED_PATTERNS = [re.compile(pattern) for pattern in CITATION_PATTERNS]
|
||||||
|
|
||||||
|
|
||||||
|
class ArxivSplitter:
|
||||||
|
"""Arxiv论文智能分割器"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
root_dir: str = "gpt_log/arxiv_cache",
|
||||||
|
proxies: Optional[Dict[str, str]] = None,
|
||||||
|
cache_ttl: int = 7 * 24 * 60 * 60):
|
||||||
|
"""
|
||||||
|
初始化分割器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
char_range: 字符数范围(最小值, 最大值)
|
||||||
|
root_dir: 缓存根目录
|
||||||
|
proxies: 代理设置
|
||||||
|
cache_ttl: 缓存过期时间(秒)
|
||||||
|
"""
|
||||||
|
self.root_dir = Path(root_dir)
|
||||||
|
self.root_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.proxies = proxies or {}
|
||||||
|
self.cache_ttl = cache_ttl
|
||||||
|
|
||||||
|
# 动态计算最优线程数
|
||||||
|
import multiprocessing
|
||||||
|
cpu_count = multiprocessing.cpu_count()
|
||||||
|
# 根据CPU核心数动态设置,但设置上限防止过度并发
|
||||||
|
self.document_structure = DocumentStructure()
|
||||||
|
self.document_parser = EssayStructureParser()
|
||||||
|
|
||||||
|
self.max_workers = min(32, cpu_count * 2)
|
||||||
|
|
||||||
|
# 初始化TeX处理器
|
||||||
|
self.tex_processor = TexUtils()
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
self._setup_logging()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_logging(self):
|
||||||
|
"""配置日志"""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def _normalize_arxiv_id(self, input_str: str) -> str:
|
||||||
|
"""规范化ArXiv ID"""
|
||||||
|
if 'arxiv.org/' in input_str.lower():
|
||||||
|
# 处理URL格式
|
||||||
|
if '/pdf/' in input_str:
|
||||||
|
arxiv_id = input_str.split('/pdf/')[-1]
|
||||||
|
else:
|
||||||
|
arxiv_id = input_str.split('/abs/')[-1]
|
||||||
|
# 移除版本号和其他后缀
|
||||||
|
return arxiv_id.split('v')[0].strip()
|
||||||
|
return input_str.split('v')[0].strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _check_cache(self, paper_dir: Path) -> bool:
|
||||||
|
"""
|
||||||
|
检查缓存是否有效,包括文件完整性检查
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paper_dir: 论文目录路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 如果缓存有效返回True,否则返回False
|
||||||
|
"""
|
||||||
|
if not paper_dir.exists():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查目录中是否存在必要文件
|
||||||
|
has_tex_files = False
|
||||||
|
has_main_tex = False
|
||||||
|
|
||||||
|
for file_path in paper_dir.rglob("*"):
|
||||||
|
if file_path.suffix == '.tex':
|
||||||
|
has_tex_files = True
|
||||||
|
content = self.tex_processor.read_file(str(file_path))
|
||||||
|
if content and r'\documentclass' in content:
|
||||||
|
has_main_tex = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not (has_tex_files and has_main_tex):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查缓存时间
|
||||||
|
cache_time = paper_dir.stat().st_mtime
|
||||||
|
if (time.time() - cache_time) < self.cache_ttl:
|
||||||
|
self.logger.info(f"Using valid cache for {paper_dir.name}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def download_paper(self, arxiv_id: str, paper_dir: Path) -> bool:
|
||||||
|
"""
|
||||||
|
异步下载论文,包含重试机制和临时文件处理
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arxiv_id: ArXiv论文ID
|
||||||
|
paper_dir: 目标目录路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 下载成功返回True,否则返回False
|
||||||
|
"""
|
||||||
|
from crazy_functions.rag_fns.arxiv_fns.arxiv_downloader import ArxivDownloader
|
||||||
|
temp_tar_path = paper_dir / f"{arxiv_id}_temp.tar.gz"
|
||||||
|
final_tar_path = paper_dir / f"{arxiv_id}.tar.gz"
|
||||||
|
|
||||||
|
# 确保目录存在
|
||||||
|
paper_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 尝试使用 ArxivDownloader 下载
|
||||||
|
try:
|
||||||
|
downloader = ArxivDownloader(root_dir=str(paper_dir), proxies=self.proxies)
|
||||||
|
downloaded_dir = downloader.download_paper(arxiv_id)
|
||||||
|
if downloaded_dir:
|
||||||
|
self.logger.info(f"Successfully downloaded using ArxivDownloader to {downloaded_dir}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"ArxivDownloader failed: {str(e)}. Falling back to direct download.")
|
||||||
|
|
||||||
|
# 如果 ArxivDownloader 失败,使用原有的下载方式作为备选
|
||||||
|
urls = [
|
||||||
|
f"https://arxiv.org/src/{arxiv_id}",
|
||||||
|
f"https://arxiv.org/e-print/{arxiv_id}"
|
||||||
|
]
|
||||||
|
|
||||||
|
max_retries = 3
|
||||||
|
retry_delay = 1 # 初始重试延迟(秒)
|
||||||
|
|
||||||
|
for url in urls:
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
self.logger.info(f"Downloading from {url} (attempt {attempt + 1}/{max_retries})")
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(url, proxy=self.proxies.get('http')) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
content = await response.read()
|
||||||
|
|
||||||
|
# 写入临时文件
|
||||||
|
temp_tar_path.write_bytes(content)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 验证tar文件完整性并解压
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
await loop.run_in_executor(None, self._process_tar_file, temp_tar_path, paper_dir)
|
||||||
|
|
||||||
|
# 下载成功后移动临时文件到最终位置
|
||||||
|
temp_tar_path.rename(final_tar_path)
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Invalid tar file: {str(e)}")
|
||||||
|
if temp_tar_path.exists():
|
||||||
|
temp_tar_path.unlink()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Download attempt {attempt + 1} failed from {url}: {str(e)}")
|
||||||
|
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
|
||||||
|
continue
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _process_tar_file(self, tar_path: Path, extract_path: Path):
|
||||||
|
"""处理tar文件的同步操作"""
|
||||||
|
with tarfile.open(tar_path, 'r:gz') as tar:
|
||||||
|
tar.testall() # 验证文件完整性
|
||||||
|
tar.extractall(path=extract_path) # 解压文件
|
||||||
|
|
||||||
|
def process_references(self, doc_structure: DocumentStructure, ref_bib: str) -> DocumentStructure:
|
||||||
|
"""
|
||||||
|
Process citations in document structure and add referenced literature for each section
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc_structure: DocumentStructure object
|
||||||
|
ref_bib: String containing references separated by newlines
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated DocumentStructure object
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Create a copy to avoid modifying the original
|
||||||
|
doc = deepcopy(doc_structure)
|
||||||
|
|
||||||
|
# Parse references into a mapping
|
||||||
|
ref_map = self._parse_references(ref_bib)
|
||||||
|
if not ref_map:
|
||||||
|
self.logger.warning("No valid references found in ref_bib")
|
||||||
|
return doc
|
||||||
|
|
||||||
|
# Process all sections recursively
|
||||||
|
self._process_section_references(doc.toc, ref_map)
|
||||||
|
|
||||||
|
return doc
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error processing references: {str(e)}")
|
||||||
|
return doc_structure # Return original if processing fails
|
||||||
|
|
||||||
|
def _process_section_references(self, sections: List[Section], ref_map: Dict[str, str]) -> None:
|
||||||
|
"""
|
||||||
|
Recursively process sections to add references
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sections: List of Section objects
|
||||||
|
ref_map: Mapping of citation keys to full references
|
||||||
|
"""
|
||||||
|
for section in sections:
|
||||||
|
if section.content:
|
||||||
|
# Find citations in current section
|
||||||
|
cited_refs = self.find_citations(section.content)
|
||||||
|
|
||||||
|
if cited_refs:
|
||||||
|
# Get full references for citations
|
||||||
|
full_refs = []
|
||||||
|
for ref_key in cited_refs:
|
||||||
|
ref_text = ref_map.get(ref_key)
|
||||||
|
if ref_text:
|
||||||
|
full_refs.append(ref_text)
|
||||||
|
else:
|
||||||
|
self.logger.warning(f"Reference not found for citation key: {ref_key}")
|
||||||
|
|
||||||
|
# Add references to section content
|
||||||
|
if full_refs:
|
||||||
|
section.bibliography = "\n\n".join(full_refs)
|
||||||
|
|
||||||
|
# Process subsections recursively
|
||||||
|
if section.subsections:
|
||||||
|
self._process_section_references(section.subsections, ref_map)
|
||||||
|
|
||||||
|
def _parse_references(self, ref_bib: str) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Parse reference string into a mapping of citation keys to full references
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ref_bib: Reference string with references separated by newlines
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping citation keys to full reference text
|
||||||
|
"""
|
||||||
|
ref_map = {}
|
||||||
|
current_ref = []
|
||||||
|
current_key = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
for line in ref_bib.split('\n'):
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# New reference entry
|
||||||
|
if line.startswith('@'):
|
||||||
|
# Save previous reference if exists
|
||||||
|
if current_key and current_ref:
|
||||||
|
ref_map[current_key] = '\n'.join(current_ref)
|
||||||
|
current_ref = []
|
||||||
|
|
||||||
|
# Extract key from new reference
|
||||||
|
key_match = re.search(r'{(.*?),', line)
|
||||||
|
if key_match:
|
||||||
|
current_key = key_match.group(1)
|
||||||
|
current_ref.append(line)
|
||||||
|
else:
|
||||||
|
if current_ref is not None:
|
||||||
|
current_ref.append(line)
|
||||||
|
|
||||||
|
# Save last reference
|
||||||
|
if current_key and current_ref:
|
||||||
|
ref_map[current_key] = '\n'.join(current_ref)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error parsing references: {str(e)}")
|
||||||
|
|
||||||
|
return ref_map
|
||||||
|
|
||||||
|
# 编译一次正则表达式以提高效率
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _clean_citation_key(key: str) -> str:
|
||||||
|
"""Clean individual citation key."""
|
||||||
|
return key.strip().strip(',').strip()
|
||||||
|
|
||||||
|
def _extract_keys_from_group(self, keys_str: str) -> Set[str]:
|
||||||
|
"""Extract and clean individual citation keys from a group."""
|
||||||
|
try:
|
||||||
|
# 分割多个引用键(支持逗号和分号分隔)
|
||||||
|
separators = '[,;]'
|
||||||
|
keys = re.split(separators, keys_str)
|
||||||
|
# 清理并过滤空键
|
||||||
|
return {self._clean_citation_key(k) for k in keys if self._clean_citation_key(k)}
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Error processing citation group '{keys_str}': {e}")
|
||||||
|
return set()
|
||||||
|
|
||||||
|
def find_citations(self, content: str) -> Set[str]:
|
||||||
|
"""
|
||||||
|
Find citation keys in text content in various formats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Text content to search for citations
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of unique citation keys
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Supported formats include:
|
||||||
|
- \cite{key1,key2}
|
||||||
|
- \cite[p. 1]{key}
|
||||||
|
- \citep{key}
|
||||||
|
- \citet{key}
|
||||||
|
- [cite:key1, key2]
|
||||||
|
- And many other variants
|
||||||
|
"""
|
||||||
|
citations = set()
|
||||||
|
|
||||||
|
if not content:
|
||||||
|
return citations
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 对每个编译好的模式进行搜索
|
||||||
|
for pattern in COMPILED_PATTERNS:
|
||||||
|
matches = pattern.finditer(content)
|
||||||
|
for match in matches:
|
||||||
|
# 获取捕获组中的引用键
|
||||||
|
keys_str = match.group(1)
|
||||||
|
if keys_str:
|
||||||
|
# 提取并添加所有引用键
|
||||||
|
new_keys = self._extract_keys_from_group(keys_str)
|
||||||
|
citations.update(new_keys)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error finding citations: {str(e)}")
|
||||||
|
|
||||||
|
# 移除明显无效的键
|
||||||
|
citations = {key for key in citations
|
||||||
|
if key and not key.startswith(('\\', '{', '}', '[', ']'))}
|
||||||
|
|
||||||
|
return citations
|
||||||
|
|
||||||
|
def get_citation_contexts(self, content: str, context_chars: int = 100) -> dict:
|
||||||
|
"""
|
||||||
|
Find citations and their surrounding context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Text content to search for citations
|
||||||
|
context_chars: Number of characters of context to include before/after
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping citation keys to lists of context strings
|
||||||
|
"""
|
||||||
|
contexts = {}
|
||||||
|
|
||||||
|
if not content:
|
||||||
|
return contexts
|
||||||
|
|
||||||
|
try:
|
||||||
|
for pattern in COMPILED_PATTERNS:
|
||||||
|
matches = pattern.finditer(content)
|
||||||
|
for match in matches:
|
||||||
|
# 获取匹配的位置
|
||||||
|
start = max(0, match.start() - context_chars)
|
||||||
|
end = min(len(content), match.end() + context_chars)
|
||||||
|
|
||||||
|
# 获取上下文
|
||||||
|
context = content[start:end]
|
||||||
|
|
||||||
|
# 获取并处理引用键
|
||||||
|
keys_str = match.group(1)
|
||||||
|
keys = self._extract_keys_from_group(keys_str)
|
||||||
|
|
||||||
|
# 为每个键添加上下文
|
||||||
|
for key in keys:
|
||||||
|
if key not in contexts:
|
||||||
|
contexts[key] = []
|
||||||
|
contexts[key].append(context)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error finding citation contexts: {str(e)}")
|
||||||
|
|
||||||
|
return contexts
|
||||||
|
async def process(self, arxiv_id_or_url: str) -> List[SectionFragment]:
|
||||||
|
"""
|
||||||
|
Process ArXiv paper and convert to list of SectionFragments.
|
||||||
|
Each fragment represents the smallest section unit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arxiv_id_or_url: ArXiv paper ID or URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[SectionFragment]: List of processed paper fragments
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
arxiv_id = self._normalize_arxiv_id(arxiv_id_or_url)
|
||||||
|
paper_dir = self.root_dir / arxiv_id
|
||||||
|
|
||||||
|
# Check if paper directory exists, if not, try to download
|
||||||
|
if not paper_dir.exists():
|
||||||
|
self.logger.info(f"Downloading paper {arxiv_id}")
|
||||||
|
await self.download_paper(arxiv_id, paper_dir)
|
||||||
|
|
||||||
|
# Find main TeX file
|
||||||
|
main_tex = self.tex_processor.find_main_tex_file(str(paper_dir))
|
||||||
|
if not main_tex:
|
||||||
|
raise RuntimeError(f"No main TeX file found in {paper_dir}")
|
||||||
|
|
||||||
|
# Get all related TeX files and references
|
||||||
|
tex_files = self.tex_processor.resolve_includes(main_tex)
|
||||||
|
ref_bib = self.tex_processor.resolve_references(main_tex, paper_dir)
|
||||||
|
|
||||||
|
if not tex_files:
|
||||||
|
raise RuntimeError(f"No valid TeX files found for {arxiv_id}")
|
||||||
|
|
||||||
|
# Reset document structure for new processing
|
||||||
|
self.document_structure = DocumentStructure()
|
||||||
|
|
||||||
|
# Process each TeX file
|
||||||
|
for file_path in tex_files:
|
||||||
|
self.logger.info(f"Processing TeX file: {file_path}")
|
||||||
|
tex_content = read_tex_file(file_path)
|
||||||
|
if tex_content:
|
||||||
|
additional_doc = self.document_parser.parse(tex_content)
|
||||||
|
self.document_structure = self.document_structure.merge(additional_doc)
|
||||||
|
|
||||||
|
# Process references if available
|
||||||
|
if ref_bib:
|
||||||
|
self.document_structure = self.process_references(self.document_structure, ref_bib)
|
||||||
|
self.logger.info("Successfully processed references")
|
||||||
|
else:
|
||||||
|
self.logger.info("No references found to process")
|
||||||
|
|
||||||
|
# Generate table of contents once
|
||||||
|
section_tree = self.document_structure.generate_toc_tree()
|
||||||
|
|
||||||
|
# Convert DocumentStructure to SectionFragments
|
||||||
|
fragments = self._convert_to_fragments(
|
||||||
|
doc_structure=self.document_structure,
|
||||||
|
arxiv_id=arxiv_id,
|
||||||
|
section_tree=section_tree
|
||||||
|
)
|
||||||
|
|
||||||
|
return fragments
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to process {arxiv_id_or_url}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _convert_to_fragments(self,
|
||||||
|
doc_structure: DocumentStructure,
|
||||||
|
arxiv_id: str,
|
||||||
|
section_tree: str) -> List[SectionFragment]:
|
||||||
|
"""
|
||||||
|
Convert DocumentStructure to list of SectionFragments.
|
||||||
|
Creates a fragment for each leaf section in the document hierarchy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc_structure: Source DocumentStructure
|
||||||
|
arxiv_id: ArXiv paper ID
|
||||||
|
section_tree: Pre-generated table of contents tree
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[SectionFragment]: List of paper fragments
|
||||||
|
"""
|
||||||
|
fragments = []
|
||||||
|
|
||||||
|
# Create a base template for all fragments to avoid repetitive assignments
|
||||||
|
base_fragment_template = {
|
||||||
|
'title': doc_structure.title,
|
||||||
|
'abstract': doc_structure.abstract,
|
||||||
|
'section_tree': section_tree,
|
||||||
|
'arxiv_id': arxiv_id
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_leaf_sections(section: Section, path: List[str] = None) -> None:
|
||||||
|
"""
|
||||||
|
Recursively find all leaf sections and create fragments.
|
||||||
|
A leaf section is one that has content but no subsections, or has neither.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
section: Current section being processed
|
||||||
|
path: List of section titles forming the path to current section
|
||||||
|
"""
|
||||||
|
if path is None:
|
||||||
|
path = []
|
||||||
|
|
||||||
|
current_path = path + [section.title]
|
||||||
|
|
||||||
|
if not section.subsections:
|
||||||
|
# This is a leaf section, create a fragment if it has content
|
||||||
|
if section.content or section.bibliography:
|
||||||
|
fragment = SectionFragment(
|
||||||
|
**base_fragment_template,
|
||||||
|
current_section="/".join(current_path),
|
||||||
|
content=self._clean_content(section.content),
|
||||||
|
bibliography=section.bibliography
|
||||||
|
)
|
||||||
|
if self._validate_fragment(fragment):
|
||||||
|
fragments.append(fragment)
|
||||||
|
else:
|
||||||
|
# Process each subsection
|
||||||
|
for subsection in section.subsections:
|
||||||
|
get_leaf_sections(subsection, current_path)
|
||||||
|
|
||||||
|
# Process all top-level sections
|
||||||
|
for section in doc_structure.toc:
|
||||||
|
get_leaf_sections(section)
|
||||||
|
|
||||||
|
# Add a fragment for the abstract if it exists
|
||||||
|
if doc_structure.abstract:
|
||||||
|
abstract_fragment = SectionFragment(
|
||||||
|
**base_fragment_template,
|
||||||
|
current_section="Abstract",
|
||||||
|
content=self._clean_content(doc_structure.abstract)
|
||||||
|
)
|
||||||
|
if self._validate_fragment(abstract_fragment):
|
||||||
|
fragments.insert(0, abstract_fragment)
|
||||||
|
|
||||||
|
self.logger.info(f"Created {len(fragments)} fragments")
|
||||||
|
return fragments
|
||||||
|
|
||||||
|
def _validate_fragment(self, fragment: SectionFragment) -> bool:
|
||||||
|
"""
|
||||||
|
Validate if the fragment has all required fields with meaningful content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fragment: SectionFragment to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if fragment is valid, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return all([
|
||||||
|
fragment.title.strip(),
|
||||||
|
fragment.section_tree.strip(),
|
||||||
|
fragment.current_section.strip(),
|
||||||
|
fragment.content.strip() or fragment.bibliography.strip()
|
||||||
|
])
|
||||||
|
except AttributeError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _clean_content(self, content: str) -> str:
|
||||||
|
"""
|
||||||
|
Clean and normalize content text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Raw content text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Cleaned content text
|
||||||
|
"""
|
||||||
|
if not content:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Remove excessive whitespace
|
||||||
|
content = re.sub(r'\s+', ' ', content)
|
||||||
|
|
||||||
|
# Remove remaining LaTeX artifacts
|
||||||
|
content = re.sub(r'\\item\s*', '• ', content) # Convert \item to bullet points
|
||||||
|
content = re.sub(r'\\[a-zA-Z]+\{([^}]*)\}', r'\1', content) # Remove simple LaTeX commands
|
||||||
|
|
||||||
|
# Clean special characters
|
||||||
|
content = content.replace('\\\\', '\n') # Convert LaTeX newlines to actual newlines
|
||||||
|
content = re.sub(r'\s*\n\s*', '\n', content) # Clean up newlines
|
||||||
|
|
||||||
|
return content.strip()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_arxiv_splitter():
|
||||||
|
"""测试ArXiv分割器的功能"""
|
||||||
|
|
||||||
|
# 测试配置
|
||||||
|
test_cases = [
|
||||||
|
{
|
||||||
|
"arxiv_id": "2411.03663",
|
||||||
|
"expected_title": "Large Language Models and Simple Scripts",
|
||||||
|
"min_fragments": 10,
|
||||||
|
},
|
||||||
|
# {
|
||||||
|
# "arxiv_id": "1805.10988",
|
||||||
|
# "expected_title": "RAG vs Fine-tuning",
|
||||||
|
# "min_fragments": 15,
|
||||||
|
# }
|
||||||
|
]
|
||||||
|
|
||||||
|
# 创建分割器实例
|
||||||
|
splitter = ArxivSplitter(
|
||||||
|
root_dir="test_cache"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
for case in test_cases:
|
||||||
|
print(f"\nTesting paper: {case['arxiv_id']}")
|
||||||
|
try:
|
||||||
|
fragments = await splitter.process(case['arxiv_id'])
|
||||||
|
|
||||||
|
# 保存fragments
|
||||||
|
output_dir = save_fragments_to_file(fragments,output_dir="crazy_functions/rag_fns/arxiv_fns/gpt_log")
|
||||||
|
print(f"Output saved to: {output_dir}")
|
||||||
|
# 内容检查
|
||||||
|
for fragment in fragments:
|
||||||
|
# 长度检查
|
||||||
|
|
||||||
|
print((fragment.content))
|
||||||
|
print(len(fragment.content))
|
||||||
|
# 类型检查
|
||||||
|
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ Test failed for {case['arxiv_id']}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(test_arxiv_splitter())
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Set, Dict, Pattern, Optional
|
from typing import Set, Dict, Pattern, Optional, List, Tuple
|
||||||
import re
|
import re
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import logging
|
import logging
|
||||||
@@ -8,179 +8,259 @@ from functools import lru_cache
|
|||||||
|
|
||||||
class EnvType(Enum):
|
class EnvType(Enum):
|
||||||
"""Environment classification types."""
|
"""Environment classification types."""
|
||||||
PRESERVE = "preserve"
|
PRESERVE = "preserve" # Preserve complete environment including commands
|
||||||
REMOVE = "remove"
|
REMOVE = "remove" # Remove environment completely
|
||||||
EXTRACT = "extract"
|
EXTRACT = "extract" # Extract and clean content
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LatexConfig:
|
class LatexConfig:
|
||||||
"""Configuration for LaTeX processing."""
|
"""Configuration for LaTeX processing."""
|
||||||
preserve_envs: Set[str] = field(default_factory=lambda: {
|
preserve_envs: Set[str] = field(default_factory=lambda: {
|
||||||
# Math environments
|
# Math environments - preserve complete content
|
||||||
'equation', 'equation*', 'align', 'align*', 'displaymath',
|
'equation', 'equation*', 'align', 'align*', 'displaymath',
|
||||||
'math', 'eqnarray', 'gather', 'gather*', 'multline', 'multline*',
|
'math', 'eqnarray', 'eqnarray*', 'gather', 'gather*',
|
||||||
# Tables and figures
|
'multline', 'multline*', 'flalign', 'flalign*',
|
||||||
|
'alignat', 'alignat*', 'cases', 'split', 'aligned',
|
||||||
|
# Tables and figures - preserve structure and content
|
||||||
'table', 'table*', 'tabular', 'tabularx', 'array', 'matrix',
|
'table', 'table*', 'tabular', 'tabularx', 'array', 'matrix',
|
||||||
'figure', 'figure*', 'subfigure',
|
'figure', 'figure*', 'subfigure', 'wrapfigure',
|
||||||
|
'minipage', 'tabbing', 'verbatim', 'longtable',
|
||||||
|
'sidewaystable', 'sidewaysfigure', 'floatrow',
|
||||||
|
# Arrays and matrices
|
||||||
|
'pmatrix', 'bmatrix', 'Bmatrix', 'vmatrix', 'Vmatrix',
|
||||||
|
'smallmatrix', 'array', 'matrix*', 'pmatrix*', 'bmatrix*',
|
||||||
# Algorithms and code
|
# Algorithms and code
|
||||||
'algorithm', 'algorithmic', 'lstlisting',
|
'algorithm', 'algorithmic', 'lstlisting', 'verbatim',
|
||||||
|
'minted', 'listing', 'algorithmic*', 'algorithm2e',
|
||||||
# Theorems and proofs
|
# Theorems and proofs
|
||||||
'theorem', 'proof', 'definition', 'lemma', 'corollary',
|
'theorem', 'proof', 'definition', 'lemma', 'corollary',
|
||||||
'proposition', 'example', 'remark'
|
'proposition', 'example', 'remark', 'note', 'claim',
|
||||||
|
'axiom', 'property', 'assumption', 'conjecture', 'observation',
|
||||||
|
# Bibliography
|
||||||
|
'thebibliography', 'bibliography', 'references'
|
||||||
|
})
|
||||||
|
|
||||||
|
# 引用类命令的特殊处理配置
|
||||||
|
citation_commands: Set[str] = field(default_factory=lambda: {
|
||||||
|
# Basic citations
|
||||||
|
'cite', 'citep', 'citet', 'citeyear', 'citeauthor',
|
||||||
|
'citeyearpar', 'citetext', 'citenum',
|
||||||
|
# Natbib citations
|
||||||
|
'citefullauthor', 'citealp', 'citealt', 'citename',
|
||||||
|
'citepalias', 'citetalias', 'citetext',
|
||||||
|
# Cross-references
|
||||||
|
'ref', 'eqref', 'pageref', 'autoref', 'nameref', 'cref',
|
||||||
|
'Cref', 'vref', 'Vref', 'fref', 'pref',
|
||||||
|
# Hyperref
|
||||||
|
'hyperref', 'href', 'url',
|
||||||
|
# Labels
|
||||||
|
'label', 'tag'
|
||||||
})
|
})
|
||||||
|
|
||||||
preserve_commands: Set[str] = field(default_factory=lambda: {
|
preserve_commands: Set[str] = field(default_factory=lambda: {
|
||||||
# Citations and references
|
|
||||||
'caption', 'label', 'ref', 'cite', 'citep', 'citet', 'eqref',
|
|
||||||
# Text formatting
|
# Text formatting
|
||||||
'emph', 'textbf', 'textit', 'underline', 'texttt', 'footnote',
|
'emph', 'textbf', 'textit', 'underline', 'texttt', 'footnote',
|
||||||
'section', 'subsection', 'subsubsection', 'paragraph',
|
'section', 'subsection', 'subsubsection', 'paragraph', 'part',
|
||||||
# Math operators
|
'chapter', 'title', 'author', 'date', 'thanks',
|
||||||
'frac', 'sum', 'int', 'prod', 'lim', 'sup', 'inf'
|
# Math operators and symbols
|
||||||
|
'frac', 'sum', 'int', 'prod', 'lim', 'sup', 'inf',
|
||||||
|
'partial', 'nabla', 'implies', 'iff', 'therefore',
|
||||||
|
'exists', 'forall', 'in', 'subset', 'subseteq',
|
||||||
|
# Greek letters and math symbols
|
||||||
|
'alpha', 'beta', 'gamma', 'delta', 'epsilon', 'zeta',
|
||||||
|
'eta', 'theta', 'iota', 'kappa', 'lambda', 'mu',
|
||||||
|
'nu', 'xi', 'pi', 'rho', 'sigma', 'tau',
|
||||||
|
'upsilon', 'phi', 'chi', 'psi', 'omega',
|
||||||
|
'Gamma', 'Delta', 'Theta', 'Lambda', 'Xi', 'Pi',
|
||||||
|
'Sigma', 'Upsilon', 'Phi', 'Psi', 'Omega',
|
||||||
|
# Math commands
|
||||||
|
'left', 'right', 'big', 'Big', 'bigg', 'Bigg',
|
||||||
|
'mathbf', 'mathit', 'mathsf', 'mathtt', 'mathbb',
|
||||||
|
'mathcal', 'mathfrak', 'mathscr', 'mathrm', 'mathop',
|
||||||
|
'operatorname', 'overline', 'underline', 'overbrace',
|
||||||
|
'underbrace', 'overset', 'underset', 'stackrel',
|
||||||
|
# Spacing and alignment
|
||||||
|
'quad', 'qquad', 'hspace', 'vspace', 'medskip',
|
||||||
|
'bigskip', 'smallskip', 'hfill', 'vfill', 'centering',
|
||||||
|
'raggedright', 'raggedleft'
|
||||||
})
|
})
|
||||||
|
|
||||||
remove_commands: Set[str] = field(default_factory=lambda: {
|
remove_commands: Set[str] = field(default_factory=lambda: {
|
||||||
# Document setup
|
# Document setup
|
||||||
'documentclass', 'usepackage', 'input', 'include', 'includeonly',
|
'documentclass', 'usepackage', 'input', 'include', 'includeonly',
|
||||||
'bibliography', 'bibliographystyle', 'frontmatter', 'mainmatter',
|
'bibliographystyle', 'frontmatter', 'mainmatter',
|
||||||
'newtheorem', 'theoremstyle', 'proof', 'proofname', 'qed',
|
'newtheorem', 'theoremstyle', 'proofname',
|
||||||
'newcommand', 'renewcommand', 'providecommand', 'DeclareMathOperator',
|
'newcommand', 'renewcommand', 'providecommand', 'DeclareMathOperator',
|
||||||
'newenvironment',
|
'newenvironment',
|
||||||
# Layout and spacing
|
# Layout and spacing
|
||||||
'pagestyle', 'thispagestyle', 'vspace', 'hspace', 'vfill', 'hfill',
|
'pagestyle', 'thispagestyle', 'newpage', 'clearpage',
|
||||||
'newpage', 'clearpage', 'pagebreak', 'linebreak', 'newline',
|
'pagebreak', 'linebreak', 'newline', 'setlength',
|
||||||
'setlength', 'setcounter', 'addtocounter', 'renewcommand',
|
'setcounter', 'addtocounter', 'makeatletter',
|
||||||
'newcommand', 'makeatletter', 'makeatother', 'pagenumbering',
|
'makeatother', 'pagenumbering'
|
||||||
# Margins and columns
|
|
||||||
'marginpar', 'marginparsep', 'columnsep', 'columnseprule',
|
|
||||||
'twocolumn', 'onecolumn', 'minipage', 'parbox'
|
|
||||||
})
|
})
|
||||||
|
|
||||||
latex_chars: Dict[str, str] = field(default_factory=lambda: {
|
latex_chars: Dict[str, str] = field(default_factory=lambda: {
|
||||||
'~': ' ', '\\&': '&', '\\%': '%', '\\_': '_', '\\$': '$',
|
'~': ' ', '\\&': '&', '\\%': '%', '\\_': '_', '\\$': '$',
|
||||||
'\\#': '#', '\\{': '{', '\\}': '}', '``': '"', "''": '"',
|
'\\#': '#', '\\{': '{', '\\}': '}', '``': '"', "''": '"',
|
||||||
'\\textbackslash': '\\', '\\ldots': '...', '\\dots': '...',
|
'\\textbackslash': '\\', '\\ldots': '...', '\\dots': '...',
|
||||||
'\\textasciitilde': '~', '\\textasciicircum': '^',
|
'\\textasciitilde': '~', '\\textasciicircum': '^'
|
||||||
'\\quad': ' ', '\\qquad': ' ', '\\,': '', '\\;': '', '\\:': '',
|
|
||||||
'\\!': '', '\\space': ' ', '\\noindent': ''
|
|
||||||
})
|
})
|
||||||
|
|
||||||
inline_math_delimiters: Set[str] = field(default_factory=lambda: {
|
# 保留原始格式的特殊命令模式
|
||||||
'$', '\\(', '\\)', '\\[', '\\]'
|
special_command_patterns: List[Tuple[str, str]] = field(default_factory=lambda: [
|
||||||
})
|
(r'\\cite\*?(?:\[[^\]]*\])?{([^}]+)}', r'\\cite{\1}'),
|
||||||
|
(r'\\ref\*?{([^}]+)}', r'\\ref{\1}'),
|
||||||
|
(r'\\label{([^}]+)}', r'\\label{\1}'),
|
||||||
|
(r'\\eqref{([^}]+)}', r'\\eqref{\1}'),
|
||||||
|
(r'\\autoref{([^}]+)}', r'\\autoref{\1}'),
|
||||||
|
(r'\\url{([^}]+)}', r'\\url{\1}'),
|
||||||
|
(r'\\href{([^}]+)}{([^}]+)}', r'\\href{\1}{\2}')
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
class LatexCleaner:
|
class LatexCleaner:
|
||||||
"""Efficient and modular LaTeX text cleaner."""
|
"""Enhanced LaTeX text cleaner that preserves mathematical content and citations."""
|
||||||
|
|
||||||
def __init__(self, config: Optional[LatexConfig] = None):
|
def __init__(self, config: Optional[LatexConfig] = None):
|
||||||
self.config = config or LatexConfig()
|
self.config = config or LatexConfig()
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
|
# 初始化正则表达式缓存
|
||||||
|
self._regex_cache = {}
|
||||||
|
|
||||||
@lru_cache(maxsize=128)
|
@lru_cache(maxsize=128)
|
||||||
def _get_env_pattern(self, env_name: str) -> Pattern:
|
def _get_env_pattern(self, env_name: str) -> Pattern:
|
||||||
|
"""Get cached regex pattern for environment matching."""
|
||||||
return re.compile(fr'\\begin{{{env_name}}}(.*?)\\end{{{env_name}}}', re.DOTALL)
|
return re.compile(fr'\\begin{{{env_name}}}(.*?)\\end{{{env_name}}}', re.DOTALL)
|
||||||
|
|
||||||
def _get_env_type(self, env_name: str) -> EnvType:
|
def _get_env_type(self, env_name: str) -> EnvType:
|
||||||
"""Determine environment processing type."""
|
"""Determine environment processing type."""
|
||||||
if env_name.rstrip('*') in {name.rstrip('*') for name in self.config.preserve_envs}:
|
if env_name.rstrip('*') in {name.rstrip('*') for name in self.config.preserve_envs}:
|
||||||
return EnvType.PRESERVE
|
return EnvType.PRESERVE
|
||||||
elif env_name in {'verbatim', 'comment'}:
|
elif env_name in {'comment'}:
|
||||||
return EnvType.REMOVE
|
return EnvType.REMOVE
|
||||||
return EnvType.EXTRACT
|
return EnvType.EXTRACT
|
||||||
|
|
||||||
|
def _preserve_special_commands(self, text: str) -> str:
|
||||||
|
"""Preserve special commands like citations and references with their complete structure."""
|
||||||
|
for pattern, replacement in self.config.special_command_patterns:
|
||||||
|
if pattern not in self._regex_cache:
|
||||||
|
self._regex_cache[pattern] = re.compile(pattern)
|
||||||
|
|
||||||
|
def replace_func(match):
|
||||||
|
# 保持原始命令格式
|
||||||
|
return match.group(0)
|
||||||
|
|
||||||
|
text = self._regex_cache[pattern].sub(replace_func, text)
|
||||||
|
return text
|
||||||
|
|
||||||
def _process_environment(self, match: re.Match) -> str:
|
def _process_environment(self, match: re.Match) -> str:
|
||||||
|
"""Process LaTeX environments while preserving complete content for special environments."""
|
||||||
try:
|
try:
|
||||||
env_name = match.group(1)
|
env_name = match.group(1)
|
||||||
content = match.group(2)
|
content = match.group(2)
|
||||||
env_type = self._get_env_type(env_name)
|
env_type = self._get_env_type(env_name)
|
||||||
|
|
||||||
if env_type == EnvType.PRESERVE:
|
if env_type == EnvType.PRESERVE:
|
||||||
# Preserve math content without markers for inline math
|
# 完整保留环境内容
|
||||||
if env_name in {'math', 'displaymath'}:
|
complete_env = match.group(0)
|
||||||
return f" {content} "
|
return f"\n[BEGIN_{env_name}]\n{complete_env}\n[END_{env_name}]\n"
|
||||||
return f" [BEGIN_{env_name}] {content} [END_{env_name}] "
|
|
||||||
elif env_type == EnvType.REMOVE:
|
elif env_type == EnvType.REMOVE:
|
||||||
return ' '
|
return ' '
|
||||||
# Process nested environments recursively
|
else:
|
||||||
return self._clean_nested_environments(content)
|
# 处理嵌套环境
|
||||||
|
return self._clean_nested_environments(content)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Error processing environment {env_name}: {e}")
|
self.logger.error(f"Error processing environment {match.group(1) if match else 'unknown'}: {e}")
|
||||||
return content
|
return match.group(0)
|
||||||
|
|
||||||
|
def _preserve_inline_math(self, text: str) -> str:
|
||||||
|
"""Preserve complete inline math content."""
|
||||||
|
|
||||||
|
def preserve_math(match):
|
||||||
|
return f" {match.group(0)} "
|
||||||
|
|
||||||
|
patterns = [
|
||||||
|
(r'\$[^$]+\$', preserve_math),
|
||||||
|
(r'\\[\(\[].*?\\[\)\]]', preserve_math),
|
||||||
|
(r'\\begin{math}.*?\\end{math}', preserve_math)
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern, handler in patterns:
|
||||||
|
if pattern not in self._regex_cache:
|
||||||
|
self._regex_cache[pattern] = re.compile(pattern, re.DOTALL)
|
||||||
|
text = self._regex_cache[pattern].sub(handler, text)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
def _clean_nested_environments(self, text: str) -> str:
|
def _clean_nested_environments(self, text: str) -> str:
|
||||||
"""Process nested environments recursively."""
|
"""Process nested environments recursively."""
|
||||||
return re.sub(
|
pattern = r'\\begin{(\w+)}(.*?)\\end{\1}'
|
||||||
r'\\begin{(\w+)}(.*?)\\end{\1}',
|
if pattern not in self._regex_cache:
|
||||||
self._process_environment,
|
self._regex_cache[pattern] = re.compile(pattern, re.DOTALL)
|
||||||
text,
|
|
||||||
flags=re.DOTALL
|
return self._regex_cache[pattern].sub(self._process_environment, text)
|
||||||
)
|
|
||||||
|
|
||||||
def _clean_commands(self, text: str) -> str:
|
def _clean_commands(self, text: str) -> str:
|
||||||
"""Clean LaTeX commands while preserving specified content."""
|
"""Clean LaTeX commands while preserving important content."""
|
||||||
# Remove complete commands
|
# 首先处理特殊命令
|
||||||
for cmd in self.config.remove_commands:
|
text = self._preserve_special_commands(text)
|
||||||
text = re.sub(fr'\\{cmd}\*?(?:\[.*?\])?(?:{{.*?}})*', '', text)
|
|
||||||
|
|
||||||
# Process commands with content
|
# 保留内联数学
|
||||||
def handle_command(match: re.Match) -> str:
|
|
||||||
cmd = match.group(1).rstrip('*') # Handle starred versions
|
|
||||||
content = match.group(2)
|
|
||||||
|
|
||||||
# For these delimiters, return the original math content
|
|
||||||
if cmd in {'[', ']', '(', ')', '$'} or cmd in self.config.inline_math_delimiters:
|
|
||||||
return match.group(0)
|
|
||||||
|
|
||||||
# For preserved commands return content, otherwise return space
|
|
||||||
return match.group(0) if cmd in self.config.preserve_commands else ' '
|
|
||||||
# Handle commands with arguments
|
|
||||||
text = re.sub(r'\\(\w+)\*?(?:\[.*?\])?{(.*?)}', handle_command, text)
|
|
||||||
|
|
||||||
# Handle inline math
|
|
||||||
text = self._preserve_inline_math(text)
|
text = self._preserve_inline_math(text)
|
||||||
|
|
||||||
# Remove remaining standalone commands
|
# 移除指定的命令
|
||||||
return text
|
for cmd in self.config.remove_commands:
|
||||||
|
if cmd not in self._regex_cache:
|
||||||
|
self._regex_cache[cmd] = re.compile(
|
||||||
|
fr'\\{cmd}\*?(?:\[.*?\])?(?:{{.*?}})*'
|
||||||
|
)
|
||||||
|
text = self._regex_cache[cmd].sub('', text)
|
||||||
|
|
||||||
def _preserve_inline_math(self, text: str) -> str:
|
# 处理带内容的命令
|
||||||
"""Preserve inline math content."""
|
def handle_command(match: re.Match) -> str:
|
||||||
# Handle $...$ math
|
cmd = match.group(1).rstrip('*')
|
||||||
text = re.sub(r'\$(.+?)\$', r' \1 ', text)
|
if cmd in self.config.preserve_commands or cmd in self.config.citation_commands:
|
||||||
# Handle \(...\) math
|
return match.group(0) # 完整保留命令和内容
|
||||||
text = re.sub(r'\\[\(\[](.+?)\\[\)\]]', r' \1 ', text)
|
return ' '
|
||||||
|
|
||||||
|
if 'command_pattern' not in self._regex_cache:
|
||||||
|
self._regex_cache['command_pattern'] = re.compile(
|
||||||
|
r'\\(\w+)\*?(?:\[.*?\])?{(.*?)}'
|
||||||
|
)
|
||||||
|
|
||||||
|
text = self._regex_cache['command_pattern'].sub(handle_command, text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def _normalize_text(self, text: str) -> str:
|
def _normalize_text(self, text: str) -> str:
|
||||||
"""Normalize special characters and whitespace."""
|
"""Normalize text while preserving special content markers."""
|
||||||
# Replace special characters
|
# 替换特殊字符
|
||||||
for char, replacement in self.config.latex_chars.items():
|
for char, replacement in self.config.latex_chars.items():
|
||||||
text = text.replace(char, replacement)
|
text = text.replace(char, replacement)
|
||||||
|
|
||||||
# Clean up whitespace
|
# 清理空白字符,同时保留环境标记
|
||||||
text = re.sub(r'\s+', ' ', text)
|
text = re.sub(r'\s+', ' ', text)
|
||||||
text = re.sub(r'\s*\[BEGIN_(\w+)\]\s*', r' [BEGIN_\1] ', text)
|
text = re.sub(r'\s*\[BEGIN_(\w+)\]\s*', r'\n[BEGIN_\1]\n', text)
|
||||||
text = re.sub(r'\s*\[END_(\w+)\]\s*', r' [END_\1] ', text)
|
text = re.sub(r'\s*\[END_(\w+)\]\s*', r'\n[END_\1]\n', text)
|
||||||
|
|
||||||
# Remove empty brackets and braces
|
# 保持块级环境之间的分隔
|
||||||
text = re.sub(r'{\s*}|\[\s*\]|\(\s*\)', '', text)
|
text = re.sub(r'\n{3,}', '\n\n', text)
|
||||||
|
|
||||||
return text.strip()
|
return text.strip()
|
||||||
|
|
||||||
def clean_text(self, text: str) -> str:
|
def clean_text(self, text: str) -> str:
|
||||||
"""Clean LaTeX text while preserving meaningful content."""
|
"""Clean LaTeX text while preserving mathematical content, citations, and special environments."""
|
||||||
if not text:
|
if not text:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Remove comments not inside environments
|
# 移除注释
|
||||||
text = re.sub(r'(?<!\\)%.*?(?=\n|$)', '', text, flags=re.MULTILINE)
|
text = re.sub(r'(?<!\\)%.*?(?=\n|$)', '', text, flags=re.MULTILINE)
|
||||||
|
|
||||||
# Process environments and their nested contents
|
# 处理环境
|
||||||
text = self._clean_nested_environments(text)
|
text = self._clean_nested_environments(text)
|
||||||
|
|
||||||
# Clean commands and normalize
|
# 清理命令并规范化
|
||||||
text = self._clean_commands(text)
|
text = self._clean_commands(text)
|
||||||
text = self._normalize_text(text)
|
text = self._normalize_text(text)
|
||||||
|
|
||||||
@@ -188,30 +268,39 @@ class LatexCleaner:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Error cleaning text: {e}")
|
self.logger.error(f"Error cleaning text: {e}")
|
||||||
raise
|
return text # 发生错误时返回原始文本
|
||||||
|
|
||||||
|
|
||||||
def clean_latex_commands(text: str) -> str:
|
def clean_latex_commands(text: str) -> str:
|
||||||
"""Convenience function for quick text cleaning with default config."""
|
"""Convenience function for quick text cleaning with default config."""
|
||||||
config = LatexConfig(
|
cleaner = LatexCleaner()
|
||||||
preserve_envs={'equation', 'theorem'},
|
return cleaner.clean_text(text)
|
||||||
preserve_commands={'textbf', 'emph', "label"},
|
|
||||||
latex_chars={'~': ' ', '\\&': '&'}
|
|
||||||
)
|
|
||||||
return LatexCleaner(config).clean_text(text)
|
|
||||||
|
|
||||||
|
|
||||||
# Example usage:
|
# Example usage:
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Basic usage with inline math
|
text = r"""
|
||||||
text = clean_latex_commands(r"""
|
\documentclass{article}
|
||||||
|
\begin{document}
|
||||||
|
|
||||||
|
\section{Introduction}
|
||||||
|
This is a reference to \cite{smith2020} and equation \eqref{eq:main}.
|
||||||
|
|
||||||
|
\begin{equation}\label{eq:main}
|
||||||
|
E = mc^2 \times \sum_{i=1}^{n} x_i
|
||||||
|
\end{equation}
|
||||||
|
|
||||||
|
See Figure \ref{fig:example} for details.
|
||||||
|
|
||||||
|
\begin{figure}
|
||||||
|
\includegraphics{image.png}
|
||||||
|
\caption{Example figure\label
|
||||||
\textbf{Important} result: $E=mc^2$ and
|
\textbf{Important} result: $E=mc^2$ and
|
||||||
\begin{equation}
|
\begin{equation}
|
||||||
F = ma
|
F = ma
|
||||||
\end{equation}
|
\end{equation}
|
||||||
\label{sec:intro}
|
\label{sec:intro}
|
||||||
""")
|
"""
|
||||||
print(text)
|
|
||||||
|
|
||||||
# Custom configuration
|
# Custom configuration
|
||||||
config = LatexConfig(
|
config = LatexConfig(
|
||||||
@@ -236,5 +325,5 @@ if __name__ == "__main__":
|
|||||||
file_path = 'test_cache/2411.03663/neurips_2024.tex'
|
file_path = 'test_cache/2411.03663/neurips_2024.tex'
|
||||||
content = read_tex_file(file_path)
|
content = read_tex_file(file_path)
|
||||||
cleaner = LatexCleaner(config)
|
cleaner = LatexCleaner(config)
|
||||||
text = cleaner.clean_text(content)
|
text = cleaner.clean_text(text)
|
||||||
print(text)
|
print(text)
|
||||||
@@ -0,0 +1,412 @@
|
|||||||
|
import re
|
||||||
|
from typing import List, Dict, Tuple, Optional, Set
|
||||||
|
from enum import Enum
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SectionLevel(Enum):
|
||||||
|
CHAPTER = 0
|
||||||
|
SECTION = 1
|
||||||
|
SUBSECTION = 2
|
||||||
|
SUBSUBSECTION = 3
|
||||||
|
PARAGRAPH = 4
|
||||||
|
SUBPARAGRAPH = 5
|
||||||
|
|
||||||
|
def __lt__(self, other):
|
||||||
|
if not isinstance(other, SectionLevel):
|
||||||
|
return NotImplemented
|
||||||
|
return self.value < other.value
|
||||||
|
|
||||||
|
def __le__(self, other):
|
||||||
|
if not isinstance(other, SectionLevel):
|
||||||
|
return NotImplemented
|
||||||
|
return self.value <= other.value
|
||||||
|
|
||||||
|
def __gt__(self, other):
|
||||||
|
if not isinstance(other, SectionLevel):
|
||||||
|
return NotImplemented
|
||||||
|
return self.value > other.value
|
||||||
|
|
||||||
|
def __ge__(self, other):
|
||||||
|
if not isinstance(other, SectionLevel):
|
||||||
|
return NotImplemented
|
||||||
|
return self.value >= other.value
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Section:
|
||||||
|
level: SectionLevel
|
||||||
|
title: str
|
||||||
|
content: str = ''
|
||||||
|
bibliography: str = ''
|
||||||
|
subsections: List['Section'] = field(default_factory=list)
|
||||||
|
def merge(self, other: 'Section') -> 'Section':
|
||||||
|
"""Merge this section with another section."""
|
||||||
|
if self.title != other.title or self.level != other.level:
|
||||||
|
raise ValueError("Can only merge sections with same title and level")
|
||||||
|
|
||||||
|
merged = deepcopy(self)
|
||||||
|
merged.content = self._merge_content(self.content, other.content)
|
||||||
|
|
||||||
|
# Create subsections lookup for efficient merging
|
||||||
|
subsections_map = {s.title: s for s in merged.subsections}
|
||||||
|
|
||||||
|
for other_subsection in other.subsections:
|
||||||
|
if other_subsection.title in subsections_map:
|
||||||
|
# Merge existing subsection
|
||||||
|
idx = next(i for i, s in enumerate(merged.subsections)
|
||||||
|
if s.title == other_subsection.title)
|
||||||
|
merged.subsections[idx] = merged.subsections[idx].merge(other_subsection)
|
||||||
|
else:
|
||||||
|
# Add new subsection
|
||||||
|
merged.subsections.append(deepcopy(other_subsection))
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_content(content1: str, content2: str) -> str:
|
||||||
|
"""Merge content strings intelligently."""
|
||||||
|
if not content1:
|
||||||
|
return content2
|
||||||
|
if not content2:
|
||||||
|
return content1
|
||||||
|
# Combine non-empty contents with a separator
|
||||||
|
return f"{content1}\n\n{content2}"
|
||||||
|
@dataclass
|
||||||
|
class LatexEnvironment:
|
||||||
|
"""表示LaTeX环境的数据类"""
|
||||||
|
name: str
|
||||||
|
start: int
|
||||||
|
end: int
|
||||||
|
content: str
|
||||||
|
raw: str
|
||||||
|
|
||||||
|
|
||||||
|
class EnhancedSectionExtractor:
|
||||||
|
"""Enhanced section extractor with comprehensive content handling and hierarchy management."""
|
||||||
|
|
||||||
|
def __init__(self, preserve_environments: bool = True):
|
||||||
|
"""
|
||||||
|
初始化Section提取器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preserve_environments: 是否保留特定环境(如equation, figure等)的原始LaTeX代码
|
||||||
|
"""
|
||||||
|
self.preserve_environments = preserve_environments
|
||||||
|
|
||||||
|
# Section级别定义
|
||||||
|
self.section_levels = {
|
||||||
|
'chapter': SectionLevel.CHAPTER,
|
||||||
|
'section': SectionLevel.SECTION,
|
||||||
|
'subsection': SectionLevel.SUBSECTION,
|
||||||
|
'subsubsection': SectionLevel.SUBSUBSECTION,
|
||||||
|
'paragraph': SectionLevel.PARAGRAPH,
|
||||||
|
'subparagraph': SectionLevel.SUBPARAGRAPH
|
||||||
|
}
|
||||||
|
|
||||||
|
# 需要保留的环境类型
|
||||||
|
self.important_environments = {
|
||||||
|
'equation', 'equation*', 'align', 'align*',
|
||||||
|
'figure', 'table', 'algorithm', 'algorithmic',
|
||||||
|
'definition', 'theorem', 'lemma', 'proof',
|
||||||
|
'itemize', 'enumerate', 'description'
|
||||||
|
}
|
||||||
|
|
||||||
|
# 改进的section pattern
|
||||||
|
self.section_pattern = (
|
||||||
|
r'\\(?P<type>chapter|section|subsection|subsubsection|paragraph|subparagraph)'
|
||||||
|
r'\*?' # Optional star
|
||||||
|
r'(?:\[(?P<short>.*?)\])?' # Optional short title
|
||||||
|
r'{(?P<title>(?:[^{}]|\{[^{}]*\})*?)}' # Main title with nested braces support
|
||||||
|
)
|
||||||
|
|
||||||
|
# 环境匹配模式
|
||||||
|
self.environment_pattern = (
|
||||||
|
r'\\begin{(?P<env_name>[^}]+)}'
|
||||||
|
r'(?P<env_content>.*?)'
|
||||||
|
r'\\end{(?P=env_name)}'
|
||||||
|
)
|
||||||
|
|
||||||
|
def _find_environments(self, content: str) -> List[LatexEnvironment]:
|
||||||
|
"""
|
||||||
|
查找文档中的所有LaTeX环境。
|
||||||
|
支持嵌套环境的处理。
|
||||||
|
"""
|
||||||
|
environments = []
|
||||||
|
stack = []
|
||||||
|
|
||||||
|
# 使用正则表达式查找所有begin和end标记
|
||||||
|
begin_pattern = r'\\begin{([^}]+)}'
|
||||||
|
end_pattern = r'\\end{([^}]+)}'
|
||||||
|
|
||||||
|
# 组合模式来同时匹配begin和end
|
||||||
|
tokens = []
|
||||||
|
for match in re.finditer(fr'({begin_pattern})|({end_pattern})', content):
|
||||||
|
if match.group(1): # begin标记
|
||||||
|
tokens.append(('begin', match.group(1), match.start()))
|
||||||
|
else: # end标记
|
||||||
|
tokens.append(('end', match.group(2), match.start()))
|
||||||
|
|
||||||
|
# 处理环境嵌套
|
||||||
|
for token_type, env_name, pos in tokens:
|
||||||
|
if token_type == 'begin':
|
||||||
|
stack.append((env_name, pos))
|
||||||
|
elif token_type == 'end' and stack:
|
||||||
|
if stack[-1][0] == env_name:
|
||||||
|
start_env_name, start_pos = stack.pop()
|
||||||
|
env_content = content[start_pos:pos]
|
||||||
|
raw_content = content[start_pos:pos + len('\\end{' + env_name + '}')]
|
||||||
|
|
||||||
|
if start_env_name in self.important_environments:
|
||||||
|
environments.append(LatexEnvironment(
|
||||||
|
name=start_env_name,
|
||||||
|
start=start_pos,
|
||||||
|
end=pos + len('\\end{' + env_name + '}'),
|
||||||
|
content=env_content,
|
||||||
|
raw=raw_content
|
||||||
|
))
|
||||||
|
|
||||||
|
return sorted(environments, key=lambda x: x.start)
|
||||||
|
|
||||||
|
def _protect_environments(self, content: str) -> Tuple[str, Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
保护重要的LaTeX环境,用占位符替换它们。
|
||||||
|
返回处理后的内容和恢复映射。
|
||||||
|
"""
|
||||||
|
environments = self._find_environments(content)
|
||||||
|
replacements = {}
|
||||||
|
|
||||||
|
# 从后向前替换,避免位置改变的问题
|
||||||
|
for env in reversed(environments):
|
||||||
|
if env.name in self.important_environments:
|
||||||
|
placeholder = f'__ENV_{len(replacements)}__'
|
||||||
|
replacements[placeholder] = env.raw
|
||||||
|
content = content[:env.start] + placeholder + content[env.end:]
|
||||||
|
|
||||||
|
return content, replacements
|
||||||
|
|
||||||
|
def _restore_environments(self, content: str, replacements: Dict[str, str]) -> str:
|
||||||
|
"""
|
||||||
|
恢复之前保护的环境。
|
||||||
|
"""
|
||||||
|
for placeholder, original in replacements.items():
|
||||||
|
content = content.replace(placeholder, original)
|
||||||
|
return content
|
||||||
|
|
||||||
|
def extract(self, content: str) -> List[Section]:
|
||||||
|
"""
|
||||||
|
从LaTeX文档中提取sections及其内容。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: LaTeX文档内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Section]: 提取的section列表,包含层次结构
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 预处理:保护重要环境
|
||||||
|
if self.preserve_environments:
|
||||||
|
content, env_replacements = self._protect_environments(content)
|
||||||
|
|
||||||
|
# 查找所有sections
|
||||||
|
sections = self._find_all_sections(content)
|
||||||
|
if not sections:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 处理sections
|
||||||
|
root_sections = self._process_sections(content, sections)
|
||||||
|
|
||||||
|
# 如果需要,恢复环境
|
||||||
|
if self.preserve_environments:
|
||||||
|
for section in self._traverse_sections(root_sections):
|
||||||
|
section.content = self._restore_environments(section.content, env_replacements)
|
||||||
|
|
||||||
|
return root_sections
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error extracting sections: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _find_all_sections(self, content: str) -> List[dict]:
|
||||||
|
"""查找所有section命令及其位置。"""
|
||||||
|
sections = []
|
||||||
|
|
||||||
|
for match in re.finditer(self.section_pattern, content, re.DOTALL | re.MULTILINE):
|
||||||
|
section_type = match.group('type').lower()
|
||||||
|
if section_type not in self.section_levels:
|
||||||
|
continue
|
||||||
|
|
||||||
|
section = {
|
||||||
|
'type': section_type,
|
||||||
|
'level': self.section_levels[section_type],
|
||||||
|
'title': self._clean_title(match.group('title')),
|
||||||
|
'start': match.start(),
|
||||||
|
'command_end': match.end(),
|
||||||
|
}
|
||||||
|
sections.append(section)
|
||||||
|
|
||||||
|
return sorted(sections, key=lambda x: x['start'])
|
||||||
|
|
||||||
|
def _process_sections(self, content: str, sections: List[dict]) -> List[Section]:
|
||||||
|
"""处理sections以构建层次结构和提取内容。"""
|
||||||
|
# 计算content范围
|
||||||
|
self._calculate_content_ranges(content, sections)
|
||||||
|
|
||||||
|
# 构建层次结构
|
||||||
|
root_sections = []
|
||||||
|
section_stack = []
|
||||||
|
|
||||||
|
for section_info in sections:
|
||||||
|
new_section = Section(
|
||||||
|
level=section_info['level'],
|
||||||
|
title=section_info['title'],
|
||||||
|
content=self._extract_clean_content(content, section_info),
|
||||||
|
subsections=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 调整堆栈以找到正确的父section
|
||||||
|
while section_stack and section_stack[-1].level.value >= new_section.level.value:
|
||||||
|
section_stack.pop()
|
||||||
|
|
||||||
|
if section_stack:
|
||||||
|
section_stack[-1].subsections.append(new_section)
|
||||||
|
else:
|
||||||
|
root_sections.append(new_section)
|
||||||
|
|
||||||
|
section_stack.append(new_section)
|
||||||
|
|
||||||
|
return root_sections
|
||||||
|
|
||||||
|
def _calculate_content_ranges(self, content: str, sections: List[dict]):
|
||||||
|
for i, current in enumerate(sections):
|
||||||
|
content_start = current['command_end']
|
||||||
|
|
||||||
|
# 找到下一个section(无论什么级别)
|
||||||
|
content_end = len(content)
|
||||||
|
for next_section in sections[i + 1:]:
|
||||||
|
content_end = next_section['start']
|
||||||
|
break
|
||||||
|
|
||||||
|
current['content_range'] = (content_start, content_end)
|
||||||
|
|
||||||
|
def _calculate_content_ranges_with_subsection_content(self, content: str, sections: List[dict]):
|
||||||
|
"""为每个section计算内容范围。"""
|
||||||
|
for i, current in enumerate(sections):
|
||||||
|
content_start = current['command_end']
|
||||||
|
|
||||||
|
# 找到下一个同级或更高级的section
|
||||||
|
content_end = len(content)
|
||||||
|
for next_section in sections[i + 1:]:
|
||||||
|
if next_section['level'] <= current['level']:
|
||||||
|
content_end = next_section['start']
|
||||||
|
break
|
||||||
|
|
||||||
|
current['content_range'] = (content_start, content_end)
|
||||||
|
|
||||||
|
def _extract_clean_content(self, content: str, section_info: dict) -> str:
|
||||||
|
"""提取并清理section内容。"""
|
||||||
|
start, end = section_info['content_range']
|
||||||
|
raw_content = content[start:end]
|
||||||
|
|
||||||
|
# 清理内容
|
||||||
|
clean_content = self._clean_content(raw_content)
|
||||||
|
return clean_content
|
||||||
|
|
||||||
|
def _clean_content(self, content: str) -> str:
|
||||||
|
"""清理LaTeX内容同时保留重要信息。"""
|
||||||
|
# 移除注释
|
||||||
|
content = re.sub(r'(?<!\\)%.*?\n', '\n', content)
|
||||||
|
|
||||||
|
# LaTeX命令处理规则
|
||||||
|
replacements = [
|
||||||
|
# 保留引用
|
||||||
|
(r'\\cite(?:\[.*?\])?{(.*?)}', r'[cite:\1]'),
|
||||||
|
# 保留脚注
|
||||||
|
(r'\\footnote{(.*?)}', r'[footnote:\1]'),
|
||||||
|
# 处理引用
|
||||||
|
(r'\\ref{(.*?)}', r'[ref:\1]'),
|
||||||
|
# 保留URL
|
||||||
|
(r'\\url{(.*?)}', r'[url:\1]'),
|
||||||
|
# 保留超链接
|
||||||
|
(r'\\href{(.*?)}{(.*?)}', r'[\2](\1)'),
|
||||||
|
# 处理文本格式命令
|
||||||
|
(r'\\(?:textbf|textit|emph){(.*?)}', r'\1'),
|
||||||
|
# 保留特殊字符
|
||||||
|
(r'\\([&%$#_{}])', r'\1'),
|
||||||
|
]
|
||||||
|
|
||||||
|
# 应用所有替换规则
|
||||||
|
for pattern, replacement in replacements:
|
||||||
|
content = re.sub(pattern, replacement, content, flags=re.DOTALL)
|
||||||
|
|
||||||
|
# 清理多余的空白
|
||||||
|
content = re.sub(r'\n\s*\n', '\n\n', content)
|
||||||
|
return content.strip()
|
||||||
|
|
||||||
|
def _clean_title(self, title: str) -> str:
|
||||||
|
"""清理section标题。"""
|
||||||
|
# 处理嵌套的花括号
|
||||||
|
while '{' in title:
|
||||||
|
title = re.sub(r'{([^{}]*)}', r'\1', title)
|
||||||
|
|
||||||
|
# 处理LaTeX命令
|
||||||
|
title = re.sub(r'\\[a-zA-Z]+(?:\[.*?\])?{(.*?)}', r'\1', title)
|
||||||
|
title = re.sub(r'\\([&%$#_{}])', r'\1', title)
|
||||||
|
|
||||||
|
return title.strip()
|
||||||
|
|
||||||
|
def _traverse_sections(self, sections: List[Section]) -> List[Section]:
|
||||||
|
"""遍历所有sections(包括子sections)。"""
|
||||||
|
result = []
|
||||||
|
for section in sections:
|
||||||
|
result.append(section)
|
||||||
|
result.extend(self._traverse_sections(section.subsections))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def test_enhanced_extractor():
|
||||||
|
"""使用复杂的测试用例测试提取器。"""
|
||||||
|
test_content = r"""
|
||||||
|
\section{Complex Examples}
|
||||||
|
Here's a complex section with various environments.
|
||||||
|
|
||||||
|
\begin{equation}
|
||||||
|
E = mc^2
|
||||||
|
\end{equation}
|
||||||
|
|
||||||
|
\subsection{Nested Environments}
|
||||||
|
This subsection has nested environments.
|
||||||
|
|
||||||
|
\begin{figure}
|
||||||
|
\begin{equation*}
|
||||||
|
f(x) = \int_0^x g(t) dt
|
||||||
|
\end{equation*}
|
||||||
|
\caption{A nested equation in a figure}
|
||||||
|
\end{figure}
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
extractor = EnhancedSectionExtractor()
|
||||||
|
sections = extractor.extract(test_content)
|
||||||
|
|
||||||
|
def print_section(section, level=0):
|
||||||
|
print("\n" + " " * level + f"[{section.level.name}] {section.title}")
|
||||||
|
if section.content:
|
||||||
|
content_preview = section.content[:150] + "..." if len(section.content) > 150 else section.content
|
||||||
|
print(" " * (level + 1) + f"Content: {content_preview}")
|
||||||
|
for subsection in section.subsections:
|
||||||
|
print_section(subsection, level + 1)
|
||||||
|
|
||||||
|
print("\nExtracted Section Structure:")
|
||||||
|
for section in sections:
|
||||||
|
print_section(section)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_enhanced_extractor()
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SectionFragment:
|
||||||
|
"""Arxiv论文片段数据类"""
|
||||||
|
title: str # 文件路径
|
||||||
|
abstract: str # 论文摘要
|
||||||
|
section_tree: str # 文章各章节的目录结构
|
||||||
|
arxiv_id: str = "" # 添加 arxiv_id 属性
|
||||||
|
current_section: str = "Introduction" # 当前片段所属的section或者subsection或者孙subsubsection名字
|
||||||
|
content: str = '' #当前片段的内容
|
||||||
|
bibliography: str = '' #当前片段的参考文献
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -0,0 +1,271 @@
|
|||||||
|
import re
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple, Dict, Set, Optional, Callable
|
||||||
|
from crazy_functions.rag_fns.arxiv_fns.latex_patterns import LaTeXPatterns
|
||||||
|
|
||||||
|
class TexUtils:
|
||||||
|
"""TeX文档处理器类"""
|
||||||
|
|
||||||
|
def __init__(self, ):
|
||||||
|
"""
|
||||||
|
初始化TeX处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
char_range: 字符数范围(最小值, 最大值)
|
||||||
|
"""
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 初始化LaTeX环境和命令模式
|
||||||
|
self._init_patterns()
|
||||||
|
self.latex_only_patterns = LaTeXPatterns.latex_only_patterns
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _init_patterns(self):
|
||||||
|
"""初始化LaTeX模式匹配规则"""
|
||||||
|
# 特殊环境模式
|
||||||
|
self.special_envs = LaTeXPatterns.special_envs
|
||||||
|
# 章节模式
|
||||||
|
self.section_patterns = LaTeXPatterns.section_patterns
|
||||||
|
# 包含模式
|
||||||
|
self.include_patterns = LaTeXPatterns.include_patterns
|
||||||
|
# 元数据模式
|
||||||
|
self.metadata_patterns = LaTeXPatterns.metadata_patterns
|
||||||
|
|
||||||
|
def read_file(self, file_path: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
读取TeX文件内容,支持多种编码
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: 文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: 文件内容或None
|
||||||
|
"""
|
||||||
|
encodings = ['utf-8', 'latin1', 'gbk', 'gb2312', 'ascii']
|
||||||
|
for encoding in encodings:
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r', encoding=encoding) as f:
|
||||||
|
return f.read()
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.logger.warning(f"Failed to read {file_path} with all encodings")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def find_main_tex_file(self, directory: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
查找主TeX文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
directory: 目录路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: 主文件路径或None
|
||||||
|
"""
|
||||||
|
tex_files = list(Path(directory).rglob("*.tex"))
|
||||||
|
if not tex_files:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 按优先级查找
|
||||||
|
for tex_file in tex_files:
|
||||||
|
content = self.read_file(str(tex_file))
|
||||||
|
if content:
|
||||||
|
if r'\documentclass' in content:
|
||||||
|
return str(tex_file)
|
||||||
|
if tex_file.name.lower() == 'main.tex':
|
||||||
|
return str(tex_file)
|
||||||
|
|
||||||
|
# 返回最大的tex文件
|
||||||
|
return str(max(tex_files, key=lambda x: x.stat().st_size))
|
||||||
|
|
||||||
|
def resolve_includes(self, tex_file: str, processed: Set[str] = None) -> List[str]:
|
||||||
|
"""
|
||||||
|
解析TeX文件中的include引用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tex_file: TeX文件路径
|
||||||
|
processed: 已处理的文件集合
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: 相关文件路径列表
|
||||||
|
"""
|
||||||
|
if processed is None:
|
||||||
|
processed = set()
|
||||||
|
|
||||||
|
if tex_file in processed:
|
||||||
|
return []
|
||||||
|
|
||||||
|
processed.add(tex_file)
|
||||||
|
result = [tex_file]
|
||||||
|
content = self.read_file(tex_file)
|
||||||
|
|
||||||
|
if not content:
|
||||||
|
return result
|
||||||
|
|
||||||
|
base_dir = Path(tex_file).parent
|
||||||
|
for pattern in self.include_patterns:
|
||||||
|
for match in re.finditer(pattern, content):
|
||||||
|
included_file = match.group(2)
|
||||||
|
if not included_file.endswith('.tex'):
|
||||||
|
included_file += '.tex'
|
||||||
|
|
||||||
|
full_path = str(base_dir / included_file)
|
||||||
|
if os.path.exists(full_path) and full_path not in processed:
|
||||||
|
result.extend(self.resolve_includes(full_path, processed))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def resolve_references(self, tex_file: str, path_dir: str = None) -> str:
|
||||||
|
"""
|
||||||
|
解析TeX文件中的参考文献引用,返回所有引用文献的内容,只保留title、author和journal字段。
|
||||||
|
如果在tex_file目录下没找到bib文件,会在path_dir中查找。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tex_file: TeX文件路径
|
||||||
|
path_dir: 额外的参考文献搜索路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 所有参考文献内容的字符串,只包含特定字段,不同参考文献之间用空行分隔
|
||||||
|
"""
|
||||||
|
all_references = [] # 存储所有参考文献内容
|
||||||
|
content = self.read_file(tex_file)
|
||||||
|
|
||||||
|
if not content:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 扩展参考文献引用的模式
|
||||||
|
bib_patterns = [
|
||||||
|
r'\\bibliography\{([^}]+)\}',
|
||||||
|
r'\\addbibresource\{([^}]+)\}',
|
||||||
|
r'\\bibliographyfile\{([^}]+)\}',
|
||||||
|
r'\\begin\{thebibliography\}',
|
||||||
|
r'\\bibinput\{([^}]+)\}',
|
||||||
|
r'\\newrefsection\{([^}]+)\}'
|
||||||
|
]
|
||||||
|
|
||||||
|
base_dir = Path(tex_file).parent
|
||||||
|
found_in_tex_dir = False
|
||||||
|
|
||||||
|
# 首先在tex文件目录下查找显式引用的bib文件
|
||||||
|
for pattern in bib_patterns:
|
||||||
|
for match in re.finditer(pattern, content):
|
||||||
|
if not match.groups():
|
||||||
|
continue
|
||||||
|
|
||||||
|
bib_files = match.group(1).split(',')
|
||||||
|
for bib_file in bib_files:
|
||||||
|
bib_file = bib_file.strip()
|
||||||
|
if not bib_file.endswith('.bib'):
|
||||||
|
bib_file += '.bib'
|
||||||
|
|
||||||
|
full_path = str(base_dir / bib_file)
|
||||||
|
if os.path.exists(full_path):
|
||||||
|
found_in_tex_dir = True
|
||||||
|
bib_content = self.read_file(full_path)
|
||||||
|
if bib_content:
|
||||||
|
processed_refs = self._process_bib_content(bib_content)
|
||||||
|
all_references.extend(processed_refs)
|
||||||
|
|
||||||
|
# 如果在tex文件目录下没找到bib文件,且提供了额外搜索路径
|
||||||
|
if not found_in_tex_dir and path_dir:
|
||||||
|
search_dir = Path(path_dir)
|
||||||
|
try:
|
||||||
|
for bib_path in search_dir.glob('**/*.bib'):
|
||||||
|
bib_content = self.read_file(str(bib_path))
|
||||||
|
if bib_content:
|
||||||
|
processed_refs = self._process_bib_content(bib_content)
|
||||||
|
all_references.extend(processed_refs)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error searching in path_dir: {e}")
|
||||||
|
|
||||||
|
# 合并所有参考文献内容,用空行分隔
|
||||||
|
return "\n\n".join(all_references)
|
||||||
|
|
||||||
|
def _process_bib_content(self, content: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
处理bib文件内容,提取每个参考文献的特定字段
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: bib文件内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: 处理后的参考文献列表
|
||||||
|
"""
|
||||||
|
processed_refs = []
|
||||||
|
# 匹配完整的参考文献条目
|
||||||
|
ref_pattern = r'@\w+\{[^@]*\}'
|
||||||
|
# 匹配参考文献类型和键值
|
||||||
|
entry_start_pattern = r'@(\w+)\{([^,]*?),'
|
||||||
|
# 匹配字段
|
||||||
|
field_pattern = r'(\w+)\s*=\s*\{([^}]*)\}'
|
||||||
|
|
||||||
|
# 查找所有参考文献条目
|
||||||
|
for ref_match in re.finditer(ref_pattern, content, re.DOTALL):
|
||||||
|
ref_content = ref_match.group(0)
|
||||||
|
|
||||||
|
# 获取参考文献类型和键值
|
||||||
|
entry_match = re.match(entry_start_pattern, ref_content)
|
||||||
|
if not entry_match:
|
||||||
|
continue
|
||||||
|
|
||||||
|
entry_type, cite_key = entry_match.groups()
|
||||||
|
|
||||||
|
# 提取需要的字段
|
||||||
|
needed_fields = {'title': None, 'author': None, 'journal': None}
|
||||||
|
for field_match in re.finditer(field_pattern, ref_content):
|
||||||
|
field_name, field_value = field_match.groups()
|
||||||
|
field_name = field_name.lower()
|
||||||
|
if field_name in needed_fields:
|
||||||
|
needed_fields[field_name] = field_value.strip()
|
||||||
|
|
||||||
|
# 构建新的参考文献条目
|
||||||
|
if any(needed_fields.values()): # 如果至少有一个需要的字段
|
||||||
|
ref_lines = [f"@{entry_type}{{{cite_key},"]
|
||||||
|
for field_name, field_value in needed_fields.items():
|
||||||
|
if field_value:
|
||||||
|
ref_lines.append(f" {field_name}={{{field_value}}},")
|
||||||
|
ref_lines[-1] = ref_lines[-1][:-1] # 移除最后一个逗号
|
||||||
|
ref_lines.append("}")
|
||||||
|
|
||||||
|
processed_refs.append("\n".join(ref_lines))
|
||||||
|
|
||||||
|
return processed_refs
|
||||||
|
def _extract_inline_references(self, content: str) -> str:
|
||||||
|
"""
|
||||||
|
从tex文件内容中提取直接写在文件中的参考文献
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: tex文件内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 提取的参考文献内容,如果没有找到则返回空字符串
|
||||||
|
"""
|
||||||
|
# 查找参考文献环境
|
||||||
|
bib_start = r'\\begin\{thebibliography\}'
|
||||||
|
bib_end = r'\\end\{thebibliography\}'
|
||||||
|
|
||||||
|
start_match = re.search(bib_start, content)
|
||||||
|
end_match = re.search(bib_end, content)
|
||||||
|
|
||||||
|
if start_match and end_match:
|
||||||
|
return content[start_match.start():end_match.end()]
|
||||||
|
|
||||||
|
return ""
|
||||||
|
def _preprocess_content(self, content: str) -> str:
|
||||||
|
"""预处理TeX内容"""
|
||||||
|
# 移除注释
|
||||||
|
content = re.sub(r'(?m)%.*$', '', content)
|
||||||
|
# 规范化空白字符
|
||||||
|
# content = re.sub(r'\s+', ' ', content)
|
||||||
|
content = re.sub(r'\n\s*\n', '\n\n', content)
|
||||||
|
return content.strip()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
在新工单中引用
屏蔽一个用户