这个提交包含在:
lbykkkk
2024-12-01 22:00:41 +08:00
父节点 b3aef6b393
当前提交 3beb22a347
共有 3 个文件被更改,包括 459 次插入39 次删除

查看文件

@@ -14,9 +14,10 @@ from crazy_functions.rag_fns.arxiv_fns.essay_structure import EssayStructurePars
from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment
from crazy_functions.rag_fns.arxiv_fns.tex_utils import TexUtils
from crazy_functions.doc_fns.content_folder import PaperContentFormatter, PaperMetadata
def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "fragment_outputs") -> Path:
def save_fragments_to_file(fragments: List[SectionFragment], output_dir: Path ) -> Path:
"""
Save all fragments to a single structured markdown file.
@@ -37,7 +38,7 @@ def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "
# Generate filename
filename = f"paper_latex_content_{timestamp}.md"
file_path = output_path / filename
file_path = output_path/ filename
# Group fragments by section
sections = {}
@@ -733,7 +734,53 @@ class ArxivSplitter:
return content.strip()
async def test_arxiv_splitter():
def process_arxiv_sync(splitter: ArxivSplitter, arxiv_id: str) -> tuple[List[SectionFragment], str, Path]:
"""
同步处理 ArXiv 文档并返回分割后的片段
Args:
splitter: ArxivSplitter 实例
arxiv_id: ArXiv 文档ID
Returns:
list: 分割后的文档片段列表
"""
try:
# 创建一个异步函数来执行异步操作
async def _process():
return await splitter.process(arxiv_id)
# 使用 asyncio.run() 运行异步函数
fragments = asyncio.run(_process())
# 保存片段到文件
output_dir = save_fragments_to_file(
fragments,
output_dir=splitter.root_dir / "arxiv_fragments"
)
print(f"Output saved to: {output_dir}")
# 创建论文格式化器
formatter = PaperContentFormatter()
# 准备元数据
# 创建格式化选项
metadata = PaperMetadata(
title=fragments[0].title,
authors=fragments[0].authors,
abstract=fragments[0].abstract,
catalogs=fragments[0].catalogs,
arxiv_id=fragments[0].arxiv_id
)
# 格式化内容
formatted_content = formatter.format(fragments, metadata)
return fragments, formatted_content, output_dir
except Exception as e:
print(f"✗ Processing failed for {arxiv_id}: {str(e)}")
raise
def test_arxiv_splitter():
"""测试ArXiv分割器的功能"""
# 测试配置
@@ -752,25 +799,21 @@ async def test_arxiv_splitter():
# 创建分割器实例
splitter = ArxivSplitter(
root_dir="test_cache"
root_dir="private_upload/default_user"
)
for case in test_cases:
print(f"\nTesting paper: {case['arxiv_id']}")
try:
fragments = await splitter.process(case['arxiv_id'])
# fragments = await splitter.process(case['arxiv_id'])
fragments, formatted_content, output_dir = process_arxiv_sync(splitter, case['arxiv_id'])
# 保存fragments
output_dir = save_fragments_to_file(fragments, output_dir="crazy_functions/rag_fns/arxiv_fns/gpt_log")
print(f"Output saved to: {output_dir}")
# # 内容检查
# for fragment in fragments:
# # 长度检查
#
# print((fragment.content))
# print(len(fragment.content))
for fragment in fragments:
# 长度检查
print((fragment.content))
print(len(fragment.content))
# 类型检查
print(output_dir)
except Exception as e:
print(f"✗ Test failed for {case['arxiv_id']}: {str(e)}")