保存完整的section层级路径

这个提交包含在:
lbykkkk
2024-11-09 18:19:51 +08:00
父节点 c6687646e4
当前提交 b8617921f4

查看文件

@@ -0,0 +1,534 @@
import os
import re
import requests
import tarfile
import logging
from dataclasses import dataclass
from typing import Generator, List, Tuple, Optional, Dict, Set
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
@dataclass
class ArxivFragment:
"""Arxiv论文片段数据类"""
file_path: str
content: str
segment_index: int
total_segments: int
rel_path: str
segment_type: str
title: str
abstract: str
section: str # 保存完整的section层级路径,如 "Introduction" 或 "Methods-Data Processing"
section_type: str # 新增:标识片段类型,如 "abstract", "section", "subsection" 等
section_level: int # 新增section的层级深度,abstract为0,main section为1,subsection为2,等等
is_appendix: bool
class SmartArxivSplitter:
def __init__(self,
char_range: Tuple[int, int],
root_dir: str = "gpt_log/arxiv_cache",
proxies: Optional[Dict[str, str]] = None,
max_workers: int = 4):
self.min_chars, self.max_chars = char_range
self.root_dir = Path(root_dir)
self.root_dir.mkdir(parents=True, exist_ok=True)
self.proxies = proxies or {}
self.max_workers = max_workers
# 定义特殊环境模式
self._init_patterns()
# 配置日志
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s')
def _init_patterns(self):
"""初始化LaTeX环境和命令模式"""
self.special_envs = {
'math': [r'\\begin{(equation|align|gather|eqnarray)\*?}.*?\\end{\1\*?}',
r'\$\$.*?\$\$', r'\$[^$]+\$'],
'table': [r'\\begin{(table|tabular)\*?}.*?\\end{\1\*?}'],
'figure': [r'\\begin{figure\*?}.*?\\end{figure\*?}'],
'algorithm': [r'\\begin{(algorithm|algorithmic)}.*?\\end{\1}']
}
self.section_patterns = [
r'\\(sub)*section\{([^}]+)\}',
r'\\chapter\{([^}]+)\}'
]
self.include_patterns = [
r'\\(input|include|subfile)\{([^}]+)\}'
]
def _find_main_tex_file(self, directory: str) -> Optional[str]:
"""查找主TEX文件"""
tex_files = list(Path(directory).rglob("*.tex"))
if not tex_files:
return None
# 按以下优先级查找:
# 1. 包含documentclass的文件
# 2. 文件名为main.tex
# 3. 最大的tex文件
for tex_file in tex_files:
try:
content = self._read_file(str(tex_file))
if content and r'\documentclass' in content:
return str(tex_file)
if tex_file.name.lower() == 'main.tex':
return str(tex_file)
except Exception:
continue
return str(max(tex_files, key=lambda x: x.stat().st_size))
def _resolve_includes(self, tex_file: str, processed: Set[str] = None) -> List[str]:
"""递归解析tex文件中的include/input命令"""
if processed is None:
processed = set()
if tex_file in processed:
return []
processed.add(tex_file)
result = [tex_file]
content = self._read_file(tex_file)
if not content:
return result
base_dir = Path(tex_file).parent
for pattern in self.include_patterns:
for match in re.finditer(pattern, content):
included_file = match.group(2)
if not included_file.endswith('.tex'):
included_file += '.tex'
# 构建完整路径
full_path = str(base_dir / included_file)
if os.path.exists(full_path) and full_path not in processed:
result.extend(self._resolve_includes(full_path, processed))
return result
def _split_into_sentences(self, text: str) -> List[str]:
"""将文本分割成句子"""
return re.split(r'(?<=[.!?。!?])\s+', text)
def _split_long_sentence(self, sentence: str) -> List[str]:
"""智能分割过长的句子"""
if len(sentence) <= self.max_chars:
return [sentence]
result = []
while sentence:
# 在最大长度位置寻找合适的分割点
split_pos = self._find_split_position(sentence[:self.max_chars])
if split_pos <= 0:
split_pos = self.max_chars
result.append(sentence[:split_pos])
sentence = sentence[split_pos:].strip()
return result
def _find_split_position(self, text: str) -> int:
"""找到合适的句子分割位置"""
# 优先在标点符号处分割
punctuation_match = re.search(r'[,,;;]\s*', text[::-1])
if punctuation_match:
return len(text) - punctuation_match.end()
# 其次在空白字符处分割
space_match = re.search(r'\s+', text[::-1])
if space_match:
return len(text) - space_match.end()
return -1
def _protect_special_environments(self, content: str, protected_blocks: Dict[str, str]) -> str:
"""保护特殊环境内容"""
for env_type, patterns in self.special_envs.items():
for pattern in patterns:
content = re.sub(
pattern,
lambda m: self._store_protected_block(m.group(0), protected_blocks),
content,
flags=re.DOTALL
)
return content
def _store_protected_block(self, content: str, protected_blocks: Dict[str, str]) -> str:
"""存储受保护的内容块"""
placeholder = f"PROTECTED_{len(protected_blocks)}"
protected_blocks[placeholder] = content
return placeholder
def _restore_special_environments(self, content: str, protected_blocks: Dict[str, str]) -> str:
"""恢复特殊环境内容"""
for placeholder, original in protected_blocks.items():
content = content.replace(placeholder, original)
return content
def _is_special_environment(self, text: str) -> bool:
"""判断是否是特殊环境"""
for patterns in self.special_envs.values():
for pattern in patterns:
if re.search(pattern, text, re.DOTALL):
return True
return False
def _preprocess_content(self, content: str) -> str:
"""预处理TEX内容"""
# 移除注释
content = re.sub(r'(?m)%.*$', '', content)
# 规范化空白字符
content = re.sub(r'\s+', ' ', content)
content = re.sub(r'\n\s*\n', '\n\n', content)
# 移除不必要的命令
content = re.sub(r'\\(label|ref|cite)\{[^}]*\}', '', content)
return content.strip()
def process_paper(self, arxiv_id_or_url: str) -> Generator[ArxivFragment, None, None]:
"""处理单篇arxiv论文"""
try:
arxiv_id = self._normalize_arxiv_id(arxiv_id_or_url)
paper_dir = self._download_and_extract(arxiv_id)
# 查找主tex文件
main_tex = self._find_main_tex_file(paper_dir)
if not main_tex:
raise RuntimeError(f"No main tex file found in {paper_dir}")
# 获取所有相关tex文件
tex_files = self._resolve_includes(main_tex)
# 处理所有tex文件
fragments = []
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
future_to_file = {
executor.submit(self._process_single_tex, file_path): file_path
for file_path in tex_files
}
for future in as_completed(future_to_file):
try:
fragments.extend(future.result())
except Exception as e:
logging.error(f"Error processing file: {e}")
# 重新计算片段索引
fragments.sort(key=lambda x: (x.rel_path, x.segment_index))
total_fragments = len(fragments)
for i, fragment in enumerate(fragments):
fragment.segment_index = i
fragment.total_segments = total_fragments
yield fragment
except Exception as e:
logging.error(f"Error processing paper {arxiv_id_or_url}: {e}")
raise RuntimeError(f"Failed to process paper: {str(e)}")
def _normalize_arxiv_id(self, input_str: str) -> str:
"""规范化arxiv ID"""
if input_str.startswith('https://arxiv.org/'):
if '/pdf/' in input_str:
return input_str.split('/pdf/')[-1].split('v')[0]
return input_str.split('/abs/')[-1].split('v')[0]
return input_str.split('v')[0]
def _download_and_extract(self, arxiv_id: str) -> str:
"""下载并解压arxiv论文源码"""
paper_dir = self.root_dir / arxiv_id
tar_path = paper_dir / f"{arxiv_id}.tar.gz"
# 检查缓存
if paper_dir.exists() and any(paper_dir.iterdir()):
logging.info(f"Using cached version for {arxiv_id}")
return str(paper_dir)
paper_dir.mkdir(exist_ok=True)
urls = [
f"https://arxiv.org/src/{arxiv_id}",
f"https://arxiv.org/e-print/{arxiv_id}"
]
for url in urls:
try:
logging.info(f"Downloading from {url}")
response = requests.get(url, proxies=self.proxies)
if response.status_code == 200:
tar_path.write_bytes(response.content)
with tarfile.open(tar_path, 'r:gz') as tar:
tar.extractall(path=paper_dir)
return str(paper_dir)
except Exception as e:
logging.warning(f"Download failed for {url}: {e}")
continue
raise RuntimeError(f"Failed to download paper {arxiv_id}")
def _read_file(self, file_path: str) -> Optional[str]:
"""使用多种编码尝试读取文件"""
encodings = ['utf-8', 'latin1', 'gbk', 'gb2312', 'ascii']
for encoding in encodings:
try:
with open(file_path, 'r', encoding=encoding) as f:
return f.read()
except UnicodeDecodeError:
continue
logging.warning(f"Failed to read file {file_path} with all encodings")
return None
def _extract_metadata(self, content: str) -> Tuple[str, str]:
"""提取论文标题和摘要"""
title = ""
abstract = ""
# 提取标题
title_patterns = [
r'\\title{([^}]*)}',
r'\\Title{([^}]*)}'
]
for pattern in title_patterns:
match = re.search(pattern, content)
if match:
title = match.group(1)
title = re.sub(r'\\[a-zA-Z]+{([^}]*)}', r'\1', title)
break
# 提取摘要
abstract_patterns = [
r'\\begin{abstract}(.*?)\\end{abstract}',
r'\\abstract{([^}]*)}'
]
for pattern in abstract_patterns:
match = re.search(pattern, content, re.DOTALL)
if match:
abstract = match.group(1).strip()
abstract = re.sub(r'\\[a-zA-Z]+{([^}]*)}', r'\1', abstract)
break
return title.strip(), abstract.strip()
def _get_section_info(self, para: str, content: str) -> Optional[Tuple[str, str, int, bool]]:
"""获取段落所属的章节信息,返回(section_path, section_type, level, is_appendix)"""
current_path = []
section_type = "content"
level = 0
is_appendix = False
# 定义section层级的正则模式
section_patterns = {
r'\\chapter\{([^}]+)\}': 1,
r'\\section\{([^}]+)\}': 1,
r'\\subsection\{([^}]+)\}': 2,
r'\\subsubsection\{([^}]+)\}': 3
}
# 查找所有章节标记
all_sections = []
for pattern, sec_level in section_patterns.items():
for match in re.finditer(pattern, content):
all_sections.append((match.start(), match.group(1), sec_level))
# 检查是否是摘要
abstract_match = re.search(r'\\begin{abstract}.*?' + re.escape(para), content, re.DOTALL)
if abstract_match:
return "Abstract", "abstract", 0, False
# 查找appendix标记
appendix_pos = content.find(r'\appendix')
# 确定当前章节
para_pos = content.find(para)
if para_pos >= 0:
is_appendix = appendix_pos >= 0 and para_pos > appendix_pos
current_sections = []
current_level = 0
# 按位置排序所有section标记
for sec_pos, sec_title, sec_level in sorted(all_sections):
if sec_pos > para_pos:
break
# 如果遇到更高层级的section,清除所有更低层级的section
if sec_level <= current_level:
current_sections = [s for s in current_sections if s[1] < sec_level]
current_sections.append((sec_title, sec_level))
current_level = sec_level
# 构建section路径
if current_sections:
current_path = [s[0] for s in sorted(current_sections, key=lambda x: x[1])]
section_path = "-".join(current_path)
level = max(s[1] for s in current_sections)
section_type = "section" if level == 1 else "subsection"
return section_path, section_type, level, is_appendix
return "Unknown Section", "content", 0, is_appendix
def _smart_split(self, content: str) -> List[Tuple[str, str, str, int, bool]]:
"""智能分割TEX内容,确保在字符范围内并保持语义完整性"""
content = self._preprocess_content(content)
segments = []
current_buffer = []
current_length = 0
current_section_info = ("Unknown Section", "content", 0, False)
# 保护特殊环境
protected_blocks = {}
content = self._protect_special_environments(content, protected_blocks)
# 按段落分割
paragraphs = re.split(r'\n\s*\n', content)
for para in paragraphs:
para = para.strip()
if not para:
continue
# 恢复特殊环境
para = self._restore_special_environments(para, protected_blocks)
# 更新章节信息
section_info = self._get_section_info(para, content)
if section_info:
current_section_info = section_info
# 判断是否是特殊环境
if self._is_special_environment(para):
# 处理当前缓冲区
if current_buffer:
segments.append((
'\n'.join(current_buffer),
*current_section_info
))
current_buffer = []
current_length = 0
# 添加特殊环境作为独立片段
segments.append((para, *current_section_info))
continue
# 处理普通段落
sentences = self._split_into_sentences(para)
for sentence in sentences:
sentence = sentence.strip()
if not sentence:
continue
sent_length = len(sentence)
new_length = current_length + sent_length + (1 if current_buffer else 0)
if new_length <= self.max_chars:
current_buffer.append(sentence)
current_length = new_length
else:
# 如果当前缓冲区达到最小长度要求
if current_length >= self.min_chars:
segments.append((
'\n'.join(current_buffer),
*current_section_info
))
current_buffer = [sentence]
current_length = sent_length
else:
# 尝试将过长的句子分割
split_sentences = self._split_long_sentence(sentence)
for split_sent in split_sentences:
if current_length + len(split_sent) <= self.max_chars:
current_buffer.append(split_sent)
current_length += len(split_sent) + 1
else:
segments.append((
'\n'.join(current_buffer),
*current_section_info
))
current_buffer = [split_sent]
current_length = len(split_sent)
# 处理剩余的缓冲区
if current_buffer:
segments.append((
'\n'.join(current_buffer),
*current_section_info
))
return segments
def _process_single_tex(self, file_path: str) -> List[ArxivFragment]:
"""处理单个TEX文件"""
try:
content = self._read_file(file_path)
if not content:
return []
# 提取元数据
is_main = r'\documentclass' in content
title = ""
abstract = ""
if is_main:
title, abstract = self._extract_metadata(content)
# 智能分割内容
segments = self._smart_split(content)
fragments = []
for i, (segment_content, section_path, section_type, level, is_appendix) in enumerate(segments):
if segment_content.strip():
segment_type = 'text'
for env_type, patterns in self.special_envs.items():
if any(re.search(pattern, segment_content, re.DOTALL)
for pattern in patterns):
segment_type = env_type
break
fragments.append(ArxivFragment(
file_path=file_path,
content=segment_content,
segment_index=i,
total_segments=len(segments),
rel_path=os.path.relpath(file_path, str(self.root_dir)),
segment_type=segment_type,
title=title,
abstract=abstract,
section=section_path,
section_type=section_type,
section_level=level,
is_appendix=is_appendix
))
return fragments
except Exception as e:
logging.error(f"Error processing file {file_path}: {e}")
return []
def main():
"""使用示例"""
# 创建分割器实例
splitter = SmartArxivSplitter(
char_range=(1000, 1200),
root_dir="gpt_log/arxiv_cache"
)
# 处理论文
for fragment in splitter.process_paper("2411.03663"):
print(f"Segment {fragment.segment_index + 1}/{fragment.total_segments}")
print(f"Length: {len(fragment.content)}")
print(f"Section: {fragment.section}")
print(fragment.content)
print("-" * 80)
if __name__ == "__main__":
main()