这个提交包含在:
lbykkkk
2024-12-01 17:35:57 +08:00
父节点 cf51d4b205
当前提交 b3aef6b393
共有 13 个文件被更改,包括 398 次插入234 次删除

查看文件

@@ -1,12 +1,14 @@
import logging
import requests
import tarfile
from pathlib import Path
from typing import Optional, Dict
import requests
class ArxivDownloader:
"""用于下载arXiv论文源码的下载器"""
def __init__(self, root_dir: str = "./papers", proxies: Optional[Dict[str, str]] = None):
"""
初始化下载器
@@ -18,13 +20,13 @@ class ArxivDownloader:
self.root_dir = Path(root_dir)
self.root_dir.mkdir(exist_ok=True)
self.proxies = proxies
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
def _download_and_extract(self, arxiv_id: str) -> str:
"""
下载并解压arxiv论文源码
@@ -40,19 +42,19 @@ class ArxivDownloader:
"""
paper_dir = self.root_dir / arxiv_id
tar_path = paper_dir / f"{arxiv_id}.tar.gz"
# 检查缓存
if paper_dir.exists() and any(paper_dir.iterdir()):
logging.info(f"Using cached version for {arxiv_id}")
return str(paper_dir)
paper_dir.mkdir(exist_ok=True)
urls = [
f"https://arxiv.org/src/{arxiv_id}",
f"https://arxiv.org/e-print/{arxiv_id}"
]
for url in urls:
try:
logging.info(f"Downloading from {url}")
@@ -65,9 +67,9 @@ class ArxivDownloader:
except Exception as e:
logging.warning(f"Download failed for {url}: {e}")
continue
raise RuntimeError(f"Failed to download paper {arxiv_id}")
def download_paper(self, arxiv_id: str) -> str:
"""
下载指定的arXiv论文
@@ -80,6 +82,7 @@ class ArxivDownloader:
"""
return self._download_and_extract(arxiv_id)
def main():
"""测试下载功能"""
# 配置代理(如果需要)
@@ -87,16 +90,16 @@ def main():
"http": "http://your-proxy:port",
"https": "https://your-proxy:port"
}
# 创建下载器实例如果不需要代理,可以不传入proxies参数
downloader = ArxivDownloader(root_dir="./downloaded_papers", proxies=None)
# 测试下载一篇论文这里使用一个示例ID
try:
paper_id = "2103.00020" # 这是一个示例ID
paper_dir = downloader.download_paper(paper_id)
print(f"Successfully downloaded paper to: {paper_dir}")
# 检查下载的文件
paper_path = Path(paper_dir)
if paper_path.exists():
@@ -107,5 +110,6 @@ def main():
except Exception as e:
print(f"Error downloading paper: {e}")
if __name__ == "__main__":
main()
main()

查看文件

@@ -1,21 +1,19 @@
import os
import re
import time
import aiohttp
import asyncio
import requests
import tarfile
import logging
from pathlib import Path
import re
import tarfile
import time
from copy import deepcopy
from pathlib import Path
from typing import List, Optional, Dict, Set
from typing import Generator, List, Tuple, Optional, Dict, Set
from concurrent.futures import ThreadPoolExecutor, as_completed
from crazy_functions.rag_fns.arxiv_fns.tex_utils import TexUtils
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment
import aiohttp
from crazy_functions.rag_fns.arxiv_fns.author_extractor import LatexAuthorExtractor
from crazy_functions.rag_fns.arxiv_fns.essay_structure import EssayStructureParser, DocumentStructure, read_tex_file
from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section
from crazy_functions.rag_fns.arxiv_fns.author_extractor import LatexAuthorExtractor
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment
from crazy_functions.rag_fns.arxiv_fns.tex_utils import TexUtils
def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "fragment_outputs") -> Path:
@@ -31,7 +29,6 @@ def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "
"""
from datetime import datetime
from pathlib import Path
import re
# Create output directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
@@ -103,9 +100,6 @@ def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "
# Content
f.write("\n**Content:**\n")
# f.write("```tex\n")
# f.write(fragment.content)
# f.write("\n```\n")
f.write("\n")
f.write(fragment.content)
f.write("\n")
@@ -140,6 +134,7 @@ def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "
print(f"Fragments saved to: {file_path}")
return file_path
# 定义各种引用命令的模式
CITATION_PATTERNS = [
# 基本的 \cite{} 格式
@@ -199,8 +194,6 @@ class ArxivSplitter:
# 配置日志
self._setup_logging()
def _setup_logging(self):
"""配置日志"""
logging.basicConfig(
@@ -221,7 +214,6 @@ class ArxivSplitter:
return arxiv_id.split('v')[0].strip()
return input_str.split('v')[0].strip()
def _check_cache(self, paper_dir: Path) -> bool:
"""
检查缓存是否有效,包括文件完整性检查
@@ -545,6 +537,7 @@ class ArxivSplitter:
self.logger.error(f"Error finding citation contexts: {str(e)}")
return contexts
async def process(self, arxiv_id_or_url: str) -> List[SectionFragment]:
"""
Process ArXiv paper and convert to list of SectionFragments.
@@ -573,8 +566,6 @@ class ArxivSplitter:
# 读取主 TeX 文件内容
main_tex_content = read_tex_file(main_tex)
# Get all related TeX files and references
tex_files = self.tex_processor.resolve_includes(main_tex)
ref_bib = self.tex_processor.resolve_references(main_tex, paper_dir)
@@ -742,7 +733,6 @@ class ArxivSplitter:
return content.strip()
async def test_arxiv_splitter():
"""测试ArXiv分割器的功能"""
@@ -765,14 +755,13 @@ async def test_arxiv_splitter():
root_dir="test_cache"
)
for case in test_cases:
print(f"\nTesting paper: {case['arxiv_id']}")
try:
fragments = await splitter.process(case['arxiv_id'])
# 保存fragments
output_dir = save_fragments_to_file(fragments,output_dir="crazy_functions/rag_fns/arxiv_fns/gpt_log")
output_dir = save_fragments_to_file(fragments, output_dir="crazy_functions/rag_fns/arxiv_fns/gpt_log")
print(f"Output saved to: {output_dir}")
# # 内容检查
# for fragment in fragments:
@@ -780,7 +769,7 @@ async def test_arxiv_splitter():
#
# print((fragment.content))
# print(len(fragment.content))
# 类型检查
# 类型检查
except Exception as e:
@@ -789,4 +778,4 @@ async def test_arxiv_splitter():
if __name__ == "__main__":
asyncio.run(test_arxiv_splitter())
asyncio.run(test_arxiv_splitter())

查看文件

@@ -0,0 +1,177 @@
import re
from typing import Optional
class LatexAuthorExtractor:
def __init__(self):
# Patterns for matching author blocks with balanced braces
self.author_block_patterns = [
# Standard LaTeX patterns with optional arguments
r'\\author(?:\s*\[[^\]]*\])?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
r'\\(?:title)?author[s]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
r'\\name[s]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
r'\\Author[s]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
r'\\AUTHOR[S]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
# Conference and journal specific patterns
r'\\addauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
r'\\IEEEauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
r'\\speaker\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
r'\\authorrunning\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
# Academic publisher specific patterns
r'\\alignauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
r'\\spauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
r'\\authors\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
]
# Cleaning patterns for LaTeX commands and formatting
self.cleaning_patterns = [
# Text formatting commands - preserve content
(r'\\textbf\{([^}]+)\}', r'\1'),
(r'\\textit\{([^}]+)\}', r'\1'),
(r'\\emph\{([^}]+)\}', r'\1'),
(r'\\texttt\{([^}]+)\}', r'\1'),
(r'\\textrm\{([^}]+)\}', r'\1'),
(r'\\text\{([^}]+)\}', r'\1'),
# Affiliation and footnote markers
(r'\$\^{[^}]+}\$', ''),
(r'\^{[^}]+}', ''),
(r'\\thanks\{[^}]+\}', ''),
(r'\\footnote\{[^}]+\}', ''),
# Email and contact formatting
(r'\\email\{([^}]+)\}', r'\1'),
(r'\\href\{[^}]+\}\{([^}]+)\}', r'\1'),
# Institution formatting
(r'\\inst\{[^}]+\}', ''),
(r'\\affil\{[^}]+\}', ''),
# Special characters and symbols
(r'\\&', '&'),
(r'\\\\\s*', ' '),
(r'\\,', ' '),
(r'\\;', ' '),
(r'\\quad', ' '),
(r'\\qquad', ' '),
# Math mode content
(r'\$[^$]+\$', ''),
# Common symbols
(r'\\dagger', ''),
(r'\\ddagger', ''),
(r'\\ast', '*'),
(r'\\star', ''),
# Remove remaining LaTeX commands
(r'\\[a-zA-Z]+', ''),
# Clean up remaining special characters
(r'[\\{}]', '')
]
def extract_author_block(self, text: str) -> Optional[str]:
"""
Extract the complete author block from LaTeX text.
Args:
text (str): Input LaTeX text
Returns:
Optional[str]: Extracted author block or None if not found
"""
try:
if not text:
return None
for pattern in self.author_block_patterns:
match = re.search(pattern, text, re.DOTALL | re.MULTILINE)
if match:
return match.group(1).strip()
return None
except (AttributeError, IndexError) as e:
print(f"Error extracting author block: {e}")
return None
def clean_tex_commands(self, text: str) -> str:
"""
Remove LaTeX commands and formatting from text while preserving content.
Args:
text (str): Text containing LaTeX commands
Returns:
str: Cleaned text with commands removed
"""
if not text:
return ""
cleaned_text = text
# Apply cleaning patterns
for pattern, replacement in self.cleaning_patterns:
cleaned_text = re.sub(pattern, replacement, cleaned_text)
# Clean up whitespace
cleaned_text = re.sub(r'\s+', ' ', cleaned_text)
cleaned_text = cleaned_text.strip()
return cleaned_text
def extract_authors(self, text: str) -> Optional[str]:
"""
Extract and clean author information from LaTeX text.
Args:
text (str): Input LaTeX text
Returns:
Optional[str]: Cleaned author information or None if extraction fails
"""
try:
if not text:
return None
# Extract author block
author_block = self.extract_author_block(text)
if not author_block:
return None
# Clean LaTeX commands
cleaned_authors = self.clean_tex_commands(author_block)
return cleaned_authors or None
except Exception as e:
print(f"Error processing text: {e}")
return None
def test_author_extractor():
"""Test the LatexAuthorExtractor with sample inputs."""
test_cases = [
# Basic test case
(r"\author{John Doe}", "John Doe"),
# Test with multiple authors
(r"\author{Alice Smith \and Bob Jones}", "Alice Smith and Bob Jones"),
# Test with affiliations
(r"\author[1]{John Smith}\affil[1]{University}", "John Smith"),
]
extractor = LatexAuthorExtractor()
for i, (input_tex, expected) in enumerate(test_cases, 1):
result = extractor.extract_authors(input_tex)
print(f"\nTest case {i}:")
print(f"Input: {input_tex[:50]}...")
print(f"Expected: {expected[:50]}...")
print(f"Got: {result[:50]}...")
print(f"Pass: {bool(result and result.strip() == expected.strip())}")
if __name__ == "__main__":
test_author_extractor()

查看文件

@@ -5,14 +5,15 @@ This module provides functionality for parsing and extracting structured informa
including metadata, document structure, and content. It uses modular design and clean architecture principles.
"""
import logging
import re
from abc import ABC, abstractmethod
import logging
from dataclasses import dataclass, field
from typing import List, Optional, Dict
from copy import deepcopy
from dataclasses import dataclass, field
from typing import List, Dict
from crazy_functions.rag_fns.arxiv_fns.latex_cleaner import clean_latex_commands
from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section, SectionLevel, EnhancedSectionExtractor
from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section, EnhancedSectionExtractor
# Configure logging
logging.basicConfig(level=logging.INFO)
@@ -28,6 +29,7 @@ def read_tex_file(file_path):
except UnicodeDecodeError:
continue
@dataclass
class DocumentStructure:
title: str = ''
@@ -68,7 +70,7 @@ class DocumentStructure:
if other_section.title in sections_map:
# Merge existing section
idx = next(i for i, s in enumerate(merged.toc)
if s.title == other_section.title)
if s.title == other_section.title)
merged.toc[idx] = merged.toc[idx].merge(other_section)
else:
# Add new section
@@ -149,6 +151,8 @@ class DocumentStructure:
result.extend(_format_section(section, 0) for section in self.toc)
return "".join(result)
class BaseExtractor(ABC):
"""Base class for LaTeX content extractors."""
@@ -157,6 +161,7 @@ class BaseExtractor(ABC):
"""Extract specific content from LaTeX document."""
pass
class TitleExtractor(BaseExtractor):
"""Extracts title from LaTeX document."""
@@ -180,6 +185,7 @@ class TitleExtractor(BaseExtractor):
return clean_latex_commands(title)
return ''
class AbstractExtractor(BaseExtractor):
"""Extracts abstract from LaTeX document."""
@@ -203,6 +209,7 @@ class AbstractExtractor(BaseExtractor):
return clean_latex_commands(abstract)
return ''
class EssayStructureParser:
"""Main class for parsing LaTeX documents."""
@@ -231,6 +238,7 @@ class EssayStructureParser:
content = re.sub(r'(?<!\\)%.*$', '', content, flags=re.MULTILINE)
return content
def pretty_print_structure(doc: DocumentStructure, max_content_length: int = 100):
"""Print document structure in a readable format."""
print(f"Title: {doc.title}\n")
@@ -250,10 +258,10 @@ def pretty_print_structure(doc: DocumentStructure, max_content_length: int = 100
for section in doc.toc:
print_section(section)
# Example usage:
if __name__ == "__main__":
# Test with a file
file_path = 'test_cache/2411.03663/neurips_2024.tex'
main_tex = read_tex_file(file_path)
@@ -278,5 +286,5 @@ if __name__ == "__main__":
additional_doc = parser.parse(tex_content)
main_doc = main_doc.merge(additional_doc)
tree= main_doc.generate_toc_tree()
pretty_print_structure(main_doc)
tree = main_doc.generate_toc_tree()
pretty_print_structure(main_doc)

查看文件

@@ -1,9 +1,9 @@
from dataclasses import dataclass, field
from typing import Set, Dict, Pattern, Optional, List, Tuple
import re
from enum import Enum
import logging
import re
from dataclasses import dataclass, field
from enum import Enum
from functools import lru_cache
from typing import Set, Dict, Pattern, Optional, List, Tuple
class EnvType(Enum):
@@ -326,4 +326,4 @@ if __name__ == "__main__":
content = read_tex_file(file_path)
cleaner = LatexCleaner(config)
text = cleaner.clean_text(text)
print(text)
print(text)

查看文件

@@ -1,4 +1,5 @@
from dataclasses import dataclass, field
from dataclasses import dataclass
@dataclass
class LaTeXPatterns:
@@ -142,124 +143,124 @@ class LaTeXPatterns:
]
metadata_patterns = {
# 标题相关
'title': [
r'\\title\{([^}]+)\}',
r'\\Title\{([^}]+)\}',
r'\\doctitle\{([^}]+)\}',
r'\\subtitle\{([^}]+)\}',
r'\\chapter\*?\{([^}]+)\}', # 第一章可能作为标题
r'\\maketitle\s*\\section\*?\{([^}]+)\}' # 第一节可能作为标题
],
# 标题相关
'title': [
r'\\title\{([^}]+)\}',
r'\\Title\{([^}]+)\}',
r'\\doctitle\{([^}]+)\}',
r'\\subtitle\{([^}]+)\}',
r'\\chapter\*?\{([^}]+)\}', # 第一章可能作为标题
r'\\maketitle\s*\\section\*?\{([^}]+)\}' # 第一节可能作为标题
],
# 摘要相关
'abstract': [
r'\\begin{abstract}(.*?)\\end{abstract}',
r'\\abstract\{([^}]+)\}',
r'\\begin{摘要}(.*?)\\end{摘要}',
r'\\begin{Summary}(.*?)\\end{Summary}',
r'\\begin{synopsis}(.*?)\\end{synopsis}',
r'\\begin{abstracten}(.*?)\\end{abstracten}' # 英文摘要
],
# 摘要相关
'abstract': [
r'\\begin{abstract}(.*?)\\end{abstract}',
r'\\abstract\{([^}]+)\}',
r'\\begin{摘要}(.*?)\\end{摘要}',
r'\\begin{Summary}(.*?)\\end{Summary}',
r'\\begin{synopsis}(.*?)\\end{synopsis}',
r'\\begin{abstracten}(.*?)\\end{abstracten}' # 英文摘要
],
# 作者信息
'author': [
r'\\author\{([^}]+)\}',
r'\\Author\{([^}]+)\}',
r'\\authorinfo\{([^}]+)\}',
r'\\authors\{([^}]+)\}',
r'\\author\[([^]]+)\]\{([^}]+)\}', # 带附加信息的作者
r'\\begin{authors}(.*?)\\end{authors}'
],
# 作者信息
'author': [
r'\\author\{([^}]+)\}',
r'\\Author\{([^}]+)\}',
r'\\authorinfo\{([^}]+)\}',
r'\\authors\{([^}]+)\}',
r'\\author\[([^]]+)\]\{([^}]+)\}', # 带附加信息的作者
r'\\begin{authors}(.*?)\\end{authors}'
],
# 日期相关
'date': [
r'\\date\{([^}]+)\}',
r'\\Date\{([^}]+)\}',
r'\\submitdate\{([^}]+)\}',
r'\\publishdate\{([^}]+)\}',
r'\\revisiondate\{([^}]+)\}'
],
# 日期相关
'date': [
r'\\date\{([^}]+)\}',
r'\\Date\{([^}]+)\}',
r'\\submitdate\{([^}]+)\}',
r'\\publishdate\{([^}]+)\}',
r'\\revisiondate\{([^}]+)\}'
],
# 关键词
'keywords': [
r'\\keywords\{([^}]+)\}',
r'\\Keywords\{([^}]+)\}',
r'\\begin{keywords}(.*?)\\end{keywords}',
r'\\key\{([^}]+)\}',
r'\\begin{关键词}(.*?)\\end{关键词}'
],
# 关键词
'keywords': [
r'\\keywords\{([^}]+)\}',
r'\\Keywords\{([^}]+)\}',
r'\\begin{keywords}(.*?)\\end{keywords}',
r'\\key\{([^}]+)\}',
r'\\begin{关键词}(.*?)\\end{关键词}'
],
# 机构/单位
'institution': [
r'\\institute\{([^}]+)\}',
r'\\institution\{([^}]+)\}',
r'\\affiliation\{([^}]+)\}',
r'\\organization\{([^}]+)\}',
r'\\department\{([^}]+)\}'
],
# 机构/单位
'institution': [
r'\\institute\{([^}]+)\}',
r'\\institution\{([^}]+)\}',
r'\\affiliation\{([^}]+)\}',
r'\\organization\{([^}]+)\}',
r'\\department\{([^}]+)\}'
],
# 学科/主题
'subject': [
r'\\subject\{([^}]+)\}',
r'\\Subject\{([^}]+)\}',
r'\\field\{([^}]+)\}',
r'\\discipline\{([^}]+)\}'
],
# 学科/主题
'subject': [
r'\\subject\{([^}]+)\}',
r'\\Subject\{([^}]+)\}',
r'\\field\{([^}]+)\}',
r'\\discipline\{([^}]+)\}'
],
# 版本信息
'version': [
r'\\version\{([^}]+)\}',
r'\\revision\{([^}]+)\}',
r'\\release\{([^}]+)\}'
],
# 版本信息
'version': [
r'\\version\{([^}]+)\}',
r'\\revision\{([^}]+)\}',
r'\\release\{([^}]+)\}'
],
# 许可证/版权
'license': [
r'\\license\{([^}]+)\}',
r'\\copyright\{([^}]+)\}',
r'\\begin{license}(.*?)\\end{license}'
],
# 许可证/版权
'license': [
r'\\license\{([^}]+)\}',
r'\\copyright\{([^}]+)\}',
r'\\begin{license}(.*?)\\end{license}'
],
# 联系方式
'contact': [
r'\\email\{([^}]+)\}',
r'\\phone\{([^}]+)\}',
r'\\address\{([^}]+)\}',
r'\\contact\{([^}]+)\}'
],
# 联系方式
'contact': [
r'\\email\{([^}]+)\}',
r'\\phone\{([^}]+)\}',
r'\\address\{([^}]+)\}',
r'\\contact\{([^}]+)\}'
],
# 致谢
'acknowledgments': [
r'\\begin{acknowledgments}(.*?)\\end{acknowledgments}',
r'\\acknowledgments\{([^}]+)\}',
r'\\thanks\{([^}]+)\}',
r'\\begin{致谢}(.*?)\\end{致谢}'
],
# 致谢
'acknowledgments': [
r'\\begin{acknowledgments}(.*?)\\end{acknowledgments}',
r'\\acknowledgments\{([^}]+)\}',
r'\\thanks\{([^}]+)\}',
r'\\begin{致谢}(.*?)\\end{致谢}'
],
# 项目/基金
'funding': [
r'\\funding\{([^}]+)\}',
r'\\grant\{([^}]+)\}',
r'\\project\{([^}]+)\}',
r'\\support\{([^}]+)\}'
],
# 项目/基金
'funding': [
r'\\funding\{([^}]+)\}',
r'\\grant\{([^}]+)\}',
r'\\project\{([^}]+)\}',
r'\\support\{([^}]+)\}'
],
# 分类号/编号
'classification': [
r'\\classification\{([^}]+)\}',
r'\\serialnumber\{([^}]+)\}',
r'\\id\{([^}]+)\}',
r'\\doi\{([^}]+)\}'
],
# 分类号/编号
'classification': [
r'\\classification\{([^}]+)\}',
r'\\serialnumber\{([^}]+)\}',
r'\\id\{([^}]+)\}',
r'\\doi\{([^}]+)\}'
],
# 语言
'language': [
r'\\documentlanguage\{([^}]+)\}',
r'\\lang\{([^}]+)\}',
r'\\language\{([^}]+)\}'
]
}
# 语言
'language': [
r'\\documentlanguage\{([^}]+)\}',
r'\\lang\{([^}]+)\}',
r'\\language\{([^}]+)\}'
]
}
latex_only_patterns = {
# 文档类和包引入
r'\\documentclass(\[.*?\])?\{.*?\}',

查看文件

@@ -1,15 +1,15 @@
import re
from typing import List, Dict, Tuple, Optional, Set
from enum import Enum
from dataclasses import dataclass, field
from copy import deepcopy
import logging
import re
from copy import deepcopy
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Dict, Tuple
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class SectionLevel(Enum):
CHAPTER = 0
@@ -39,6 +39,7 @@ class SectionLevel(Enum):
return NotImplemented
return self.value >= other.value
@dataclass
class Section:
level: SectionLevel
@@ -46,6 +47,7 @@ class Section:
content: str = ''
bibliography: str = ''
subsections: List['Section'] = field(default_factory=list)
def merge(self, other: 'Section') -> 'Section':
"""Merge this section with another section."""
if self.title != other.title or self.level != other.level:
@@ -78,6 +80,8 @@ class Section:
return content1
# Combine non-empty contents with a separator
return f"{content1}\n\n{content2}"
@dataclass
class LatexEnvironment:
"""表示LaTeX环境的数据类"""
@@ -409,4 +413,4 @@ f(x) = \int_0^x g(t) dt
if __name__ == "__main__":
test_enhanced_extractor()
test_enhanced_extractor()

查看文件

@@ -1,19 +1,14 @@
from dataclasses import dataclass
@dataclass
class SectionFragment:
"""Arxiv论文片段数据类"""
title: str # 论文标题
authors: str
abstract: str # 论文摘要
catalogs: str # 文章各章节的目录结构
catalogs: str # 文章各章节的目录结构
arxiv_id: str = "" # 添加 arxiv_id 属性
current_section: str = "Introduction" # 当前片段所属的section或者subsection或者孙subsubsection名字
content: str = '' #当前片段的内容
bibliography: str = '' #当前片段的参考文献
current_section: str = "Introduction" # 当前片段所属的section或者subsection或者孙subsubsection名字
content: str = '' # 当前片段的内容
bibliography: str = '' # 当前片段的参考文献

查看文件

@@ -1,10 +1,12 @@
import re
import os
import logging
import os
import re
from pathlib import Path
from typing import List, Tuple, Dict, Set, Optional, Callable
from typing import List, Set, Optional
from crazy_functions.rag_fns.arxiv_fns.latex_patterns import LaTeXPatterns
class TexUtils:
"""TeX文档处理器类"""
@@ -21,9 +23,6 @@ class TexUtils:
self._init_patterns()
self.latex_only_patterns = LaTeXPatterns.latex_only_patterns
def _init_patterns(self):
"""初始化LaTeX模式匹配规则"""
# 特殊环境模式
@@ -234,6 +233,7 @@ class TexUtils:
processed_refs.append("\n".join(ref_lines))
return processed_refs
def _extract_inline_references(self, content: str) -> str:
"""
从tex文件内容中提取直接写在文件中的参考文献
@@ -255,6 +255,7 @@ class TexUtils:
return content[start_match.start():end_match.end()]
return ""
def _preprocess_content(self, content: str) -> str:
"""预处理TeX内容"""
# 移除注释
@@ -263,9 +264,3 @@ class TexUtils:
# content = re.sub(r'\s+', ' ', content)
content = re.sub(r'\n\s*\n', '\n\n', content)
return content.strip()

查看文件

@@ -1,17 +1,13 @@
import llama_index
import os
import atexit
from loguru import logger
from typing import List
import os
from llama_index.core import Document
from llama_index.core.schema import TextNode
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
from shared_utils.connect_void_terminal import get_chat_default_kwargs
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
from llama_index.core.ingestion import run_transformations
from llama_index.core import PromptTemplate
from llama_index.core.response_synthesizers import TreeSummarize
from llama_index.core.schema import TextNode
from loguru import logger
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
DEFAULT_QUERY_GENERATION_PROMPT = """\
Now, you have context information as below:
@@ -127,7 +123,6 @@ class LlamaIndexRagWorker(SaveLoad):
logger.error(f"Error saving checkpoint: {str(e)}")
raise
def assign_embedding_model(self):
pass

查看文件

@@ -1,20 +1,14 @@
import llama_index
import os
import atexit
import os
from typing import List
from loguru import logger
from llama_index.core import Document
from llama_index.core.schema import TextNode
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
from shared_utils.connect_void_terminal import get_chat_default_kwargs
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
from llama_index.core.ingestion import run_transformations
from llama_index.core import PromptTemplate
from llama_index.core.response_synthesizers import TreeSummarize
from llama_index.core import StorageContext
from llama_index.vector_stores.milvus import MilvusVectorStore
from loguru import logger
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
DEFAULT_QUERY_GENERATION_PROMPT = """\
Now, you have context information as below:
@@ -65,17 +59,19 @@ class MilvusSaveLoad():
def create_new_vs(self, checkpoint_dir, overwrite=False):
vector_store = MilvusVectorStore(
uri=os.path.join(checkpoint_dir, "milvus_demo.db"),
uri=os.path.join(checkpoint_dir, "milvus_demo.db"),
dim=self.embed_model.embedding_dimension(),
overwrite=overwrite
)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = GptacVectorStoreIndex.default_vector_store(storage_context=storage_context, embed_model=self.embed_model)
index = GptacVectorStoreIndex.default_vector_store(storage_context=storage_context,
embed_model=self.embed_model)
return index
def purge(self):
self.vs_index = self.create_new_vs(self.checkpoint_dir, overwrite=True)
class MilvusRagWorker(MilvusSaveLoad, LlamaIndexRagWorker):
def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None:
@@ -96,7 +92,7 @@ class MilvusRagWorker(MilvusSaveLoad, LlamaIndexRagWorker):
docstore = self.vs_index.storage_context.docstore.docs
if not docstore.items():
raise ValueError("cannot inspect")
vector_store_preview = "\n".join([ f"{_id} | {tn.text}" for _id, tn in docstore.items() ])
vector_store_preview = "\n".join([f"{_id} | {tn.text}" for _id, tn in docstore.items()])
except:
dummy_retrieve_res: List["NodeWithScore"] = self.vs_index.as_retriever().retrieve(' ')
vector_store_preview = "\n".join(

查看文件

@@ -1,8 +1,8 @@
import os
from llama_index.core import SimpleDirectoryReader
supports_format = ['.csv', '.docx','.doc', '.epub', '.ipynb', '.mbox', '.md', '.pdf', '.txt', '.ppt',
'.pptm', '.pptx','.py', '.xls', '.xlsx', '.html', '.json', '.xml', '.yaml', '.yml' ,'.m']
supports_format = ['.csv', '.docx', '.doc', '.epub', '.ipynb', '.mbox', '.md', '.pdf', '.txt', '.ppt',
'.pptm', '.pptx', '.py', '.xls', '.xlsx', '.html', '.json', '.xml', '.yaml', '.yml', '.m']
def read_docx_doc(file_path):
if file_path.split(".")[-1] == "docx":
@@ -25,9 +25,11 @@ def read_docx_doc(file_path):
raise RuntimeError('请先将.doc文档转换为.docx文档。')
return file_content
# 修改后的 extract_text 函数,结合 SimpleDirectoryReader 和自定义解析逻辑
import os
def extract_text(file_path):
_, ext = os.path.splitext(file_path.lower())

查看文件

@@ -1,6 +1,6 @@
from llama_index.core import VectorStoreIndex
from typing import Any, List, Optional
from typing import Any, List, Optional
from llama_index.core import VectorStoreIndex
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.schema import TransformComponent
from llama_index.core.service_context import ServiceContext
@@ -13,18 +13,18 @@ from llama_index.core.storage.storage_context import StorageContext
class GptacVectorStoreIndex(VectorStoreIndex):
@classmethod
def default_vector_store(
cls,
storage_context: Optional[StorageContext] = None,
show_progress: bool = False,
callback_manager: Optional[CallbackManager] = None,
transformations: Optional[List[TransformComponent]] = None,
# deprecated
service_context: Optional[ServiceContext] = None,
embed_model = None,
**kwargs: Any,
cls,
storage_context: Optional[StorageContext] = None,
show_progress: bool = False,
callback_manager: Optional[CallbackManager] = None,
transformations: Optional[List[TransformComponent]] = None,
# deprecated
service_context: Optional[ServiceContext] = None,
embed_model=None,
**kwargs: Any,
):
"""Create index from documents.
@@ -36,15 +36,14 @@ class GptacVectorStoreIndex(VectorStoreIndex):
storage_context = storage_context or StorageContext.from_defaults()
docstore = storage_context.docstore
callback_manager = (
callback_manager
or callback_manager_from_settings_or_context(Settings, service_context)
callback_manager
or callback_manager_from_settings_or_context(Settings, service_context)
)
transformations = transformations or transformations_from_settings_or_context(
Settings, service_context
)
with callback_manager.as_trace("index_construction"):
return cls(
nodes=[],
storage_context=storage_context,
@@ -55,4 +54,3 @@ class GptacVectorStoreIndex(VectorStoreIndex):
embed_model=embed_model,
**kwargs,
)