这个提交包含在:
lbykkkk
2024-11-16 00:35:31 +08:00
父节点 dd902e9519
当前提交 21626a44d5
共有 12 个文件被更改,包括 2385 次插入1169 次删除

查看文件

@@ -0,0 +1,449 @@
import os
import re
import time
import aiohttp
import asyncio
import requests
import tarfile
import logging
from pathlib import Path
from typing import Generator, List, Tuple, Optional, Dict, Set
from concurrent.futures import ThreadPoolExecutor, as_completed
from crazy_functions.rag_fns.arxiv_fns.tex_processor import TexProcessor
from crazy_functions.rag_fns.arxiv_fns.arxiv_fragment import ArxivFragment
def save_fragments_to_file(fragments, output_dir: str = "fragment_outputs"):
"""
将所有fragments保存为单个结构化markdown文件
Args:
fragments: fragment列表
output_dir: 输出目录
"""
from datetime import datetime
from pathlib import Path
import re
# 创建输出目录
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# 生成文件名
filename = f"fragments_{timestamp}.md"
file_path = output_path / filename
current_section = ""
section_count = {} # 用于跟踪每个章节的片段数量
with open(file_path, "w", encoding="utf-8") as f:
# 写入文档头部
f.write("# Document Fragments Analysis\n\n")
f.write("## Overview\n")
f.write(f"- Total Fragments: {len(fragments)}\n")
f.write(f"- Generated Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
# 如果有标题和摘要,添加到开头
if fragments and (fragments[0].title or fragments[0].abstract):
f.write("\n## Paper Information\n")
if fragments[0].title:
f.write(f"### Title\n{fragments[0].title}\n")
if fragments[0].abstract:
f.write(f"\n### Abstract\n{fragments[0].abstract}\n")
# 生成目录
f.write("\n## Table of Contents\n")
# 首先收集所有章节信息
sections = {}
for fragment in fragments:
section = fragment.section or "Uncategorized"
if section not in sections:
sections[section] = []
sections[section].append(fragment)
# 写入目录
for section, section_fragments in sections.items():
clean_section = section.strip()
if not clean_section:
clean_section = "Uncategorized"
f.write(
f"- [{clean_section}](#{clean_section.lower().replace(' ', '-')}) ({len(section_fragments)} fragments)\n")
# 写入正文内容
f.write("\n## Content\n")
# 按章节组织内容
for section, section_fragments in sections.items():
clean_section = section.strip() or "Uncategorized"
f.write(f"\n### {clean_section}\n")
# 写入每个fragment
for i, fragment in enumerate(section_fragments, 1):
f.write(f"\n#### Fragment {i} ({fragment.segment_type})\n")
# 元数据
f.write("**Metadata:**\n")
f.write(f"- Type: {fragment.segment_type}\n")
f.write(f"- Length: {len(fragment.content)} chars\n")
f.write(f"- Importance: {fragment.importance:.2f}\n")
f.write(f"- Is Appendix: {fragment.is_appendix}\n")
f.write(f"- File: {fragment.rel_path}\n")
# 内容
f.write("\n**Content:**\n")
f.write("```tex\n")
f.write(fragment.content)
f.write("\n```\n")
# 添加分隔线
if i < len(section_fragments):
f.write("\n---\n")
# 添加统计信息
f.write("\n## Statistics\n")
f.write("\n### Fragment Type Distribution\n")
type_stats = {}
for fragment in fragments:
type_stats[fragment.segment_type] = type_stats.get(fragment.segment_type, 0) + 1
for ftype, count in type_stats.items():
percentage = (count / len(fragments)) * 100
f.write(f"- {ftype}: {count} ({percentage:.1f}%)\n")
# 长度分布
f.write("\n### Length Distribution\n")
lengths = [len(f.content) for f in fragments]
f.write(f"- Minimum: {min(lengths)} chars\n")
f.write(f"- Maximum: {max(lengths)} chars\n")
f.write(f"- Average: {sum(lengths) / len(lengths):.1f} chars\n")
print(f"Fragments saved to: {file_path}")
return file_path
class ArxivSplitter:
"""Arxiv论文智能分割器"""
def __init__(self,
char_range: Tuple[int, int],
root_dir: str = "gpt_log/arxiv_cache",
proxies: Optional[Dict[str, str]] = None,
cache_ttl: int = 7 * 24 * 60 * 60):
"""
初始化分割器
Args:
char_range: 字符数范围(最小值, 最大值)
root_dir: 缓存根目录
proxies: 代理设置
cache_ttl: 缓存过期时间(秒)
"""
self.min_chars, self.max_chars = char_range
self.root_dir = Path(root_dir)
self.root_dir.mkdir(parents=True, exist_ok=True)
self.proxies = proxies or {}
self.cache_ttl = cache_ttl
# 动态计算最优线程数
import multiprocessing
cpu_count = multiprocessing.cpu_count()
# 根据CPU核心数动态设置,但设置上限防止过度并发
self.max_workers = min(32, cpu_count * 2)
# 初始化TeX处理器
self.tex_processor = TexProcessor(char_range)
# 配置日志
self._setup_logging()
def _setup_logging(self):
"""配置日志"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
self.logger = logging.getLogger(__name__)
def _normalize_arxiv_id(self, input_str: str) -> str:
"""规范化ArXiv ID"""
if 'arxiv.org/' in input_str.lower():
# 处理URL格式
if '/pdf/' in input_str:
arxiv_id = input_str.split('/pdf/')[-1]
else:
arxiv_id = input_str.split('/abs/')[-1]
# 移除版本号和其他后缀
return arxiv_id.split('v')[0].strip()
return input_str.split('v')[0].strip()
def _check_cache(self, paper_dir: Path) -> bool:
"""
检查缓存是否有效,包括文件完整性检查
Args:
paper_dir: 论文目录路径
Returns:
bool: 如果缓存有效返回True,否则返回False
"""
if not paper_dir.exists():
return False
# 检查目录中是否存在必要文件
has_tex_files = False
has_main_tex = False
for file_path in paper_dir.rglob("*"):
if file_path.suffix == '.tex':
has_tex_files = True
content = self.tex_processor.read_file(str(file_path))
if content and r'\documentclass' in content:
has_main_tex = True
break
if not (has_tex_files and has_main_tex):
return False
# 检查缓存时间
cache_time = paper_dir.stat().st_mtime
if (time.time() - cache_time) < self.cache_ttl:
self.logger.info(f"Using valid cache for {paper_dir.name}")
return True
return False
async def download_paper(self, arxiv_id: str, paper_dir: Path) -> bool:
"""
异步下载论文,包含重试机制和临时文件处理
Args:
arxiv_id: ArXiv论文ID
paper_dir: 目标目录路径
Returns:
bool: 下载成功返回True,否则返回False
"""
from crazy_functions.rag_fns.arxiv_fns.arxiv_downloader import ArxivDownloader
temp_tar_path = paper_dir / f"{arxiv_id}_temp.tar.gz"
final_tar_path = paper_dir / f"{arxiv_id}.tar.gz"
# 确保目录存在
paper_dir.mkdir(parents=True, exist_ok=True)
# 尝试使用 ArxivDownloader 下载
try:
downloader = ArxivDownloader(root_dir=str(paper_dir), proxies=self.proxies)
downloaded_dir = downloader.download_paper(arxiv_id)
if downloaded_dir:
self.logger.info(f"Successfully downloaded using ArxivDownloader to {downloaded_dir}")
return True
except Exception as e:
self.logger.warning(f"ArxivDownloader failed: {str(e)}. Falling back to direct download.")
# 如果 ArxivDownloader 失败,使用原有的下载方式作为备选
urls = [
f"https://arxiv.org/src/{arxiv_id}",
f"https://arxiv.org/e-print/{arxiv_id}"
]
max_retries = 3
retry_delay = 1 # 初始重试延迟(秒)
for url in urls:
for attempt in range(max_retries):
try:
self.logger.info(f"Downloading from {url} (attempt {attempt + 1}/{max_retries})")
async with aiohttp.ClientSession() as session:
async with session.get(url, proxy=self.proxies.get('http')) as response:
if response.status == 200:
content = await response.read()
# 写入临时文件
temp_tar_path.write_bytes(content)
try:
# 验证tar文件完整性并解压
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, self._process_tar_file, temp_tar_path, paper_dir)
# 下载成功后移动临时文件到最终位置
temp_tar_path.rename(final_tar_path)
return True
except Exception as e:
self.logger.warning(f"Invalid tar file: {str(e)}")
if temp_tar_path.exists():
temp_tar_path.unlink()
except Exception as e:
self.logger.warning(f"Download attempt {attempt + 1} failed from {url}: {str(e)}")
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
continue
return False
def _process_tar_file(self, tar_path: Path, extract_path: Path):
"""处理tar文件的同步操作"""
with tarfile.open(tar_path, 'r:gz') as tar:
tar.testall() # 验证文件完整性
tar.extractall(path=extract_path) # 解压文件
def _process_single_tex(self, file_path: str) -> List[ArxivFragment]:
"""处理单个TeX文件"""
try:
content = self.tex_processor.read_file(file_path)
if not content:
return []
# 提取元数据
is_main = r'\documentclass' in content
title, abstract = "", ""
if is_main:
title, abstract = self.tex_processor.extract_metadata(content)
# 分割内容
segments = self.tex_processor.split_content(content)
fragments = []
for i, (segment_content, section, is_appendix) in enumerate(segments):
if not segment_content.strip():
continue
segment_type = self.tex_processor.detect_segment_type(segment_content)
importance = self.tex_processor.calculate_importance(
segment_content, segment_type, is_main
)
fragments.append(ArxivFragment(
file_path=file_path,
content=segment_content,
segment_index=i,
total_segments=len(segments),
rel_path=str(Path(file_path).relative_to(self.root_dir)),
segment_type=segment_type,
title=title,
abstract=abstract,
section=section,
is_appendix=is_appendix,
importance=importance
))
return fragments
except Exception as e:
self.logger.error(f"Error processing {file_path}: {str(e)}")
return []
async def process(self, arxiv_id_or_url: str) -> List[ArxivFragment]:
"""处理ArXiv论文"""
try:
arxiv_id = self._normalize_arxiv_id(arxiv_id_or_url)
paper_dir = self.root_dir / arxiv_id
# 检查缓存
if not self._check_cache(paper_dir):
paper_dir.mkdir(exist_ok=True)
if not await self.download_paper(arxiv_id, paper_dir):
raise RuntimeError(f"Failed to download paper {arxiv_id}")
# 查找主TeX文件
main_tex = self.tex_processor.find_main_tex_file(str(paper_dir))
if not main_tex:
raise RuntimeError(f"No main TeX file found in {paper_dir}")
# 获取所有相关TeX文件
tex_files = self.tex_processor.resolve_includes(main_tex)
if not tex_files:
raise RuntimeError(f"No valid TeX files found for {arxiv_id}")
# 并行处理所有TeX文件
fragments = []
chunk_size = max(1, len(tex_files) // self.max_workers) # 计算每个线程处理的文件数
loop = asyncio.get_event_loop()
async def process_chunk(chunk_files):
chunk_fragments = []
for file_path in chunk_files:
try:
result = await loop.run_in_executor(None, self._process_single_tex, file_path)
chunk_fragments.extend(result)
except Exception as e:
self.logger.error(f"Error processing {file_path}: {str(e)}")
return chunk_fragments
# 将文件分成多个块
file_chunks = [tex_files[i:i + chunk_size] for i in range(0, len(tex_files), chunk_size)]
# 异步处理每个块
chunk_results = await asyncio.gather(*[process_chunk(chunk) for chunk in file_chunks])
for result in chunk_results:
fragments.extend(result)
# 重新计算片段索引并排序
fragments.sort(key=lambda x: (x.rel_path, x.segment_index))
total_fragments = len(fragments)
for i, fragment in enumerate(fragments):
fragment.segment_index = i
fragment.total_segments = total_fragments
# 在返回之前添加过滤
fragments = self.tex_processor.filter_fragments(fragments)
return fragments
except Exception as e:
self.logger.error(f"Failed to process {arxiv_id_or_url}: {str(e)}")
raise
async def test_arxiv_splitter():
"""测试ArXiv分割器的功能"""
# 测试配置
test_cases = [
{
"arxiv_id": "2411.03663",
"expected_title": "Large Language Models and Simple Scripts",
"min_fragments": 10,
},
# {
# "arxiv_id": "1805.10988",
# "expected_title": "RAG vs Fine-tuning",
# "min_fragments": 15,
# }
]
# 创建分割器实例
splitter = ArxivSplitter(
char_range=(800, 1800),
root_dir="test_cache"
)
for case in test_cases:
print(f"\nTesting paper: {case['arxiv_id']}")
try:
fragments = await splitter.process(case['arxiv_id'])
# 保存fragments
output_dir = save_fragments_to_file(fragments,output_dir="crazy_functions/rag_fns/arxiv_fns/gpt_log")
print(f"Output saved to: {output_dir}")
# 内容检查
for fragment in fragments:
# 长度检查
print((fragment.content))
print(len(fragment.content))
# 类型检查
except Exception as e:
print(f"✗ Test failed for {case['arxiv_id']}: {str(e)}")
raise
if __name__ == "__main__":
asyncio.run(test_arxiv_splitter())