From cf51d4b205e417a141b2c80bcb03cc851b50e347 Mon Sep 17 00:00:00 2001 From: lbykkkk Date: Sun, 1 Dec 2024 17:28:23 +0800 Subject: [PATCH] up --- .../rag_fns/arxiv_fns/arxiv_fragment.py | 56 ------------------- .../rag_fns/arxiv_fns/arxiv_splitter.py | 38 ++++++++++--- .../rag_fns/arxiv_fns/essay_structure.py | 1 + .../rag_fns/arxiv_fns/section_fragment.py | 3 +- 4 files changed, 32 insertions(+), 66 deletions(-) delete mode 100644 crazy_functions/rag_fns/arxiv_fns/arxiv_fragment.py diff --git a/crazy_functions/rag_fns/arxiv_fns/arxiv_fragment.py b/crazy_functions/rag_fns/arxiv_fns/arxiv_fragment.py deleted file mode 100644 index 73d95d9b..00000000 --- a/crazy_functions/rag_fns/arxiv_fns/arxiv_fragment.py +++ /dev/null @@ -1,56 +0,0 @@ -from dataclasses import dataclass - -@dataclass -class ArxivFragment: - """Arxiv论文片段数据类""" - file_path: str # 文件路径 - content: str # 内容 - segment_index: int # 片段索引 - total_segments: int # 总片段数 - rel_path: str # 相对路径 - segment_type: str # 片段类型(text/math/table/figure等) - title: str # 论文标题 - abstract: str # 论文摘要 - section: str # 所属章节 - is_appendix: bool # 是否是附录 - importance: float = 1.0 # 重要性得分 - arxiv_id: str = "" # 添加 arxiv_id 属性 - - @staticmethod - def merge_segments(seg1: 'ArxivFragment', seg2: 'ArxivFragment') -> 'ArxivFragment': - """ - 合并两个片段的静态方法 - - Args: - seg1: 第一个片段 - seg2: 第二个片段 - - Returns: - ArxivFragment: 合并后的片段 - """ - # 合并内容 - merged_content = f"{seg1.content}\n{seg2.content}" - - # 确定合并后的类型 - def _merge_segment_type(type1: str, type2: str) -> str: - if type1 == type2: - return type1 - if type1 == 'text': - return type2 - if type2 == 'text': - return type1 - return 'mixed' - - return ArxivFragment( - file_path=seg1.file_path, - content=merged_content, - segment_index=seg1.segment_index, - total_segments=seg1.total_segments - 1, - rel_path=seg1.rel_path, - segment_type=_merge_segment_type(seg1.segment_type, seg2.segment_type), - title=seg1.title, - abstract=seg1.abstract, - section=seg1.section, - is_appendix=seg1.is_appendix, - importance=max(seg1.importance, seg2.importance) - ) \ No newline at end of file diff --git a/crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py b/crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py index b96f173f..049fa4a4 100644 --- a/crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py +++ b/crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py @@ -15,6 +15,7 @@ from crazy_functions.rag_fns.arxiv_fns.tex_utils import TexUtils from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment from crazy_functions.rag_fns.arxiv_fns.essay_structure import EssayStructureParser, DocumentStructure, read_tex_file from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section +from crazy_functions.rag_fns.arxiv_fns.author_extractor import LatexAuthorExtractor def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "fragment_outputs") -> Path: @@ -38,7 +39,7 @@ def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = " output_path.mkdir(parents=True, exist_ok=True) # Generate filename - filename = f"fragments_{timestamp}.md" + filename = f"paper_latex_content_{timestamp}.md" file_path = output_path / filename # Group fragments by section @@ -61,13 +62,17 @@ def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = " f.write("\n## Paper Information\n") if fragments[0].title: f.write(f"### Title\n{fragments[0].title}\n") + if fragments[0].authors: + f.write(f"\n### Authors\n{fragments[0].authors}\n") if fragments[0].abstract: f.write(f"\n### Abstract\n{fragments[0].abstract}\n") # Write section tree if available if fragments and fragments[0].catalogs: f.write("\n## Section Tree\n") + f.write("```\n") # 添加代码块开始标记 f.write(fragments[0].catalogs) + f.write("\n```") # 添加代码块结束标记 # Generate table of contents f.write("\n## Table of Contents\n") @@ -98,9 +103,12 @@ def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = " # Content f.write("\n**Content:**\n") - f.write("```tex\n") + # f.write("```tex\n") + # f.write(fragment.content) + # f.write("\n```\n") + f.write("\n") f.write(fragment.content) - f.write("\n```\n") + f.write("\n") # Bibliography if exists if fragment.bibliography: @@ -562,6 +570,11 @@ class ArxivSplitter: if not main_tex: raise RuntimeError(f"No main TeX file found in {paper_dir}") + # 读取主 TeX 文件内容 + main_tex_content = read_tex_file(main_tex) + + + # Get all related TeX files and references tex_files = self.tex_processor.resolve_includes(main_tex) ref_bib = self.tex_processor.resolve_references(main_tex, paper_dir) @@ -572,6 +585,11 @@ class ArxivSplitter: # Reset document structure for new processing self.document_structure = DocumentStructure() + # 提取作者信息 + author_extractor = LatexAuthorExtractor() + authors = author_extractor.extract_authors(main_tex_content) + self.document_structure.authors = authors # 保存到文档结构中 + # Process each TeX file for file_path in tex_files: self.logger.info(f"Processing TeX file: {file_path}") @@ -624,6 +642,7 @@ class ArxivSplitter: # Create a base template for all fragments to avoid repetitive assignments base_fragment_template = { 'title': doc_structure.title, + 'authors': doc_structure.authors, 'abstract': doc_structure.abstract, 'catalogs': section_tree, 'arxiv_id': arxiv_id @@ -723,6 +742,7 @@ class ArxivSplitter: return content.strip() + async def test_arxiv_splitter(): """测试ArXiv分割器的功能""" @@ -754,12 +774,12 @@ async def test_arxiv_splitter(): # 保存fragments output_dir = save_fragments_to_file(fragments,output_dir="crazy_functions/rag_fns/arxiv_fns/gpt_log") print(f"Output saved to: {output_dir}") - # 内容检查 - for fragment in fragments: - # 长度检查 - - print((fragment.content)) - print(len(fragment.content)) + # # 内容检查 + # for fragment in fragments: + # # 长度检查 + # + # print((fragment.content)) + # print(len(fragment.content)) # 类型检查 diff --git a/crazy_functions/rag_fns/arxiv_fns/essay_structure.py b/crazy_functions/rag_fns/arxiv_fns/essay_structure.py index 9f961624..cc1f1391 100644 --- a/crazy_functions/rag_fns/arxiv_fns/essay_structure.py +++ b/crazy_functions/rag_fns/arxiv_fns/essay_structure.py @@ -31,6 +31,7 @@ def read_tex_file(file_path): @dataclass class DocumentStructure: title: str = '' + authors: str = '' abstract: str = '' toc: List[Section] = field(default_factory=list) metadata: Dict[str, str] = field(default_factory=dict) diff --git a/crazy_functions/rag_fns/arxiv_fns/section_fragment.py b/crazy_functions/rag_fns/arxiv_fns/section_fragment.py index f04cdd55..f933837d 100644 --- a/crazy_functions/rag_fns/arxiv_fns/section_fragment.py +++ b/crazy_functions/rag_fns/arxiv_fns/section_fragment.py @@ -3,7 +3,8 @@ from dataclasses import dataclass @dataclass class SectionFragment: """Arxiv论文片段数据类""" - title: str # 文件路径 + title: str # 论文标题 + authors: str abstract: str # 论文摘要 catalogs: str # 文章各章节的目录结构 arxiv_id: str = "" # 添加 arxiv_id 属性