这个提交包含在:
lbykkkk
2024-12-01 17:35:57 +08:00
父节点 cf51d4b205
当前提交 b3aef6b393
共有 13 个文件被更改,包括 398 次插入234 次删除

查看文件

@@ -1,21 +1,19 @@
import os
import re
import time
import aiohttp
import asyncio
import requests
import tarfile
import logging
from pathlib import Path
import re
import tarfile
import time
from copy import deepcopy
from pathlib import Path
from typing import List, Optional, Dict, Set
from typing import Generator, List, Tuple, Optional, Dict, Set
from concurrent.futures import ThreadPoolExecutor, as_completed
from crazy_functions.rag_fns.arxiv_fns.tex_utils import TexUtils
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment
import aiohttp
from crazy_functions.rag_fns.arxiv_fns.author_extractor import LatexAuthorExtractor
from crazy_functions.rag_fns.arxiv_fns.essay_structure import EssayStructureParser, DocumentStructure, read_tex_file
from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section
from crazy_functions.rag_fns.arxiv_fns.author_extractor import LatexAuthorExtractor
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment
from crazy_functions.rag_fns.arxiv_fns.tex_utils import TexUtils
def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "fragment_outputs") -> Path:
@@ -31,7 +29,6 @@ def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "
"""
from datetime import datetime
from pathlib import Path
import re
# Create output directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
@@ -103,9 +100,6 @@ def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "
# Content
f.write("\n**Content:**\n")
# f.write("```tex\n")
# f.write(fragment.content)
# f.write("\n```\n")
f.write("\n")
f.write(fragment.content)
f.write("\n")
@@ -140,6 +134,7 @@ def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "
print(f"Fragments saved to: {file_path}")
return file_path
# 定义各种引用命令的模式
CITATION_PATTERNS = [
# 基本的 \cite{} 格式
@@ -199,8 +194,6 @@ class ArxivSplitter:
# 配置日志
self._setup_logging()
def _setup_logging(self):
"""配置日志"""
logging.basicConfig(
@@ -221,7 +214,6 @@ class ArxivSplitter:
return arxiv_id.split('v')[0].strip()
return input_str.split('v')[0].strip()
def _check_cache(self, paper_dir: Path) -> bool:
"""
检查缓存是否有效,包括文件完整性检查
@@ -545,6 +537,7 @@ class ArxivSplitter:
self.logger.error(f"Error finding citation contexts: {str(e)}")
return contexts
async def process(self, arxiv_id_or_url: str) -> List[SectionFragment]:
"""
Process ArXiv paper and convert to list of SectionFragments.
@@ -573,8 +566,6 @@ class ArxivSplitter:
# 读取主 TeX 文件内容
main_tex_content = read_tex_file(main_tex)
# Get all related TeX files and references
tex_files = self.tex_processor.resolve_includes(main_tex)
ref_bib = self.tex_processor.resolve_references(main_tex, paper_dir)
@@ -742,7 +733,6 @@ class ArxivSplitter:
return content.strip()
async def test_arxiv_splitter():
"""测试ArXiv分割器的功能"""
@@ -765,14 +755,13 @@ async def test_arxiv_splitter():
root_dir="test_cache"
)
for case in test_cases:
print(f"\nTesting paper: {case['arxiv_id']}")
try:
fragments = await splitter.process(case['arxiv_id'])
# 保存fragments
output_dir = save_fragments_to_file(fragments,output_dir="crazy_functions/rag_fns/arxiv_fns/gpt_log")
output_dir = save_fragments_to_file(fragments, output_dir="crazy_functions/rag_fns/arxiv_fns/gpt_log")
print(f"Output saved to: {output_dir}")
# # 内容检查
# for fragment in fragments:
@@ -780,7 +769,7 @@ async def test_arxiv_splitter():
#
# print((fragment.content))
# print(len(fragment.content))
# 类型检查
# 类型检查
except Exception as e:
@@ -789,4 +778,4 @@ async def test_arxiv_splitter():
if __name__ == "__main__":
asyncio.run(test_arxiv_splitter())
asyncio.run(test_arxiv_splitter())