这个提交包含在:
lbykkkk
2024-11-10 15:06:50 +08:00
父节点 b8617921f4
当前提交 68aa846a89
共有 12 个文件被更改,包括 456 次插入287 次删除

查看文件

@@ -1,6 +1,7 @@
import os.path
from toolbox import CatchException, update_ui
from crazy_functions.rag_essay_fns.paper_processing import ArxivPaperProcessor
from crazy_functions.rag_essay_fns.rag_handler import RagHandler
import asyncio
@CatchException
@@ -9,7 +10,24 @@ def Rag论文对话(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_pro
txt: 用户输入,通常是arxiv论文链接
功能RAG论文总结和对话
"""
# 初始化处理器
if_project, if_arxiv = False, False
if os.path.exists(txt):
from crazy_functions.rag_essay_fns.document_splitter import SmartDocumentSplitter
splitter = SmartDocumentSplitter(
char_range=(1000, 1200),
max_workers=32 # 可选,默认会根据CPU核心数自动设置
)
if_project = True
else:
from crazy_functions.rag_essay_fns.arxiv_splitter import SmartArxivSplitter
splitter = SmartArxivSplitter(
char_range=(1000, 1200),
root_dir="gpt_log/arxiv_cache"
)
if_arxiv = True
for fragment in splitter.process(txt):
pass
# 初始化处理器
processor = ArxivPaperProcessor()
rag_handler = RagHandler()

查看文件

@@ -152,8 +152,6 @@ class Conversation_To_File_Wrap(GptAcademicPluginTemplate):
def hide_cwd(str):
import os
current_path = os.getcwd()

查看文件

@@ -20,9 +20,7 @@ class ArxivFragment:
segment_type: str
title: str
abstract: str
section: str # 保存完整的section层级路径,如 "Introduction" 或 "Methods-Data Processing"
section_type: str # 新增:标识片段类型,如 "abstract", "section", "subsection" 等
section_level: int # 新增section的层级深度,abstract为0,main section为1,subsection为2,等等
section: str
is_appendix: bool
@@ -116,6 +114,100 @@ class SmartArxivSplitter:
return result
def _smart_split(self, content: str) -> List[Tuple[str, str, bool]]:
"""智能分割TEX内容,确保在字符范围内并保持语义完整性"""
content = self._preprocess_content(content)
segments = []
current_buffer = []
current_length = 0
current_section = "Unknown Section"
is_appendix = False
# 保护特殊环境
protected_blocks = {}
content = self._protect_special_environments(content, protected_blocks)
# 按段落分割
paragraphs = re.split(r'\n\s*\n', content)
for para in paragraphs:
para = para.strip()
if not para:
continue
# 恢复特殊环境
para = self._restore_special_environments(para, protected_blocks)
# 更新章节信息
section_info = self._get_section_info(para, content)
if section_info:
current_section, is_appendix = section_info
# 判断是否是特殊环境
if self._is_special_environment(para):
# 处理当前缓冲区
if current_buffer:
segments.append((
'\n'.join(current_buffer),
current_section,
is_appendix
))
current_buffer = []
current_length = 0
# 添加特殊环境作为独立片段
segments.append((para, current_section, is_appendix))
continue
# 处理普通段落
sentences = self._split_into_sentences(para)
for sentence in sentences:
sentence = sentence.strip()
if not sentence:
continue
sent_length = len(sentence)
new_length = current_length + sent_length + (1 if current_buffer else 0)
if new_length <= self.max_chars:
current_buffer.append(sentence)
current_length = new_length
else:
# 如果当前缓冲区达到最小长度要求
if current_length >= self.min_chars:
segments.append((
'\n'.join(current_buffer),
current_section,
is_appendix
))
current_buffer = [sentence]
current_length = sent_length
else:
# 尝试将过长的句子分割
split_sentences = self._split_long_sentence(sentence)
for split_sent in split_sentences:
if current_length + len(split_sent) <= self.max_chars:
current_buffer.append(split_sent)
current_length += len(split_sent) + 1
else:
segments.append((
'\n'.join(current_buffer),
current_section,
is_appendix
))
current_buffer = [split_sent]
current_length = len(split_sent)
# 处理剩余的缓冲区
if current_buffer:
segments.append((
'\n'.join(current_buffer),
current_section,
is_appendix
))
return segments
def _split_into_sentences(self, text: str) -> List[str]:
"""将文本分割成句子"""
return re.split(r'(?<=[.!?。!?])\s+', text)
@@ -194,7 +286,7 @@ class SmartArxivSplitter:
content = re.sub(r'\\(label|ref|cite)\{[^}]*\}', '', content)
return content.strip()
def process_paper(self, arxiv_id_or_url: str) -> Generator[ArxivFragment, None, None]:
def process(self, arxiv_id_or_url: str) -> Generator[ArxivFragment, None, None]:
"""处理单篇arxiv论文"""
try:
arxiv_id = self._normalize_arxiv_id(arxiv_id_or_url)
@@ -318,31 +410,16 @@ class SmartArxivSplitter:
return title.strip(), abstract.strip()
def _get_section_info(self, para: str, content: str) -> Optional[Tuple[str, str, int, bool]]:
"""获取段落所属的章节信息,返回(section_path, section_type, level, is_appendix)"""
current_path = []
section_type = "content"
level = 0
def _get_section_info(self, para: str, content: str) -> Optional[Tuple[str, bool]]:
"""获取段落所属的章节信息"""
section = "Unknown Section"
is_appendix = False
# 定义section层级的正则模式
section_patterns = {
r'\\chapter\{([^}]+)\}': 1,
r'\\section\{([^}]+)\}': 1,
r'\\subsection\{([^}]+)\}': 2,
r'\\subsubsection\{([^}]+)\}': 3
}
# 查找所有章节标记
all_sections = []
for pattern, sec_level in section_patterns.items():
for pattern in self.section_patterns:
for match in re.finditer(pattern, content):
all_sections.append((match.start(), match.group(1), sec_level))
# 检查是否是摘要
abstract_match = re.search(r'\\begin{abstract}.*?' + re.escape(para), content, re.DOTALL)
if abstract_match:
return "Abstract", "abstract", 0, False
all_sections.append((match.start(), match.group(2)))
# 查找appendix标记
appendix_pos = content.find(r'\appendix')
@@ -350,118 +427,19 @@ class SmartArxivSplitter:
# 确定当前章节
para_pos = content.find(para)
if para_pos >= 0:
is_appendix = appendix_pos >= 0 and para_pos > appendix_pos
current_sections = []
current_level = 0
# 按位置排序所有section标记
for sec_pos, sec_title, sec_level in sorted(all_sections):
current_section = None
for sec_pos, sec_title in sorted(all_sections):
if sec_pos > para_pos:
break
# 如果遇到更高层级的section,清除所有更低层级的section
if sec_level <= current_level:
current_sections = [s for s in current_sections if s[1] < sec_level]
current_sections.append((sec_title, sec_level))
current_level = sec_level
current_section = sec_title
# 构建section路径
if current_sections:
current_path = [s[0] for s in sorted(current_sections, key=lambda x: x[1])]
section_path = "-".join(current_path)
level = max(s[1] for s in current_sections)
section_type = "section" if level == 1 else "subsection"
return section_path, section_type, level, is_appendix
if current_section:
section = current_section
is_appendix = appendix_pos >= 0 and para_pos > appendix_pos
return "Unknown Section", "content", 0, is_appendix
return section, is_appendix
def _smart_split(self, content: str) -> List[Tuple[str, str, str, int, bool]]:
"""智能分割TEX内容,确保在字符范围内并保持语义完整性"""
content = self._preprocess_content(content)
segments = []
current_buffer = []
current_length = 0
current_section_info = ("Unknown Section", "content", 0, False)
# 保护特殊环境
protected_blocks = {}
content = self._protect_special_environments(content, protected_blocks)
# 按段落分割
paragraphs = re.split(r'\n\s*\n', content)
for para in paragraphs:
para = para.strip()
if not para:
continue
# 恢复特殊环境
para = self._restore_special_environments(para, protected_blocks)
# 更新章节信息
section_info = self._get_section_info(para, content)
if section_info:
current_section_info = section_info
# 判断是否是特殊环境
if self._is_special_environment(para):
# 处理当前缓冲区
if current_buffer:
segments.append((
'\n'.join(current_buffer),
*current_section_info
))
current_buffer = []
current_length = 0
# 添加特殊环境作为独立片段
segments.append((para, *current_section_info))
continue
# 处理普通段落
sentences = self._split_into_sentences(para)
for sentence in sentences:
sentence = sentence.strip()
if not sentence:
continue
sent_length = len(sentence)
new_length = current_length + sent_length + (1 if current_buffer else 0)
if new_length <= self.max_chars:
current_buffer.append(sentence)
current_length = new_length
else:
# 如果当前缓冲区达到最小长度要求
if current_length >= self.min_chars:
segments.append((
'\n'.join(current_buffer),
*current_section_info
))
current_buffer = [sentence]
current_length = sent_length
else:
# 尝试将过长的句子分割
split_sentences = self._split_long_sentence(sentence)
for split_sent in split_sentences:
if current_length + len(split_sent) <= self.max_chars:
current_buffer.append(split_sent)
current_length += len(split_sent) + 1
else:
segments.append((
'\n'.join(current_buffer),
*current_section_info
))
current_buffer = [split_sent]
current_length = len(split_sent)
# 处理剩余的缓冲区
if current_buffer:
segments.append((
'\n'.join(current_buffer),
*current_section_info
))
return segments
return None
def _process_single_tex(self, file_path: str) -> List[ArxivFragment]:
"""处理单个TEX文件"""
@@ -481,12 +459,12 @@ class SmartArxivSplitter:
segments = self._smart_split(content)
fragments = []
for i, (segment_content, section_path, section_type, level, is_appendix) in enumerate(segments):
for i, (segment_content, section, is_appendix) in enumerate(segments):
if segment_content.strip():
segment_type = 'text'
for env_type, patterns in self.special_envs.items():
if any(re.search(pattern, segment_content, re.DOTALL)
for pattern in patterns):
for pattern in patterns):
segment_type = env_type
break
@@ -499,9 +477,7 @@ class SmartArxivSplitter:
segment_type=segment_type,
title=title,
abstract=abstract,
section=section_path,
section_type=section_type,
section_level=level,
section=section,
is_appendix=is_appendix
))
@@ -511,7 +487,6 @@ class SmartArxivSplitter:
logging.error(f"Error processing file {file_path}: {e}")
return []
def main():
"""使用示例"""
# 创建分割器实例
@@ -521,10 +496,11 @@ def main():
)
# 处理论文
for fragment in splitter.process_paper("2411.03663"):
for fragment in splitter.process("2411.03663"):
print(f"Segment {fragment.segment_index + 1}/{fragment.total_segments}")
print(f"Length: {len(fragment.content)}")
print(f"Section: {fragment.section}")
print(f"Title: {fragment.file_path}")
print(fragment.content)
print("-" * 80)

查看文件

@@ -1,7 +1,6 @@
from typing import Tuple, Optional, Generator, List
from toolbox import update_ui, update_ui_lastest_msg, get_conf
import os, tarfile, requests, time, re
class ArxivPaperProcessor:
"""Arxiv论文处理器类"""

查看文件

@@ -81,5 +81,84 @@ class RagHandler:
)
)
return response
except Exception as e:
return f"查询出错: {str(e)}"
class RagHandler:
def __init__(self):
# 初始化工作目录
self.working_dir = os.path.join(get_conf('ARXIV_CACHE_DIR'), 'rag_cache')
if not os.path.exists(self.working_dir):
os.makedirs(self.working_dir)
# 初始化 LightRAG
self.rag = LightRAG(
working_dir=self.working_dir,
llm_model_func=self._llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=1536, # OpenAI embedding 维度
max_token_size=8192,
func=self._embedding_func,
),
)
async def _llm_model_func(self, prompt: str, system_prompt: str = None,
history_messages: List = None, **kwargs) -> str:
"""LLM 模型函数"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if history_messages:
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
response = await openai.ChatCompletion.acreate(
model="gpt-3.5-turbo",
messages=messages,
temperature=kwargs.get("temperature", 0),
max_tokens=kwargs.get("max_tokens", 1000)
)
return response.choices[0].message.content
async def _embedding_func(self, texts: List[str]) -> np.ndarray:
"""Embedding 函数"""
response = await openai.Embedding.acreate(
model="text-embedding-ada-002",
input=texts
)
embeddings = [item["embedding"] for item in response["data"]]
return np.array(embeddings)
def process_paper_content(self, paper_content: Dict) -> None:
"""处理论文内容,构建知识图谱"""
# 处理标题和摘要
content_list = []
if paper_content['title']:
content_list.append(f"Title: {paper_content['title']}")
if paper_content['abstract']:
content_list.append(f"Abstract: {paper_content['abstract']}")
# 添加分段内容
content_list.extend(paper_content['segments'])
# 插入到 RAG 系统
self.rag.insert(content_list)
def query(self, question: str, mode: str = "hybrid") -> str:
"""查询论文内容
mode: 查询模式,可选 naive/local/global/hybrid
"""
try:
response = self.rag.query(
question,
param=QueryParam(
mode=mode,
top_k=5, # 返回相关度最高的5个结果
max_token_for_text_unit=2048, # 每个文本单元的最大token数
response_type="detailed" # 返回详细回答
)
)
return response
except Exception as e:
return f"查询出错: {str(e)}"

查看文件

@@ -1,10 +1,10 @@
import atexit
from loguru import logger
from typing import List
from typing import List, Dict, Optional, Any, Tuple
from llama_index.core import Document
from llama_index.core.ingestion import run_transformations
from llama_index.core.schema import TextNode
from llama_index.core.schema import TextNode, NodeWithScore
from loguru import logger
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
@@ -135,4 +135,170 @@ class LlamaIndexRagWorker(SaveLoad):
"""
Purges the current vector store and creates a new one.
"""
self.purge()
self.purge()
"""
以下是添加的新方法,原有方法保持不变
"""
def add_text_with_metadata(self, text: str, metadata: dict) -> str:
"""
添加带元数据的文本到向量存储
Args:
text: 文本内容
metadata: 元数据字典
Returns:
添加的节点ID
"""
node = TextNode(text=text, metadata=metadata)
nodes = run_transformations(
[node],
self.vs_index._transformations,
show_progress=True
)
self.vs_index.insert_nodes(nodes)
return nodes[0].node_id if nodes else None
def batch_add_texts_with_metadata(self, texts: List[Tuple[str, dict]]) -> List[str]:
"""
批量添加带元数据的文本
Args:
texts: (text, metadata)元组列表
Returns:
添加的节点ID列表
"""
nodes = [TextNode(text=t, metadata=m) for t, m in texts]
transformed_nodes = run_transformations(
nodes,
self.vs_index._transformations,
show_progress=True
)
if transformed_nodes:
self.vs_index.insert_nodes(transformed_nodes)
return [node.node_id for node in transformed_nodes]
return []
def get_node_metadata(self, node_id: str) -> Optional[dict]:
"""
获取节点的元数据
Args:
node_id: 节点ID
Returns:
节点的元数据字典
"""
node = self.vs_index.storage_context.docstore.docs.get(node_id)
return node.metadata if node else None
def update_node_metadata(self, node_id: str, metadata: dict, merge: bool = True) -> bool:
"""
更新节点的元数据
Args:
node_id: 节点ID
metadata: 新的元数据
merge: 是否与现有元数据合并
Returns:
是否更新成功
"""
docstore = self.vs_index.storage_context.docstore
if node_id in docstore.docs:
node = docstore.docs[node_id]
if merge:
node.metadata.update(metadata)
else:
node.metadata = metadata
return True
return False
def filter_nodes_by_metadata(self, filters: Dict[str, Any]) -> List[TextNode]:
"""
按元数据过滤节点
Args:
filters: 元数据过滤条件
Returns:
符合条件的节点列表
"""
docstore = self.vs_index.storage_context.docstore
results = []
for node in docstore.docs.values():
if all(node.metadata.get(k) == v for k, v in filters.items()):
results.append(node)
return results
def retrieve_with_metadata_filter(
self,
query: str,
metadata_filters: Dict[str, Any],
top_k: int = 5
) -> List[NodeWithScore]:
"""
结合元数据过滤的检索
Args:
query: 查询文本
metadata_filters: 元数据过滤条件
top_k: 返回结果数量
Returns:
检索结果节点列表
"""
retriever = self.vs_index.as_retriever(similarity_top_k=top_k)
nodes = retriever.retrieve(query)
# 应用元数据过滤
filtered_nodes = []
for node in nodes:
if all(node.metadata.get(k) == v for k, v in metadata_filters.items()):
filtered_nodes.append(node)
return filtered_nodes
def get_node_stats(self, node_id: str) -> dict:
"""
获取单个节点的统计信息
Args:
node_id: 节点ID
Returns:
节点统计信息字典
"""
node = self.vs_index.storage_context.docstore.docs.get(node_id)
if not node:
return {}
return {
"text_length": len(node.text),
"token_count": len(node.text.split()),
"has_embedding": node.embedding is not None,
"metadata_keys": list(node.metadata.keys()),
}
def get_nodes_by_content_pattern(self, pattern: str) -> List[TextNode]:
"""
按内容模式查找节点
Args:
pattern: 正则表达式模式
Returns:
匹配的节点列表
"""
import re
docstore = self.vs_index.storage_context.docstore
matched_nodes = []
for node in docstore.docs.values():
if re.search(pattern, node.text):
matched_nodes.append(node)
return matched_nodes

查看文件

@@ -1,22 +1,45 @@
import os
from llama_index.core import SimpleDirectoryReader
supports_format = ['.csv', '.docx', '.epub', '.ipynb', '.mbox', '.md', '.pdf', '.txt', '.ppt',
'.pptm', '.pptx']
supports_format = ['.csv', '.docx','.doc', '.epub', '.ipynb', '.mbox', '.md', '.pdf', '.txt', '.ppt',
'.pptm', '.pptx','.py', '.xls', '.xlsx', '.html', '.json', '.xml', '.yaml', '.yml' ,'.m']
def read_docx_doc(file_path):
if file_path.split(".")[-1] == "docx":
from docx import Document
doc = Document(file_path)
file_content = "\n".join([para.text for para in doc.paragraphs])
else:
try:
import win32com.client
word = win32com.client.Dispatch("Word.Application")
word.visible = False
# 打开文件
doc = word.Documents.Open(os.getcwd() + '/' + file_path)
# file_content = doc.Content.Text
doc = word.ActiveDocument
file_content = doc.Range().Text
doc.Close()
word.Quit()
except:
raise RuntimeError('请先将.doc文档转换为.docx文档。')
return file_content
# 修改后的 extract_text 函数,结合 SimpleDirectoryReader 和自定义解析逻辑
import os
def extract_text(file_path):
_, ext = os.path.splitext(file_path.lower())
# 使用 SimpleDirectoryReader 处理它支持的文件格式
if ext in supports_format:
try:
reader = SimpleDirectoryReader(input_files=[file_path])
documents = reader.load_data()
if len(documents) > 0:
return documents[0].text
except Exception as e:
pass
if ext in ['.docx', '.doc']:
return read_docx_doc(file_path)
try:
reader = SimpleDirectoryReader(input_files=[file_path])
documents = reader.load_data()
if len(documents) > 0:
return documents[0].text
except Exception as e:
pass
return None