这个提交包含在:
lbykkkk
2024-11-16 00:35:31 +08:00
父节点 dd902e9519
当前提交 21626a44d5
共有 12 个文件被更改,包括 2385 次插入1169 次删除

查看文件

@@ -0,0 +1,111 @@
import logging
import requests
import tarfile
from pathlib import Path
from typing import Optional, Dict
class ArxivDownloader:
"""用于下载arXiv论文源码的下载器"""
def __init__(self, root_dir: str = "./papers", proxies: Optional[Dict[str, str]] = None):
"""
初始化下载器
Args:
root_dir: 保存下载文件的根目录
proxies: 代理服务器设置,例如 {"http": "http://proxy:port", "https": "https://proxy:port"}
"""
self.root_dir = Path(root_dir)
self.root_dir.mkdir(exist_ok=True)
self.proxies = proxies
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
def _download_and_extract(self, arxiv_id: str) -> str:
"""
下载并解压arxiv论文源码
Args:
arxiv_id: arXiv论文ID,例如"2103.00020"
Returns:
str: 解压后的文件目录路径
Raises:
RuntimeError: 当下载失败时抛出
"""
paper_dir = self.root_dir / arxiv_id
tar_path = paper_dir / f"{arxiv_id}.tar.gz"
# 检查缓存
if paper_dir.exists() and any(paper_dir.iterdir()):
logging.info(f"Using cached version for {arxiv_id}")
return str(paper_dir)
paper_dir.mkdir(exist_ok=True)
urls = [
f"https://arxiv.org/src/{arxiv_id}",
f"https://arxiv.org/e-print/{arxiv_id}"
]
for url in urls:
try:
logging.info(f"Downloading from {url}")
response = requests.get(url, proxies=self.proxies)
if response.status_code == 200:
tar_path.write_bytes(response.content)
with tarfile.open(tar_path, 'r:gz') as tar:
tar.extractall(path=paper_dir)
return str(paper_dir)
except Exception as e:
logging.warning(f"Download failed for {url}: {e}")
continue
raise RuntimeError(f"Failed to download paper {arxiv_id}")
def download_paper(self, arxiv_id: str) -> str:
"""
下载指定的arXiv论文
Args:
arxiv_id: arXiv论文ID
Returns:
str: 论文文件所在的目录路径
"""
return self._download_and_extract(arxiv_id)
def main():
"""测试下载功能"""
# 配置代理(如果需要)
proxies = {
"http": "http://your-proxy:port",
"https": "https://your-proxy:port"
}
# 创建下载器实例如果不需要代理,可以不传入proxies参数
downloader = ArxivDownloader(root_dir="./downloaded_papers", proxies=None)
# 测试下载一篇论文这里使用一个示例ID
try:
paper_id = "2103.00020" # 这是一个示例ID
paper_dir = downloader.download_paper(paper_id)
print(f"Successfully downloaded paper to: {paper_dir}")
# 检查下载的文件
paper_path = Path(paper_dir)
if paper_path.exists():
print("Downloaded files:")
for file in paper_path.rglob("*"):
if file.is_file():
print(f"- {file.relative_to(paper_path)}")
except Exception as e:
print(f"Error downloading paper: {e}")
if __name__ == "__main__":
main()

查看文件

@@ -0,0 +1,55 @@
from dataclasses import dataclass
@dataclass
class ArxivFragment:
"""Arxiv论文片段数据类"""
file_path: str # 文件路径
content: str # 内容
segment_index: int # 片段索引
total_segments: int # 总片段数
rel_path: str # 相对路径
segment_type: str # 片段类型(text/math/table/figure等)
title: str # 论文标题
abstract: str # 论文摘要
section: str # 所属章节
is_appendix: bool # 是否是附录
importance: float = 1.0 # 重要性得分
@staticmethod
def merge_segments(seg1: 'ArxivFragment', seg2: 'ArxivFragment') -> 'ArxivFragment':
"""
合并两个片段的静态方法
Args:
seg1: 第一个片段
seg2: 第二个片段
Returns:
ArxivFragment: 合并后的片段
"""
# 合并内容
merged_content = f"{seg1.content}\n{seg2.content}"
# 确定合并后的类型
def _merge_segment_type(type1: str, type2: str) -> str:
if type1 == type2:
return type1
if type1 == 'text':
return type2
if type2 == 'text':
return type1
return 'mixed'
return ArxivFragment(
file_path=seg1.file_path,
content=merged_content,
segment_index=seg1.segment_index,
total_segments=seg1.total_segments - 1,
rel_path=seg1.rel_path,
segment_type=_merge_segment_type(seg1.segment_type, seg2.segment_type),
title=seg1.title,
abstract=seg1.abstract,
section=seg1.section,
is_appendix=seg1.is_appendix,
importance=max(seg1.importance, seg2.importance)
)

查看文件

@@ -0,0 +1,449 @@
import os
import re
import time
import aiohttp
import asyncio
import requests
import tarfile
import logging
from pathlib import Path
from typing import Generator, List, Tuple, Optional, Dict, Set
from concurrent.futures import ThreadPoolExecutor, as_completed
from crazy_functions.rag_fns.arxiv_fns.tex_processor import TexProcessor
from crazy_functions.rag_fns.arxiv_fns.arxiv_fragment import ArxivFragment
def save_fragments_to_file(fragments, output_dir: str = "fragment_outputs"):
"""
将所有fragments保存为单个结构化markdown文件
Args:
fragments: fragment列表
output_dir: 输出目录
"""
from datetime import datetime
from pathlib import Path
import re
# 创建输出目录
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# 生成文件名
filename = f"fragments_{timestamp}.md"
file_path = output_path / filename
current_section = ""
section_count = {} # 用于跟踪每个章节的片段数量
with open(file_path, "w", encoding="utf-8") as f:
# 写入文档头部
f.write("# Document Fragments Analysis\n\n")
f.write("## Overview\n")
f.write(f"- Total Fragments: {len(fragments)}\n")
f.write(f"- Generated Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
# 如果有标题和摘要,添加到开头
if fragments and (fragments[0].title or fragments[0].abstract):
f.write("\n## Paper Information\n")
if fragments[0].title:
f.write(f"### Title\n{fragments[0].title}\n")
if fragments[0].abstract:
f.write(f"\n### Abstract\n{fragments[0].abstract}\n")
# 生成目录
f.write("\n## Table of Contents\n")
# 首先收集所有章节信息
sections = {}
for fragment in fragments:
section = fragment.section or "Uncategorized"
if section not in sections:
sections[section] = []
sections[section].append(fragment)
# 写入目录
for section, section_fragments in sections.items():
clean_section = section.strip()
if not clean_section:
clean_section = "Uncategorized"
f.write(
f"- [{clean_section}](#{clean_section.lower().replace(' ', '-')}) ({len(section_fragments)} fragments)\n")
# 写入正文内容
f.write("\n## Content\n")
# 按章节组织内容
for section, section_fragments in sections.items():
clean_section = section.strip() or "Uncategorized"
f.write(f"\n### {clean_section}\n")
# 写入每个fragment
for i, fragment in enumerate(section_fragments, 1):
f.write(f"\n#### Fragment {i} ({fragment.segment_type})\n")
# 元数据
f.write("**Metadata:**\n")
f.write(f"- Type: {fragment.segment_type}\n")
f.write(f"- Length: {len(fragment.content)} chars\n")
f.write(f"- Importance: {fragment.importance:.2f}\n")
f.write(f"- Is Appendix: {fragment.is_appendix}\n")
f.write(f"- File: {fragment.rel_path}\n")
# 内容
f.write("\n**Content:**\n")
f.write("```tex\n")
f.write(fragment.content)
f.write("\n```\n")
# 添加分隔线
if i < len(section_fragments):
f.write("\n---\n")
# 添加统计信息
f.write("\n## Statistics\n")
f.write("\n### Fragment Type Distribution\n")
type_stats = {}
for fragment in fragments:
type_stats[fragment.segment_type] = type_stats.get(fragment.segment_type, 0) + 1
for ftype, count in type_stats.items():
percentage = (count / len(fragments)) * 100
f.write(f"- {ftype}: {count} ({percentage:.1f}%)\n")
# 长度分布
f.write("\n### Length Distribution\n")
lengths = [len(f.content) for f in fragments]
f.write(f"- Minimum: {min(lengths)} chars\n")
f.write(f"- Maximum: {max(lengths)} chars\n")
f.write(f"- Average: {sum(lengths) / len(lengths):.1f} chars\n")
print(f"Fragments saved to: {file_path}")
return file_path
class ArxivSplitter:
"""Arxiv论文智能分割器"""
def __init__(self,
char_range: Tuple[int, int],
root_dir: str = "gpt_log/arxiv_cache",
proxies: Optional[Dict[str, str]] = None,
cache_ttl: int = 7 * 24 * 60 * 60):
"""
初始化分割器
Args:
char_range: 字符数范围(最小值, 最大值)
root_dir: 缓存根目录
proxies: 代理设置
cache_ttl: 缓存过期时间(秒)
"""
self.min_chars, self.max_chars = char_range
self.root_dir = Path(root_dir)
self.root_dir.mkdir(parents=True, exist_ok=True)
self.proxies = proxies or {}
self.cache_ttl = cache_ttl
# 动态计算最优线程数
import multiprocessing
cpu_count = multiprocessing.cpu_count()
# 根据CPU核心数动态设置,但设置上限防止过度并发
self.max_workers = min(32, cpu_count * 2)
# 初始化TeX处理器
self.tex_processor = TexProcessor(char_range)
# 配置日志
self._setup_logging()
def _setup_logging(self):
"""配置日志"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
self.logger = logging.getLogger(__name__)
def _normalize_arxiv_id(self, input_str: str) -> str:
"""规范化ArXiv ID"""
if 'arxiv.org/' in input_str.lower():
# 处理URL格式
if '/pdf/' in input_str:
arxiv_id = input_str.split('/pdf/')[-1]
else:
arxiv_id = input_str.split('/abs/')[-1]
# 移除版本号和其他后缀
return arxiv_id.split('v')[0].strip()
return input_str.split('v')[0].strip()
def _check_cache(self, paper_dir: Path) -> bool:
"""
检查缓存是否有效,包括文件完整性检查
Args:
paper_dir: 论文目录路径
Returns:
bool: 如果缓存有效返回True,否则返回False
"""
if not paper_dir.exists():
return False
# 检查目录中是否存在必要文件
has_tex_files = False
has_main_tex = False
for file_path in paper_dir.rglob("*"):
if file_path.suffix == '.tex':
has_tex_files = True
content = self.tex_processor.read_file(str(file_path))
if content and r'\documentclass' in content:
has_main_tex = True
break
if not (has_tex_files and has_main_tex):
return False
# 检查缓存时间
cache_time = paper_dir.stat().st_mtime
if (time.time() - cache_time) < self.cache_ttl:
self.logger.info(f"Using valid cache for {paper_dir.name}")
return True
return False
async def download_paper(self, arxiv_id: str, paper_dir: Path) -> bool:
"""
异步下载论文,包含重试机制和临时文件处理
Args:
arxiv_id: ArXiv论文ID
paper_dir: 目标目录路径
Returns:
bool: 下载成功返回True,否则返回False
"""
from crazy_functions.rag_fns.arxiv_fns.arxiv_downloader import ArxivDownloader
temp_tar_path = paper_dir / f"{arxiv_id}_temp.tar.gz"
final_tar_path = paper_dir / f"{arxiv_id}.tar.gz"
# 确保目录存在
paper_dir.mkdir(parents=True, exist_ok=True)
# 尝试使用 ArxivDownloader 下载
try:
downloader = ArxivDownloader(root_dir=str(paper_dir), proxies=self.proxies)
downloaded_dir = downloader.download_paper(arxiv_id)
if downloaded_dir:
self.logger.info(f"Successfully downloaded using ArxivDownloader to {downloaded_dir}")
return True
except Exception as e:
self.logger.warning(f"ArxivDownloader failed: {str(e)}. Falling back to direct download.")
# 如果 ArxivDownloader 失败,使用原有的下载方式作为备选
urls = [
f"https://arxiv.org/src/{arxiv_id}",
f"https://arxiv.org/e-print/{arxiv_id}"
]
max_retries = 3
retry_delay = 1 # 初始重试延迟(秒)
for url in urls:
for attempt in range(max_retries):
try:
self.logger.info(f"Downloading from {url} (attempt {attempt + 1}/{max_retries})")
async with aiohttp.ClientSession() as session:
async with session.get(url, proxy=self.proxies.get('http')) as response:
if response.status == 200:
content = await response.read()
# 写入临时文件
temp_tar_path.write_bytes(content)
try:
# 验证tar文件完整性并解压
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, self._process_tar_file, temp_tar_path, paper_dir)
# 下载成功后移动临时文件到最终位置
temp_tar_path.rename(final_tar_path)
return True
except Exception as e:
self.logger.warning(f"Invalid tar file: {str(e)}")
if temp_tar_path.exists():
temp_tar_path.unlink()
except Exception as e:
self.logger.warning(f"Download attempt {attempt + 1} failed from {url}: {str(e)}")
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
continue
return False
def _process_tar_file(self, tar_path: Path, extract_path: Path):
"""处理tar文件的同步操作"""
with tarfile.open(tar_path, 'r:gz') as tar:
tar.testall() # 验证文件完整性
tar.extractall(path=extract_path) # 解压文件
def _process_single_tex(self, file_path: str) -> List[ArxivFragment]:
"""处理单个TeX文件"""
try:
content = self.tex_processor.read_file(file_path)
if not content:
return []
# 提取元数据
is_main = r'\documentclass' in content
title, abstract = "", ""
if is_main:
title, abstract = self.tex_processor.extract_metadata(content)
# 分割内容
segments = self.tex_processor.split_content(content)
fragments = []
for i, (segment_content, section, is_appendix) in enumerate(segments):
if not segment_content.strip():
continue
segment_type = self.tex_processor.detect_segment_type(segment_content)
importance = self.tex_processor.calculate_importance(
segment_content, segment_type, is_main
)
fragments.append(ArxivFragment(
file_path=file_path,
content=segment_content,
segment_index=i,
total_segments=len(segments),
rel_path=str(Path(file_path).relative_to(self.root_dir)),
segment_type=segment_type,
title=title,
abstract=abstract,
section=section,
is_appendix=is_appendix,
importance=importance
))
return fragments
except Exception as e:
self.logger.error(f"Error processing {file_path}: {str(e)}")
return []
async def process(self, arxiv_id_or_url: str) -> List[ArxivFragment]:
"""处理ArXiv论文"""
try:
arxiv_id = self._normalize_arxiv_id(arxiv_id_or_url)
paper_dir = self.root_dir / arxiv_id
# 检查缓存
if not self._check_cache(paper_dir):
paper_dir.mkdir(exist_ok=True)
if not await self.download_paper(arxiv_id, paper_dir):
raise RuntimeError(f"Failed to download paper {arxiv_id}")
# 查找主TeX文件
main_tex = self.tex_processor.find_main_tex_file(str(paper_dir))
if not main_tex:
raise RuntimeError(f"No main TeX file found in {paper_dir}")
# 获取所有相关TeX文件
tex_files = self.tex_processor.resolve_includes(main_tex)
if not tex_files:
raise RuntimeError(f"No valid TeX files found for {arxiv_id}")
# 并行处理所有TeX文件
fragments = []
chunk_size = max(1, len(tex_files) // self.max_workers) # 计算每个线程处理的文件数
loop = asyncio.get_event_loop()
async def process_chunk(chunk_files):
chunk_fragments = []
for file_path in chunk_files:
try:
result = await loop.run_in_executor(None, self._process_single_tex, file_path)
chunk_fragments.extend(result)
except Exception as e:
self.logger.error(f"Error processing {file_path}: {str(e)}")
return chunk_fragments
# 将文件分成多个块
file_chunks = [tex_files[i:i + chunk_size] for i in range(0, len(tex_files), chunk_size)]
# 异步处理每个块
chunk_results = await asyncio.gather(*[process_chunk(chunk) for chunk in file_chunks])
for result in chunk_results:
fragments.extend(result)
# 重新计算片段索引并排序
fragments.sort(key=lambda x: (x.rel_path, x.segment_index))
total_fragments = len(fragments)
for i, fragment in enumerate(fragments):
fragment.segment_index = i
fragment.total_segments = total_fragments
# 在返回之前添加过滤
fragments = self.tex_processor.filter_fragments(fragments)
return fragments
except Exception as e:
self.logger.error(f"Failed to process {arxiv_id_or_url}: {str(e)}")
raise
async def test_arxiv_splitter():
"""测试ArXiv分割器的功能"""
# 测试配置
test_cases = [
{
"arxiv_id": "2411.03663",
"expected_title": "Large Language Models and Simple Scripts",
"min_fragments": 10,
},
# {
# "arxiv_id": "1805.10988",
# "expected_title": "RAG vs Fine-tuning",
# "min_fragments": 15,
# }
]
# 创建分割器实例
splitter = ArxivSplitter(
char_range=(800, 1800),
root_dir="test_cache"
)
for case in test_cases:
print(f"\nTesting paper: {case['arxiv_id']}")
try:
fragments = await splitter.process(case['arxiv_id'])
# 保存fragments
output_dir = save_fragments_to_file(fragments,output_dir="crazy_functions/rag_fns/arxiv_fns/gpt_log")
print(f"Output saved to: {output_dir}")
# 内容检查
for fragment in fragments:
# 长度检查
print((fragment.content))
print(len(fragment.content))
# 类型检查
except Exception as e:
print(f"✗ Test failed for {case['arxiv_id']}: {str(e)}")
raise
if __name__ == "__main__":
asyncio.run(test_arxiv_splitter())

查看文件

@@ -0,0 +1,395 @@
from dataclasses import dataclass, field
@dataclass
class LaTeXPatterns:
"""LaTeX模式存储类,用于集中管理所有LaTeX相关的正则表达式模式"""
special_envs = {
'math': [
# 基础数学环境
r'\\begin{(equation|align|gather|eqnarray|multline|flalign|alignat)\*?}.*?\\end{\1\*?}',
r'\$\$.*?\$\$',
r'\$[^$]+\$',
# 矩阵环境
r'\\begin{(matrix|pmatrix|bmatrix|Bmatrix|vmatrix|Vmatrix|smallmatrix)\*?}.*?\\end{\1\*?}',
# 数组环境
r'\\begin{(array|cases|aligned|gathered|split)\*?}.*?\\end{\1\*?}',
# 其他数学环境
r'\\begin{(subequations|math|displaymath)\*?}.*?\\end{\1\*?}'
],
'table': [
# 基础表格环境
r'\\begin{(table|tabular|tabularx|tabulary|longtable)\*?}.*?\\end{\1\*?}',
# 复杂表格环境
r'\\begin{(tabu|supertabular|xtabular|mpsupertabular)\*?}.*?\\end{\1\*?}',
# 自定义表格环境
r'\\begin{(threeparttable|tablefootnote)\*?}.*?\\end{\1\*?}',
# 表格注释环境
r'\\begin{(tablenotes)\*?}.*?\\end{\1\*?}'
],
'figure': [
# 图片环境
r'\\begin{figure\*?}.*?\\end{figure\*?}',
r'\\begin{(subfigure|wrapfigure)\*?}.*?\\end{\1\*?}',
# 图片插入命令
r'\\includegraphics(\[.*?\])?\{.*?\}',
# tikz 图形环境
r'\\begin{(tikzpicture|pgfpicture)\*?}.*?\\end{\1\*?}',
# 其他图形环境
r'\\begin{(picture|pspicture)\*?}.*?\\end{\1\*?}'
],
'algorithm': [
# 算法环境
r'\\begin{(algorithm|algorithmic|algorithm2e|algorithmicx)\*?}.*?\\end{\1\*?}',
r'\\begin{(lstlisting|verbatim|minted|listing)\*?}.*?\\end{\1\*?}',
# 代码块环境
r'\\begin{(code|verbatimtab|verbatimwrite)\*?}.*?\\end{\1\*?}',
# 伪代码环境
r'\\begin{(pseudocode|procedure)\*?}.*?\\end{\1\*?}'
],
'list': [
# 列表环境
r'\\begin{(itemize|enumerate|description)\*?}.*?\\end{\1\*?}',
r'\\begin{(list|compactlist|bulletlist)\*?}.*?\\end{\1\*?}',
# 自定义列表环境
r'\\begin{(tasks|todolist)\*?}.*?\\end{\1\*?}'
],
'theorem': [
# 定理类环境
r'\\begin{(theorem|lemma|proposition|corollary)\*?}.*?\\end{\1\*?}',
r'\\begin{(definition|example|proof|remark)\*?}.*?\\end{\1\*?}',
# 其他证明环境
r'\\begin{(axiom|property|assumption|conjecture)\*?}.*?\\end{\1\*?}'
],
'box': [
# 文本框环境
r'\\begin{(tcolorbox|mdframed|framed|shaded)\*?}.*?\\end{\1\*?}',
r'\\begin{(boxedminipage|shadowbox)\*?}.*?\\end{\1\*?}',
# 强调环境
r'\\begin{(important|warning|info|note)\*?}.*?\\end{\1\*?}'
],
'quote': [
# 引用环境
r'\\begin{(quote|quotation|verse|abstract)\*?}.*?\\end{\1\*?}',
r'\\begin{(excerpt|epigraph)\*?}.*?\\end{\1\*?}'
],
'bibliography': [
# 参考文献环境
r'\\begin{(thebibliography|bibliography)\*?}.*?\\end{\1\*?}',
r'\\begin{(biblist|citelist)\*?}.*?\\end{\1\*?}'
],
'index': [
# 索引环境
r'\\begin{(theindex|printindex)\*?}.*?\\end{\1\*?}',
r'\\begin{(glossary|acronym)\*?}.*?\\end{\1\*?}'
]
}
# 章节模式
section_patterns = [
# 基础章节命令
r'\\chapter\{([^}]+)\}',
r'\\section\{([^}]+)\}',
r'\\subsection\{([^}]+)\}',
r'\\subsubsection\{([^}]+)\}',
r'\\paragraph\{([^}]+)\}',
r'\\subparagraph\{([^}]+)\}',
# 带星号的变体(不编号)
r'\\chapter\*\{([^}]+)\}',
r'\\section\*\{([^}]+)\}',
r'\\subsection\*\{([^}]+)\}',
r'\\subsubsection\*\{([^}]+)\}',
r'\\paragraph\*\{([^}]+)\}',
r'\\subparagraph\*\{([^}]+)\}',
# 特殊章节
r'\\part\{([^}]+)\}',
r'\\part\*\{([^}]+)\}',
r'\\appendix\{([^}]+)\}',
# 前言部分
r'\\frontmatter\{([^}]+)\}',
r'\\mainmatter\{([^}]+)\}',
r'\\backmatter\{([^}]+)\}',
# 目录相关
r'\\tableofcontents',
r'\\listoffigures',
r'\\listoftables',
# 自定义章节命令
r'\\addchap\{([^}]+)\}', # KOMA-Script类
r'\\addsec\{([^}]+)\}', # KOMA-Script类
r'\\minisec\{([^}]+)\}', # KOMA-Script类
# 带可选参数的章节命令
r'\\chapter\[([^]]+)\]\{([^}]+)\}',
r'\\section\[([^]]+)\]\{([^}]+)\}',
r'\\subsection\[([^]]+)\]\{([^}]+)\}'
]
# 包含模式
include_patterns = [
r'\\(input|include|subfile)\{([^}]+)\}'
]
metadata_patterns = {
# 标题相关
'title': [
r'\\title\{([^}]+)\}',
r'\\Title\{([^}]+)\}',
r'\\doctitle\{([^}]+)\}',
r'\\subtitle\{([^}]+)\}',
r'\\chapter\*?\{([^}]+)\}', # 第一章可能作为标题
r'\\maketitle\s*\\section\*?\{([^}]+)\}' # 第一节可能作为标题
],
# 摘要相关
'abstract': [
r'\\begin{abstract}(.*?)\\end{abstract}',
r'\\abstract\{([^}]+)\}',
r'\\begin{摘要}(.*?)\\end{摘要}',
r'\\begin{Summary}(.*?)\\end{Summary}',
r'\\begin{synopsis}(.*?)\\end{synopsis}',
r'\\begin{abstracten}(.*?)\\end{abstracten}' # 英文摘要
],
# 作者信息
'author': [
r'\\author\{([^}]+)\}',
r'\\Author\{([^}]+)\}',
r'\\authorinfo\{([^}]+)\}',
r'\\authors\{([^}]+)\}',
r'\\author\[([^]]+)\]\{([^}]+)\}', # 带附加信息的作者
r'\\begin{authors}(.*?)\\end{authors}'
],
# 日期相关
'date': [
r'\\date\{([^}]+)\}',
r'\\Date\{([^}]+)\}',
r'\\submitdate\{([^}]+)\}',
r'\\publishdate\{([^}]+)\}',
r'\\revisiondate\{([^}]+)\}'
],
# 关键词
'keywords': [
r'\\keywords\{([^}]+)\}',
r'\\Keywords\{([^}]+)\}',
r'\\begin{keywords}(.*?)\\end{keywords}',
r'\\key\{([^}]+)\}',
r'\\begin{关键词}(.*?)\\end{关键词}'
],
# 机构/单位
'institution': [
r'\\institute\{([^}]+)\}',
r'\\institution\{([^}]+)\}',
r'\\affiliation\{([^}]+)\}',
r'\\organization\{([^}]+)\}',
r'\\department\{([^}]+)\}'
],
# 学科/主题
'subject': [
r'\\subject\{([^}]+)\}',
r'\\Subject\{([^}]+)\}',
r'\\field\{([^}]+)\}',
r'\\discipline\{([^}]+)\}'
],
# 版本信息
'version': [
r'\\version\{([^}]+)\}',
r'\\revision\{([^}]+)\}',
r'\\release\{([^}]+)\}'
],
# 许可证/版权
'license': [
r'\\license\{([^}]+)\}',
r'\\copyright\{([^}]+)\}',
r'\\begin{license}(.*?)\\end{license}'
],
# 联系方式
'contact': [
r'\\email\{([^}]+)\}',
r'\\phone\{([^}]+)\}',
r'\\address\{([^}]+)\}',
r'\\contact\{([^}]+)\}'
],
# 致谢
'acknowledgments': [
r'\\begin{acknowledgments}(.*?)\\end{acknowledgments}',
r'\\acknowledgments\{([^}]+)\}',
r'\\thanks\{([^}]+)\}',
r'\\begin{致谢}(.*?)\\end{致谢}'
],
# 项目/基金
'funding': [
r'\\funding\{([^}]+)\}',
r'\\grant\{([^}]+)\}',
r'\\project\{([^}]+)\}',
r'\\support\{([^}]+)\}'
],
# 分类号/编号
'classification': [
r'\\classification\{([^}]+)\}',
r'\\serialnumber\{([^}]+)\}',
r'\\id\{([^}]+)\}',
r'\\doi\{([^}]+)\}'
],
# 语言
'language': [
r'\\documentlanguage\{([^}]+)\}',
r'\\lang\{([^}]+)\}',
r'\\language\{([^}]+)\}'
]
}
latex_only_patterns = {
# 文档类和包引入
r'\\documentclass(\[.*?\])?\{.*?\}',
r'\\usepackage(\[.*?\])?\{.*?\}',
# 常见的文档设置命令
r'\\setlength\{.*?\}\{.*?\}',
r'\\newcommand\{.*?\}(\[.*?\])?\{.*?\}',
r'\\renewcommand\{.*?\}(\[.*?\])?\{.*?\}',
r'\\definecolor\{.*?\}\{.*?\}\{.*?\}',
# 页面设置相关
r'\\pagestyle\{.*?\}',
r'\\thispagestyle\{.*?\}',
# 其他常见的设置命令
r'\\bibliographystyle\{.*?\}',
r'\\bibliography\{.*?\}',
r'\\setcounter\{.*?\}\{.*?\}',
# 字体和文本设置命令
r'\\makeFNbottom',
r'\\@setfontsize\\[A-Z]+\{.*?\}\{.*?\}', # 匹配字体大小设置
r'\\renewcommand\\[A-Z]+\{\\@setfontsize\\[A-Z]+\{.*?\}\{.*?\}\}',
r'\\renewcommand\{?\\thefootnote\}?\{\\fnsymbol\{footnote\}\}',
r'\\renewcommand\\footnoterule\{.*?\}',
r'\\color\{.*?\}',
# 页面和节标题设置
r'\\setcounter\{secnumdepth\}\{.*?\}',
r'\\renewcommand\\@biblabel\[.*?\]\{.*?\}',
r'\\renewcommand\\@makefntext\[.*?\](\{.*?\})*',
r'\\renewcommand\{?\\figurename\}?\{.*?\}',
# 字体样式设置
r'\\sectionfont\{.*?\}',
r'\\subsectionfont\{.*?\}',
r'\\subsubsectionfont\{.*?\}',
# 间距和布局设置
r'\\setstretch\{.*?\}',
r'\\setlength\{\\skip\\footins\}\{.*?\}',
r'\\setlength\{\\footnotesep\}\{.*?\}',
r'\\setlength\{\\jot\}\{.*?\}',
r'\\hrule\s+width\s+.*?\s+height\s+.*?',
# makeatletter 和 makeatother
r'\\makeatletter\s*',
r'\\makeatother\s*',
r'\\footnotetext\{[^}]*\$\^{[^}]*}\$[^}]*\}', # 带有上标的脚注
# r'\\footnotetext\{[^}]*\}', # 普通脚注
# r'\\footnotetext\{.*?(?:\$\^{.*?}\$)?.*?(?:email\s*:\s*[^}]*)?.*?\}', # 带有邮箱的脚注
# r'\\footnotetext\{.*?(?:ESI|DOI).*?\}', # 带有 DOI 或 ESI 引用的脚注
# 文档结构命令
r'\\begin\{document\}',
r'\\end\{document\}',
r'\\maketitle',
r'\\printbibliography',
r'\\newpage',
# 输入文件命令
r'\\input\{[^}]*\}',
r'\\input\{.*?\.tex\}', # 特别匹配 .tex 后缀的输入
# 脚注相关
# r'\\footnotetext\[\d+\]\{[^}]*\}', # 带编号的脚注
# 致谢环境
r'\\begin\{ack\}',
r'\\end\{ack\}',
r'\\begin\{ack\}[^\n]*(?:\n.*?)*?\\end\{ack\}', # 匹配整个致谢环境及其内容
# 其他文档控制命令
r'\\renewcommand\{\\thefootnote\}\{\\fnsymbol\{footnote\}\}',
}
math_envs = [
# 基础数学环境
(r'\\begin{equation\*?}.*?\\end{equation\*?}', 'equation'), # 单行公式
(r'\\begin{align\*?}.*?\\end{align\*?}', 'align'), # 多行对齐公式
(r'\\begin{gather\*?}.*?\\end{gather\*?}', 'gather'), # 多行居中公式
(r'\$\$.*?\$\$', 'display'), # 行间公式
(r'\$.*?\$', 'inline'), # 行内公式
# 矩阵环境
(r'\\begin{matrix}.*?\\end{matrix}', 'matrix'), # 基础矩阵
(r'\\begin{pmatrix}.*?\\end{pmatrix}', 'pmatrix'), # 圆括号矩阵
(r'\\begin{bmatrix}.*?\\end{bmatrix}', 'bmatrix'), # 方括号矩阵
(r'\\begin{vmatrix}.*?\\end{vmatrix}', 'vmatrix'), # 竖线矩阵
(r'\\begin{Vmatrix}.*?\\end{Vmatrix}', 'Vmatrix'), # 双竖线矩阵
(r'\\begin{smallmatrix}.*?\\end{smallmatrix}', 'smallmatrix'), # 小号矩阵
# 数组环境
(r'\\begin{array}.*?\\end{array}', 'array'), # 数组
(r'\\begin{cases}.*?\\end{cases}', 'cases'), # 分段函数
# 多行公式环境
(r'\\begin{multline\*?}.*?\\end{multline\*?}', 'multline'), # 多行单个公式
(r'\\begin{split}.*?\\end{split}', 'split'), # 拆分长公式
(r'\\begin{alignat\*?}.*?\\end{alignat\*?}', 'alignat'), # 对齐环境带间距控制
(r'\\begin{flalign\*?}.*?\\end{flalign\*?}', 'flalign'), # 完全左对齐
# 特殊数学环境
(r'\\begin{subequations}.*?\\end{subequations}', 'subequations'), # 子公式编号
(r'\\begin{gathered}.*?\\end{gathered}', 'gathered'), # 居中对齐组
(r'\\begin{aligned}.*?\\end{aligned}', 'aligned'), # 内部对齐组
# 定理类环境
(r'\\begin{theorem}.*?\\end{theorem}', 'theorem'), # 定理
(r'\\begin{lemma}.*?\\end{lemma}', 'lemma'), # 引理
(r'\\begin{proof}.*?\\end{proof}', 'proof'), # 证明
# 数学模式中的表格环境
(r'\\begin{tabular}.*?\\end{tabular}', 'tabular'), # 表格
(r'\\begin{array}.*?\\end{array}', 'array'), # 数组
# 其他专业数学环境
(r'\\begin{CD}.*?\\end{CD}', 'CD'), # 交换图
(r'\\begin{boxed}.*?\\end{boxed}', 'boxed'), # 带框公式
(r'\\begin{empheq}.*?\\end{empheq}', 'empheq'), # 强调公式
# 化学方程式环境 (需要加载 mhchem 包)
(r'\\begin{reaction}.*?\\end{reaction}', 'reaction'), # 化学反应式
(r'\\ce\{.*?\}', 'chemequation'), # 化学方程式
# 物理单位环境 (需要加载 siunitx 包)
(r'\\SI\{.*?\}\{.*?\}', 'SI'), # 物理单位
(r'\\si\{.*?\}', 'si'), # 单位
# 补充环境
(r'\\begin{equation\+}.*?\\end{equation\+}', 'equation+'), # breqn包的自动换行公式
(r'\\begin{dmath\*?}.*?\\end{dmath\*?}', 'dmath'), # breqn包的显示数学模式
(r'\\begin{dgroup\*?}.*?\\end{dgroup\*?}', 'dgroup'), # breqn包的公式组
]
# 示例使用函数
# 使用示例

文件差异内容过多而无法显示 加载差异

查看文件

@@ -13,18 +13,19 @@ from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
T = TypeVar('T')
@dataclass
@dataclass
class StorageBase:
"""Base class for all storage implementations"""
namespace: str
working_dir: str
async def index_done_callback(self):
"""Hook called after indexing operations"""
pass
async def query_done_callback(self):
"""Hook called after query operations"""
"""Hook called after query operations"""
pass
@@ -32,37 +33,37 @@ class StorageBase:
class JsonKVStorage(StorageBase, Generic[T]):
"""
Key-Value storage using JSON files
Attributes:
namespace (str): Storage namespace
working_dir (str): Working directory for storage files
_file_name (str): JSON file path
_data (Dict[str, T]): In-memory storage
"""
def __post_init__(self):
"""Initialize storage file and load data"""
self._file_name = os.path.join(self.working_dir, f"kv_{self.namespace}.json")
self._file_name = os.path.join(self.working_dir, f"kv_store_{self.namespace}.json")
self._data: Dict[str, T] = {}
self.load()
def load(self):
"""Load data from JSON file"""
if os.path.exists(self._file_name):
with open(self._file_name, 'r', encoding='utf-8') as f:
self._data = json.load(f)
logger.info(f"Loaded {len(self._data)} items from {self._file_name}")
async def save(self):
"""Save data to JSON file"""
os.makedirs(os.path.dirname(self._file_name), exist_ok=True)
with open(self._file_name, 'w', encoding='utf-8') as f:
json.dump(self._data, f, ensure_ascii=False, indent=2)
async def get_by_id(self, id: str) -> Optional[T]:
"""Get item by ID"""
return self._data.get(id)
async def get_by_ids(self, ids: List[str], fields: Optional[Set[str]] = None) -> List[Optional[T]]:
"""Get multiple items by IDs with optional field filtering"""
if fields is None:
@@ -70,16 +71,16 @@ class JsonKVStorage(StorageBase, Generic[T]):
return [{k: v for k, v in self._data[id].items() if k in fields}
if id in self._data else None
for id in ids]
async def filter_keys(self, keys: List[str]) -> Set[str]:
"""Return keys that don't exist in storage"""
return set(k for k in keys if k not in self._data)
async def upsert(self, data: Dict[str, T]):
"""Insert or update items"""
self._data.update(data)
await self.save()
async def drop(self):
"""Clear all data"""
self._data = {}
@@ -95,148 +96,225 @@ class JsonKVStorage(StorageBase, Generic[T]):
await self.save()
@dataclass
class VectorStorage(StorageBase):
"""
Vector storage using LlamaIndex
Vector storage using LlamaIndexRagWorker
Attributes:
namespace (str): Storage namespace
namespace (str): Storage namespace (e.g., 'entities', 'relationships', 'chunks')
working_dir (str): Working directory for storage files
llm_kwargs (dict): LLM configuration
embedding_func (OpenAiEmbeddingModel): Embedding function
meta_fields (Set[str]): Additional fields to store
cosine_better_than_threshold (float): Similarity threshold
meta_fields (Set[str]): Additional metadata fields to store
"""
llm_kwargs: dict
embedding_func: OpenAiEmbeddingModel
meta_fields: Set[str] = field(default_factory=set)
cosine_better_than_threshold: float = 0.2
def __post_init__(self):
"""Initialize LlamaIndex worker"""
checkpoint_dir = os.path.join(self.working_dir, f"vector_{self.namespace}")
# 使用正确的文件命名格式
self._vector_file = os.path.join(self.working_dir, f"vdb_{self.namespace}.json")
# 设置检查点目录
checkpoint_dir = os.path.join(self.working_dir, f"vector_{self.namespace}_checkpoint")
os.makedirs(checkpoint_dir, exist_ok=True)
# 初始化向量存储
self.vector_store = LlamaIndexRagWorker(
user_name=self.namespace,
llm_kwargs=self.llm_kwargs,
checkpoint_dir=checkpoint_dir,
auto_load_checkpoint=True # 自动加载检查点
auto_load_checkpoint=True
)
async def query(self, query: str, top_k: int = 5) -> List[dict]:
logger.info(f"Initialized vector storage for {self.namespace}")
async def query(self, query: str, top_k: int = 5, metadata_filters: Optional[Dict[str, Any]] = None) -> List[dict]:
"""
Query vectors by similarity
Query vectors by similarity with optional metadata filtering
Args:
query: Query text
top_k: Maximum number of results
top_k: Maximum number of results to return
metadata_filters: Optional metadata filters
Returns:
List of similar documents with scores
"""
nodes = self.vector_store.retrieve_from_store_with_query(query)
results = [{
"id": node.node_id,
"text": node.text,
"score": node.score,
**{k: getattr(node, k) for k in self.meta_fields if hasattr(node, k)}
} for node in nodes[:top_k]]
return [r for r in results if r.get('score', 0) > self.cosine_better_than_threshold]
try:
if metadata_filters:
nodes = self.vector_store.retrieve_with_metadata_filter(query, metadata_filters, top_k)
else:
nodes = self.vector_store.retrieve_from_store_with_query(query)[:top_k]
results = []
for node in nodes:
result = {
"id": node.node_id,
"text": node.text,
"score": node.score if hasattr(node, 'score') else 0.0,
}
# Add metadata fields if they exist and are in meta_fields
if hasattr(node, 'metadata'):
result.update({
k: node.metadata[k]
for k in self.meta_fields
if k in node.metadata
})
results.append(result)
return results
except Exception as e:
logger.error(f"Error in vector query: {e}")
raise
async def upsert(self, data: Dict[str, dict]):
"""
Insert or update vectors
Args:
data: Dictionary of documents to insert/update
data: Dictionary of documents to insert/update with format:
{id: {"content": text, "metadata": dict}}
"""
for id, item in data.items():
content = item["content"]
metadata = {k: item[k] for k in self.meta_fields if k in item}
self.vector_store.add_text_with_metadata(content, metadata=metadata)
try:
for doc_id, item in data.items():
content = item["content"]
# 提取元数据
metadata = {
k: item[k]
for k in self.meta_fields
if k in item
}
# 添加文档ID到元数据
metadata["doc_id"] = doc_id
# 添加到向量存储
self.vector_store.add_text_with_metadata(content, metadata)
# 导出向量数据到json文件
self.vector_store.export_nodes(
self._vector_file,
format="json",
include_embeddings=True
)
except Exception as e:
logger.error(f"Error in vector upsert: {e}")
raise
async def save(self):
"""Save vector store to checkpoint and export data"""
try:
# 保存检查点
self.vector_store.save_to_checkpoint()
# 导出向量数据
self.vector_store.export_nodes(
self._vector_file,
format="json",
include_embeddings=True
)
except Exception as e:
logger.error(f"Error saving vector storage: {e}")
raise
async def index_done_callback(self):
"""Save after indexing"""
self.vector_store.save_to_checkpoint()
await self.save()
def get_statistics(self) -> Dict[str, Any]:
"""Get vector store statistics"""
return self.vector_store.get_statistics()
@dataclass
class NetworkStorage(StorageBase):
"""
Graph storage using NetworkX
Attributes:
namespace (str): Storage namespace
working_dir (str): Working directory for storage files
"""
def __post_init__(self):
"""Initialize graph and storage file"""
self._file_name = os.path.join(self.working_dir, f"graph_{self.namespace}.graphml")
self._graph = self._load_graph() or nx.Graph()
logger.info(f"Initialized graph storage for {self.namespace}")
def _load_graph(self) -> Optional[nx.Graph]:
"""Load graph from GraphML file"""
if os.path.exists(self._file_name):
try:
return nx.read_graphml(self._file_name)
graph = nx.read_graphml(self._file_name)
logger.info(f"Loaded graph with {graph.number_of_nodes()} nodes and {graph.number_of_edges()} edges")
return graph
except Exception as e:
logger.error(f"Error loading graph from {self._file_name}: {e}")
return None
return None
async def save_graph(self):
"""Save graph to GraphML file"""
os.makedirs(os.path.dirname(self._file_name), exist_ok=True)
logger.info(f"Saving graph with {self._graph.number_of_nodes()} nodes, {self._graph.number_of_edges()} edges")
nx.write_graphml(self._graph, self._file_name)
try:
os.makedirs(os.path.dirname(self._file_name), exist_ok=True)
logger.info(
f"Saving graph with {self._graph.number_of_nodes()} nodes, {self._graph.number_of_edges()} edges")
nx.write_graphml(self._graph, self._file_name)
except Exception as e:
logger.error(f"Error saving graph: {e}")
raise
async def has_node(self, node_id: str) -> bool:
"""Check if node exists"""
return self._graph.has_node(node_id)
async def has_edge(self, source_id: str, target_id: str) -> bool:
"""Check if edge exists"""
return self._graph.has_edge(source_id, target_id)
async def get_node(self, node_id: str) -> Optional[dict]:
"""Get node attributes"""
if not self._graph.has_node(node_id):
return None
return dict(self._graph.nodes[node_id])
async def get_edge(self, source_id: str, target_id: str) -> Optional[dict]:
"""Get edge attributes"""
if not self._graph.has_edge(source_id, target_id):
return None
return dict(self._graph.edges[source_id, target_id])
async def node_degree(self, node_id: str) -> int:
"""Get node degree"""
return self._graph.degree(node_id)
async def edge_degree(self, source_id: str, target_id: str) -> int:
"""Get sum of degrees of edge endpoints"""
return self._graph.degree(source_id) + self._graph.degree(target_id)
async def get_node_edges(self, source_id: str) -> Optional[List[Tuple[str, str]]]:
"""Get all edges connected to node"""
if not self._graph.has_node(source_id):
return None
return list(self._graph.edges(source_id))
async def upsert_node(self, node_id: str, node_data: Dict[str, str]):
"""Insert or update node"""
# Clean and normalize node data
cleaned_data = {k: html.escape(str(v).upper().strip()) for k, v in node_data.items()}
self._graph.add_node(node_id, **cleaned_data)
await self.save_graph()
async def upsert_edge(self, source_id: str, target_id: str, edge_data: Dict[str, str]):
"""Insert or update edge"""
# Clean and normalize edge data
cleaned_data = {k: html.escape(str(v).strip()) for k, v in edge_data.items()}
self._graph.add_edge(source_id, target_id, **cleaned_data)
await self.save_graph()
async def index_done_callback(self):
"""Save after indexing"""
await self.save_graph()
@@ -245,47 +323,47 @@ class NetworkStorage(StorageBase):
"""Get the largest connected component of the graph"""
if not self._graph:
return nx.Graph()
components = list(nx.connected_components(self._graph))
if not components:
return nx.Graph()
largest_component = max(components, key=len)
return self._graph.subgraph(largest_component).copy()
async def embed_nodes(self, algorithm: str, **kwargs) -> Tuple[np.ndarray, List[str]]:
"""
Embed nodes using specified algorithm
Args:
algorithm: Node embedding algorithm name
**kwargs: Additional algorithm parameters
Returns:
Tuple of (node embeddings, node IDs)
"""
async def embed_nodes(
self,
algorithm: str = "node2vec",
dimensions: int = 128,
walk_length: int = 30,
num_walks: int = 200,
workers: int = 4,
window: int = 10,
min_count: int = 1,
**kwargs
) -> Tuple[np.ndarray, List[str]]:
"""Generate node embeddings using specified algorithm"""
if algorithm == "node2vec":
from node2vec import Node2Vec
# Create node2vec model
node2vec = Node2Vec(
# Create and train node2vec model
n2v = Node2Vec(
self._graph,
dimensions=kwargs.get('dimensions', 128),
walk_length=kwargs.get('walk_length', 30),
num_walks=kwargs.get('num_walks', 200),
workers=kwargs.get('workers', 4)
dimensions=dimensions,
walk_length=walk_length,
num_walks=num_walks,
workers=workers
)
# Train model
model = node2vec.fit(
window=kwargs.get('window', 10),
min_count=kwargs.get('min_count', 1)
model = n2v.fit(
window=window,
min_count=min_count
)
# Get embeddings
# Get embeddings for all nodes
node_ids = list(self._graph.nodes())
embeddings = np.array([model.wv[node] for node in node_ids])
return embeddings, node_ids
else:
raise ValueError(f"Unsupported embedding algorithm: {algorithm}")
raise ValueError(f"Unsupported embedding algorithm: {algorithm}")

查看文件

@@ -23,25 +23,29 @@ class ExtractionExample:
def __init__(self):
"""Initialize RAG system components"""
# 设置工作目录
self.working_dir = f"private_upload/default_user/rag_cache_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
self.working_dir = f"crazy_functions/rag_fns/LightRAG/rag_cache_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
os.makedirs(self.working_dir, exist_ok=True)
logger.info(f"Working directory: {self.working_dir}")
# 初始化embedding
self.llm_kwargs = {'api_key': os.getenv("one_api_key"), 'client_ip': '127.0.0.1',
'embed_model': 'text-embedding-3-small', 'llm_model': 'one-api-Qwen2.5-72B-Instruct',
'max_length': 4096, 'most_recent_uploaded': None, 'temperature': 1, 'top_p': 1}
self.llm_kwargs = {
'api_key': os.getenv("one_api_key"),
'client_ip': '127.0.0.1',
'embed_model': 'text-embedding-3-small',
'llm_model': 'one-api-Qwen2.5-72B-Instruct',
'max_length': 4096,
'most_recent_uploaded': None,
'temperature': 1,
'top_p': 1
}
self.embedding_func = OpenAiEmbeddingModel(self.llm_kwargs)
# 初始化提示模板和抽取器
self.prompt_templates = PromptTemplates()
self.extractor = EntityRelationExtractor(
prompt_templates=self.prompt_templates,
required_prompts = {
'entity_extraction'
},
required_prompts={'entity_extraction'},
entity_extract_max_gleaning=1
)
# 初始化存储系统
@@ -63,18 +67,33 @@ class ExtractionExample:
working_dir=self.working_dir
)
# 向量存储 - 用于相似度检索
self.vector_store = VectorStorage(
namespace="vectors",
# 向量存储 - 用于实体、关系和文本块的向量表示
self.entities_vdb = VectorStorage(
namespace="entities",
working_dir=self.working_dir,
llm_kwargs=self.llm_kwargs,
embedding_func=self.embedding_func,
meta_fields={"entity_name", "entity_type"}
)
self.relationships_vdb = VectorStorage(
namespace="relationships",
working_dir=self.working_dir,
llm_kwargs=self.llm_kwargs,
embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"}
)
self.chunks_vdb = VectorStorage(
namespace="chunks",
working_dir=self.working_dir,
llm_kwargs=self.llm_kwargs,
embedding_func=self.embedding_func
)
# 图存储 - 用于实体关系
self.graph_store = NetworkStorage(
namespace="graph",
namespace="chunk_entity_relation",
working_dir=self.working_dir
)
@@ -152,7 +171,7 @@ class ExtractionExample:
try:
# 向量存储
logger.info("Adding chunks to vector store...")
await self.vector_store.upsert(chunks)
await self.chunks_vdb.upsert(chunks)
# 初始化对话历史
self.conversation_history = {chunk_key: [] for chunk_key in chunks.keys()}
@@ -178,14 +197,32 @@ class ExtractionExample:
# 获取结果
nodes, edges = self.extractor.get_results()
# 存储图数据库
logger.info("Storing extracted information in graph database...")
# 存储实体到向量数据库和图数据库
for node_name, node_instances in nodes.items():
for node in node_instances:
# 存储到向量数据库
await self.entities_vdb.upsert({
f"entity_{node_name}": {
"content": f"{node_name}: {node['description']}",
"entity_name": node_name,
"entity_type": node['entity_type']
}
})
# 存储到图数据库
await self.graph_store.upsert_node(node_name, node)
# 存储关系到向量数据库和图数据库
for (src, tgt), edge_instances in edges.items():
for edge in edge_instances:
# 存储到向量数据库
await self.relationships_vdb.upsert({
f"rel_{src}_{tgt}": {
"content": f"{edge['description']} | {edge['keywords']}",
"src_id": src,
"tgt_id": tgt
}
})
# 存储到图数据库
await self.graph_store.upsert_edge(src, tgt, edge)
return nodes, edges
@@ -197,26 +234,39 @@ class ExtractionExample:
async def query_knowledge_base(self, query: str, top_k: int = 5):
"""Query the knowledge base using various methods"""
try:
# 向量相似度搜索
vector_results = await self.vector_store.query(query, top_k=top_k)
# 向量相似度搜索 - 文本块
chunk_results = await self.chunks_vdb.query(query, top_k=top_k)
# 向量相似度搜索 - 实体
entity_results = await self.entities_vdb.query(query, top_k=top_k)
# 获取相关文本块
chunk_ids = [r["id"] for r in vector_results]
chunk_ids = [r["id"] for r in chunk_results]
chunks = await self.text_chunks.get_by_ids(chunk_ids)
# 获取相关实体
# 假设query中包含实体名称
relevant_nodes = []
for word in query.split():
if await self.graph_store.has_node(word.upper()):
node_data = await self.graph_store.get_node(word.upper())
if node_data:
relevant_nodes.append(node_data)
# 获取实体相关的图结构信息
relevant_edges = []
for entity in entity_results:
if "entity_name" in entity:
entity_name = entity["entity_name"]
if await self.graph_store.has_node(entity_name):
edges = await self.graph_store.get_node_edges(entity_name)
if edges:
edge_data = []
for edge in edges:
edge_info = await self.graph_store.get_edge(edge[0], edge[1])
if edge_info:
edge_data.append({
"source": edge[0],
"target": edge[1],
"data": edge_info
})
relevant_edges.extend(edge_data)
return {
"vector_results": vector_results,
"text_chunks": chunks,
"relevant_entities": relevant_nodes
"chunks": chunks,
"entities": entity_results,
"relationships": relevant_edges
}
except Exception as e:
@@ -228,30 +278,27 @@ class ExtractionExample:
os.makedirs(export_dir, exist_ok=True)
try:
# 导出向量存储
self.vector_store.vector_store.export_nodes(
os.path.join(export_dir, "vector_nodes.json"),
include_embeddings=True
)
# 导出图数据统计
graph_stats = {
"total_nodes": len(list(self.graph_store._graph.nodes())),
"total_edges": len(list(self.graph_store._graph.edges())),
"node_degrees": dict(self.graph_store._graph.degree()),
"largest_component_size": len(self.graph_store.get_largest_connected_component())
}
with open(os.path.join(export_dir, "graph_stats.json"), "w") as f:
json.dump(graph_stats, f, indent=2)
# 导出存储统计
# 导出统计信息
storage_stats = {
"chunks": len(self.text_chunks._data),
"docs": len(self.full_docs._data),
"vector_store": self.vector_store.vector_store.get_statistics()
"chunks": {
"total": len(self.text_chunks._data),
"vector_stats": self.chunks_vdb.get_statistics()
},
"entities": {
"vector_stats": self.entities_vdb.get_statistics()
},
"relationships": {
"vector_stats": self.relationships_vdb.get_statistics()
},
"graph": {
"total_nodes": len(list(self.graph_store._graph.nodes())),
"total_edges": len(list(self.graph_store._graph.edges())),
"node_degrees": dict(self.graph_store._graph.degree()),
"largest_component_size": len(self.graph_store.get_largest_connected_component())
}
}
# 导出统计
with open(os.path.join(export_dir, "storage_stats.json"), "w") as f:
json.dump(storage_stats, f, indent=2)
@@ -299,19 +346,6 @@ async def main():
the company's commitment to innovation and sustainability. The new iPhone
features groundbreaking AI capabilities.
""",
# "business_news": """
# Microsoft and OpenAI expanded their partnership today.
# Satya Nadella emphasized the importance of AI development while
# Sam Altman discussed the future of large language models. The collaboration
# aims to accelerate AI research and deployment.
# """,
#
# "science_paper": """
# Researchers at DeepMind published a breakthrough paper on quantum computing.
# The team demonstrated novel approaches to quantum error correction.
# Dr. Sarah Johnson led the research, collaborating with Google's quantum lab.
# """
}
try: