这个提交包含在:
lbykkkk
2024-11-17 23:15:34 +08:00
父节点 21626a44d5
当前提交 cbef9a908c
共有 3 个文件被更改,包括 512 次插入314 次删除

查看文件

@@ -15,7 +15,7 @@ def get_crazy_functions():
from crazy_functions.SourceCode_Analyse import 解析一个Rust项目
from crazy_functions.SourceCode_Analyse import 解析一个Java项目
from crazy_functions.SourceCode_Analyse import 解析一个前端项目
from crazy_functions.Arxiv_论文对话 import Rag论文对话
from crazy_functions.Arxiv_论文对话 import Arxiv论文对话
from crazy_functions.高级功能函数模板 import 高阶功能模板函数
from crazy_functions.高级功能函数模板 import Demo_Wrap
from crazy_functions.Latex全文润色 import Latex英文润色
@@ -31,6 +31,8 @@ def get_crazy_functions():
from crazy_functions.Markdown_Translate import Markdown英译中
from crazy_functions.批量总结PDF文档 import 批量总结PDF文档
from crazy_functions.PDF_Translate import 批量翻译PDF文档
from crazy_functions.批量文件询问 import 批量文件询问
from crazy_functions.谷歌检索小助手 import 谷歌检索小助手
from crazy_functions.理解PDF文档内容 import 理解PDF文档内容标准文件输入
from crazy_functions.Latex全文润色 import Latex中文润色
@@ -74,12 +76,25 @@ def get_crazy_functions():
"Function": HotReload(Latex翻译中文并重新编译PDF), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
"Class": Arxiv_Localize, # 新一代插件需要注册Class
},
"Rag论文对话": {
"批量文件询问": {
"Group": "学术",
"Color": "stop",
"AsButton": False,
"Info": "Arixv论文精细翻译 | 输入参数arxiv论文的ID,比如1812.10695",
"Function": HotReload(Rag论文对话), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
"AdvancedArgs": True,
"Info": "通过在高级参数区写入prompt,可自定义询问逻辑,默认情况下为总结逻辑 | 输入参数为路径",
"ArgsReminder": r"1、请不要更改上方输入框中以“private_upload/...”开头的路径。 "
r"2、请在下方高级参数区中输入你的prompt,文档中的内容将被添加你的prompt后。3、示例“请总结下面的内容”,此时,文档内容将添加在“”后 ",
"Function": HotReload(批量文件询问),
},
"Arxiv论文对话": {
"Group": "学术",
"Color": "stop",
"AsButton": False,
"AdvancedArgs": True,
"Info": "通过在高级参数区写入prompt,可自定义询问逻辑,默认情况下为总结逻辑 | 输入参数为路径",
"ArgsReminder": r"1、请不要更改上方输入框中以“private_upload/...”开头的路径。 "
r"2、请在下方高级参数区中输入你的prompt,文档中的内容将被添加你的prompt后。3、示例“这篇文章的方法是什么",
"Function": HotReload(Arxiv论文对话),
},
"翻译README或MD": {
"Group": "编程",
@@ -604,6 +619,23 @@ def get_crazy_functions():
logger.error(trimmed_format_exc())
logger.error("Load function plugin failed")
try:
from crazy_functions.Arxiv_论文对话 import Arxiv论文对话
function_plugins.update(
{
"Arxiv论文对话": {
"Group": "对话",
"Color": "stop",
"AsButton": False,
"Info": "将问答数据记录到向量库中,作为长期参考。",
"Function": HotReload(Arxiv论文对话),
},
}
)
except:
logger.error(trimmed_format_exc())
logger.error("Load function plugin failed")

查看文件

@@ -1,63 +1,416 @@
import os.path
import os
import logging
import asyncio
from pathlib import Path
from typing import List, Optional, Generator, Dict, Union
from datetime import datetime
from dataclasses import dataclass
from concurrent.futures import ThreadPoolExecutor
import aiohttp
from toolbox import CatchException, update_ui
from crazy_functions.rag_fns.arxiv_fns.paper_processing import ArxivPaperProcessor
from shared_utils.fastapi_server import validate_path_safety
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg
from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import ArxivSplitter, save_fragments_to_file
from crazy_functions.rag_fns.arxiv_fns.arxiv_fragment import ArxivFragment as Fragment
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
from crazy_functions.crazy_utils import input_clipping
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
# 全局常量配置
MAX_HISTORY_ROUND = 5 # 最大历史对话轮数
MAX_CONTEXT_TOKEN_LIMIT = 4096 # 上下文最大token数
REMEMBER_PREVIEW = 1000 # 记忆预览长度
VECTOR_STORE_TYPE = "Simple" # 向量存储类型Simple或Milvus
MAX_CONCURRENT_PAPERS = 5 # 最大并行处理论文数
MAX_WORKERS = 3 # 最大工作线程数
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
@dataclass
class ProcessingTask:
"""论文处理任务数据类"""
arxiv_id: str
status: str = "pending" # pending, processing, completed, failed
error: Optional[str] = None
fragments: List[Fragment] = None
class ArxivRagWorker:
def __init__(self, user_name: str, llm_kwargs: Dict):
self.user_name = user_name
self.llm_kwargs = llm_kwargs
# 初始化存储目录
self.checkpoint_dir = Path(get_log_folder(user_name, plugin_name='rag_cache'))
self.vector_store_dir = self.checkpoint_dir / "vector_store"
self.fragment_store_dir = self.checkpoint_dir / "fragments"
# 创建必要的目录
self.vector_store_dir.mkdir(parents=True, exist_ok=True)
self.fragment_store_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Vector store directory: {self.vector_store_dir}")
logger.info(f"Fragment store directory: {self.fragment_store_dir}")
# 初始化RAG worker
self.rag_worker = LlamaIndexRagWorker(
user_name=user_name,
llm_kwargs=llm_kwargs,
checkpoint_dir=str(self.vector_store_dir),
auto_load_checkpoint=True
)
# 初始化arxiv splitter
self.arxiv_splitter = ArxivSplitter(
char_range=(1000, 1200),
root_dir=str(self.checkpoint_dir / "arxiv_cache")
)
# 初始化并行处理组件
self.processing_queue = {}
self.semaphore = asyncio.Semaphore(MAX_CONCURRENT_PAPERS)
self.thread_pool = ThreadPoolExecutor(max_workers=MAX_WORKERS)
async def _process_fragments(self, fragments: List[Fragment]) -> None:
"""并行处理论文片段"""
if not fragments:
logger.warning("No fragments to process")
return
# 首先添加论文概述
overview = {
"title": fragments[0].title,
"abstract": fragments[0].abstract,
"arxiv_id": fragments[0].arxiv_id,
}
overview_text = (
f"Paper Title: {overview['title']}\n"
f"ArXiv ID: {overview['arxiv_id']}\n"
f"Abstract: {overview['abstract']}\n"
f"Type: OVERVIEW"
)
try:
# 同步添加概述
self.rag_worker.add_text_to_vector_store(overview_text)
logger.info(f"Added paper overview for {overview['arxiv_id']}")
# 并行处理其余片段
tasks = []
for i, fragment in enumerate(fragments):
task = asyncio.get_event_loop().run_in_executor(
self.thread_pool,
self._process_single_fragment,
fragment,
i
)
tasks.append(task)
await asyncio.gather(*tasks)
logger.info(f"Processed {len(fragments)} fragments successfully")
# 保存到本地文件用于调试
save_fragments_to_file(
fragments,
str(self.fragment_store_dir / f"{overview['arxiv_id']}_fragments.json")
)
except Exception as e:
logger.error(f"Error processing fragments: {str(e)}")
raise
def _process_single_fragment(self, fragment: Fragment, index: int) -> None:
"""处理单个论文片段"""
try:
text = (
f"Paper Title: {fragment.title}\n"
f"ArXiv ID: {fragment.arxiv_id}\n"
f"Section: {fragment.section}\n"
f"Fragment Index: {index}\n"
f"Content: {fragment.content}\n"
f"Type: FRAGMENT"
)
logger.info(f"Processing fragment {index} for paper {fragment.arxiv_id}")
self.rag_worker.add_text_to_vector_store(text)
logger.info(f"Successfully added fragment {index} to vector store")
except Exception as e:
logger.error(f"Error processing fragment {index}: {str(e)}")
raise
async def process_paper(self, arxiv_id: str) -> bool:
"""处理论文主函数"""
try:
arxiv_id = self._normalize_arxiv_id(arxiv_id)
logger.info(f"Starting to process paper: {arxiv_id}")
paper_path = self.checkpoint_dir / f"{arxiv_id}.processed"
if paper_path.exists():
logger.info(f"Paper {arxiv_id} already processed")
return True
# 创建处理任务
task = ProcessingTask(arxiv_id=arxiv_id)
self.processing_queue[arxiv_id] = task
task.status = "processing"
async with self.semaphore:
# 下载和分割论文
fragments = await self.arxiv_splitter.process(arxiv_id)
if not fragments:
raise ValueError(f"No fragments extracted from paper {arxiv_id}")
logger.info(f"Got {len(fragments)} fragments from paper {arxiv_id}")
# 处理片段
await self._process_fragments(fragments)
# 标记完成
paper_path.touch()
task.status = "completed"
task.fragments = fragments
logger.info(f"Successfully processed paper {arxiv_id}")
return True
except Exception as e:
logger.error(f"Error processing paper {arxiv_id}: {str(e)}")
if arxiv_id in self.processing_queue:
self.processing_queue[arxiv_id].status = "failed"
self.processing_queue[arxiv_id].error = str(e)
return False
def _normalize_arxiv_id(self, input_str: str) -> str:
"""规范化ArXiv ID"""
if 'arxiv.org/' in input_str.lower():
if '/pdf/' in input_str:
arxiv_id = input_str.split('/pdf/')[-1]
else:
arxiv_id = input_str.split('/abs/')[-1]
return arxiv_id.split('v')[0].strip()
return input_str.split('v')[0].strip()
async def wait_for_paper(self, arxiv_id: str, timeout: float = 300.0) -> bool:
"""等待论文处理完成"""
try:
start_time = datetime.now()
while True:
task = self.processing_queue.get(arxiv_id)
if not task:
return False
if task.status == "completed":
return True
if task.status == "failed":
return False
# 检查超时
if (datetime.now() - start_time).total_seconds() > timeout:
logger.error(f"Processing paper {arxiv_id} timed out")
return False
await asyncio.sleep(0.1)
except Exception as e:
logger.error(f"Error waiting for paper {arxiv_id}: {str(e)}")
return False
def retrieve_and_generate(self, query: str) -> str:
"""检索相关内容并生成提示词"""
try:
nodes = self.rag_worker.retrieve_from_store_with_query(query)
return self.rag_worker.build_prompt(query=query, nodes=nodes)
except Exception as e:
logger.error(f"Error in retrieve and generate: {str(e)}")
return ""
def remember_qa(self, question: str, answer: str) -> None:
"""记忆问答对"""
try:
self.rag_worker.remember_qa(question, answer)
except Exception as e:
logger.error(f"Error remembering QA: {str(e)}")
async def auto_analyze_paper(self, chatbot: List, history: List, system_prompt: str) -> None:
"""自动分析论文的关键问题"""
key_questions = [
"What is the main research question or problem addressed in this paper?",
"What methods or approaches did the authors use to investigate the problem?",
"What are the key findings or results presented in the paper?",
"How do the findings of this paper contribute to the broader field or topic of study?",
"What are the limitations of this study, and what future research directions do the authors suggest?"
]
results = []
for question in key_questions:
try:
prompt = self.retrieve_and_generate(question)
if prompt:
response = await request_gpt_model_in_new_thread_with_ui_alive(
inputs=prompt,
inputs_show_user=question,
llm_kwargs=self.llm_kwargs,
chatbot=chatbot,
history=history,
sys_prompt=system_prompt
)
results.append(f"Q: {question}\nA: {response}\n")
self.remember_qa(question, response)
except Exception as e:
logger.error(f"Error in auto analysis: {str(e)}")
# 合并所有结果
summary = "\n\n".join(results)
chatbot[-1] = (chatbot[-1][0], f"论文已成功加载并完成初步分析:\n\n{summary}\n\n您现在可以继续提问更多细节。")
@CatchException
def Rag论文对话(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
history: List, system_prompt: str, web_port: str) -> Generator:
"""
txt: 用户输入,通常是arxiv论文链接
功能RAG论文总结和对话
Arxiv论文对话主函数
Args:
txt: arxiv ID/URL
llm_kwargs: LLM配置参数
plugin_kwargs: 插件配置参数,包含 advanced_arg 字段作为用户询问指令
chatbot: 对话历史
history: 聊天历史
system_prompt: 系统提示词
web_port: Web端口
"""
if_project, if_arxiv = False, False
if os.path.exists(txt):
from crazy_functions.rag_fns.doc_fns.document_splitter import SmartDocumentSplitter
splitter = SmartDocumentSplitter(
char_range=(1000, 1200),
max_workers=32 # 可选,默认会根据CPU核心数自动设置
)
if_project = True
# 初始化时,提示用户需要 arxiv ID/URL
if len(history) == 0 and not txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '1', '2')):
chatbot.append((txt, "请先提供Arxiv论文链接或ID。"))
yield from update_ui(chatbot=chatbot, history=history)
return
user_name = chatbot.get_user()
worker = ArxivRagWorker(user_name, llm_kwargs)
# 处理新论文的情况
if txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '1', '2')):
chatbot.append((txt, "正在处理论文,请稍等..."))
yield from update_ui(chatbot=chatbot, history=history)
# 创建事件循环来处理异步调用
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# 运行异步处理函数
success = loop.run_until_complete(worker.process_paper(txt))
if success:
arxiv_id = worker._normalize_arxiv_id(txt)
success = loop.run_until_complete(worker.wait_for_paper(arxiv_id))
if success:
# 执行自动分析
yield from worker.auto_analyze_paper(chatbot, history, system_prompt)
finally:
loop.close()
if not success:
chatbot[-1] = (txt, "论文处理失败,请检查论文ID是否正确或稍后重试。")
yield from update_ui(chatbot=chatbot, history=history)
return
yield from update_ui(chatbot=chatbot, history=history)
return
# 处理用户询问的情况
# 获取用户询问指令
user_query = plugin_kwargs.get("advanced_arg", "")
if not user_query:
chatbot.append((txt, "请提供您的问题。"))
yield from update_ui(chatbot=chatbot, history=history)
return
# 处理历史对话长度
if len(history) > MAX_HISTORY_ROUND * 2:
history = history[-(MAX_HISTORY_ROUND * 2):]
# 处理询问指令
query_clip, history, flags = input_clipping(
user_query,
history,
max_token_limit=MAX_CONTEXT_TOKEN_LIMIT,
return_clip_flags=True
)
if flags["original_input_len"] != flags["clipped_input_len"]:
yield from update_ui_lastest_msg('检测到长输入,正在处理...', chatbot, history, delay=0)
if len(user_query) > REMEMBER_PREVIEW:
HALF = REMEMBER_PREVIEW // 2
query_to_remember = user_query[:HALF] + f" ...\n...(省略{len(user_query) - REMEMBER_PREVIEW}字)...\n... " + user_query[-HALF:]
else:
query_to_remember = query_clip
else:
from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import SmartArxivSplitter
splitter = SmartArxivSplitter(
char_range=(1000, 1200),
root_dir="gpt_log/arxiv_cache"
)
if_arxiv = True
for fragment in splitter.process(txt):
pass
# 初始化处理器
processor = ArxivPaperProcessor()
rag_handler = RagHandler()
query_to_remember = query_clip
# Step 1: 下载和提取论文
download_result = processor.download_and_extract(txt, chatbot, history)
project_folder, arxiv_id = None, None
for result in download_result:
if isinstance(result, tuple) and len(result) == 2:
project_folder, arxiv_id = result
break
chatbot.append((user_query, "正在思考中..."))
yield from update_ui(chatbot=chatbot, history=history)
if not project_folder or not arxiv_id:
# 生成提示词
prompt = worker.retrieve_and_generate(query_clip)
if not prompt:
chatbot[-1] = (user_query, "抱歉,处理您的问题时出现错误,请重试。")
yield from update_ui(chatbot=chatbot, history=history)
return
# Step 2: 合并TEX文件
paper_content = processor.merge_tex_files(project_folder, chatbot, history)
if not paper_content:
return
# Step 3: RAG处理
chatbot.append(["正在构建知识图谱...", "处理中..."])
# 获取回答
response = yield from request_gpt_model_in_new_thread_with_ui_alive(
inputs=prompt,
inputs_show_user=query_clip,
llm_kwargs=llm_kwargs,
chatbot=chatbot,
history=history,
sys_prompt=system_prompt
)
# 记忆问答对
worker.remember_qa(query_to_remember, response)
history.extend([user_query, response])
yield from update_ui(chatbot=chatbot, history=history)
# 处理论文内容
rag_handler.process_paper_content(paper_content)
# 生成初始摘要
summary = rag_handler.query("请总结这篇论文的主要内容,包括研究目的、方法、结果和结论。")
chatbot.append(["论文摘要", summary])
yield from update_ui(chatbot=chatbot, history=history)
# 交互式问答
if __name__ == "__main__":
# 测试代码
llm_kwargs = {
'api_key': os.getenv("one_api_key"),
'client_ip': '127.0.0.1',
'embed_model': 'text-embedding-3-small',
'llm_model': 'one-api-Qwen2.5-72B-Instruct',
'max_length': 4096,
'most_recent_uploaded': None,
'temperature': 1,
'top_p': 1
}
plugin_kwargs = {}
chatbot = []
history = []
system_prompt = "You are a helpful assistant."
web_port = "8080"
# 测试论文导入
arxiv_url = "https://arxiv.org/abs/2312.12345"
for response in Arxiv论文对话(
arxiv_url, llm_kwargs, plugin_kwargs,
chatbot, history, system_prompt, web_port
):
print(response)
# 测试问答
question = "这篇论文的主要贡献是什么?"
for response in Arxiv论文对话(
question, llm_kwargs, plugin_kwargs,
chatbot, history, system_prompt, web_port
):
print(response)

查看文件

@@ -1,15 +1,18 @@
import llama_index
import os
import atexit
from typing import List, Dict, Optional, Any, Tuple
from llama_index.core import Document
from llama_index.core.ingestion import run_transformations
from llama_index.core.schema import TextNode, NodeWithScore
from loguru import logger
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
from typing import List
from llama_index.core import Document
from llama_index.core.schema import TextNode
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
import json
import numpy as np
from shared_utils.connect_void_terminal import get_chat_default_kwargs
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
from llama_index.core.ingestion import run_transformations
from llama_index.core import PromptTemplate
from llama_index.core.response_synthesizers import TreeSummarize
DEFAULT_QUERY_GENERATION_PROMPT = """\
Now, you have context information as below:
---------------------
@@ -60,7 +63,7 @@ class SaveLoad():
def purge(self):
import shutil
shutil.rmtree(self.checkpoint_dir, ignore_errors=True)
self.vs_index = self.create_new_vs(self.checkpoint_dir)
self.vs_index = self.create_new_vs()
class LlamaIndexRagWorker(SaveLoad):
@@ -69,11 +72,61 @@ class LlamaIndexRagWorker(SaveLoad):
self.embed_model = OpenAiEmbeddingModel(llm_kwargs)
self.user_name = user_name
self.checkpoint_dir = checkpoint_dir
if auto_load_checkpoint:
self.vs_index = self.load_from_checkpoint(checkpoint_dir)
# 确保checkpoint_dir存在
if checkpoint_dir:
os.makedirs(checkpoint_dir, exist_ok=True)
logger.info(f"Initializing LlamaIndexRagWorker with checkpoint_dir: {checkpoint_dir}")
# 初始化向量存储
if auto_load_checkpoint and self.does_checkpoint_exist():
logger.info("Loading existing vector store from checkpoint")
self.vs_index = self.load_from_checkpoint()
else:
logger.info("Creating new vector store")
self.vs_index = self.create_new_vs()
atexit.register(lambda: self.save_to_checkpoint(checkpoint_dir))
# 注册退出时保存
atexit.register(self.save_to_checkpoint)
def add_text_to_vector_store(self, text: str) -> None:
"""添加文本到向量存储"""
try:
logger.info(f"Adding text to vector store (first 100 chars): {text[:100]}...")
node = TextNode(text=text)
nodes = run_transformations(
[node],
self.vs_index._transformations,
show_progress=True
)
self.vs_index.insert_nodes(nodes)
# 立即保存
self.save_to_checkpoint()
if self.debug_mode:
self.inspect_vector_store()
except Exception as e:
logger.error(f"Error adding text to vector store: {str(e)}")
raise
def save_to_checkpoint(self, checkpoint_dir=None):
"""保存向量存储到检查点"""
try:
if checkpoint_dir is None:
checkpoint_dir = self.checkpoint_dir
logger.info(f'Saving vector store to: {checkpoint_dir}')
if checkpoint_dir:
self.vs_index.storage_context.persist(persist_dir=checkpoint_dir)
logger.info('Vector store saved successfully')
else:
logger.warning('No checkpoint directory specified, skipping save')
except Exception as e:
logger.error(f"Error saving checkpoint: {str(e)}")
raise
def assign_embedding_model(self):
pass
@@ -82,44 +135,28 @@ class LlamaIndexRagWorker(SaveLoad):
# This function is for debugging
self.vs_index.storage_context.index_store.to_dict()
docstore = self.vs_index.storage_context.docstore.docs
vector_store_preview = "\n".join([ f"{_id} | {tn.text}" for _id, tn in docstore.items() ])
vector_store_preview = "\n".join([f"{_id} | {tn.text}" for _id, tn in docstore.items()])
logger.info('\n++ --------inspect_vector_store begin--------')
logger.info(vector_store_preview)
logger.info('oo --------inspect_vector_store end--------')
return vector_store_preview
def add_documents_to_vector_store(self, document_list: List[Document]):
"""
Adds a list of Document objects to the vector store after processing.
"""
documents = document_list
def add_documents_to_vector_store(self, document_list):
documents = [Document(text=t) for t in document_list]
documents_nodes = run_transformations(
documents, # type: ignore
self.vs_index._transformations,
show_progress=True
)
self.vs_index.insert_nodes(documents_nodes)
if self.debug_mode:
self.inspect_vector_store()
def add_text_to_vector_store(self, text: str):
node = TextNode(text=text)
documents_nodes = run_transformations(
[node],
self.vs_index._transformations,
show_progress=True
)
self.vs_index.insert_nodes(documents_nodes)
if self.debug_mode:
self.inspect_vector_store()
if self.debug_mode: self.inspect_vector_store()
def remember_qa(self, question, answer):
formatted_str = QUESTION_ANSWER_RECORD.format(question=question, answer=answer)
self.add_text_to_vector_store(formatted_str)
def retrieve_from_store_with_query(self, query):
if self.debug_mode:
self.inspect_vector_store()
if self.debug_mode: self.inspect_vector_store()
retriever = self.vs_index.as_retriever()
return retriever.retrieve(query)
@@ -128,230 +165,6 @@ class LlamaIndexRagWorker(SaveLoad):
return DEFAULT_QUERY_GENERATION_PROMPT.format(context_str=context_str, query_str=query)
def generate_node_array_preview(self, nodes):
buf = "\n".join(([f"(No.{i+1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)]))
buf = "\n".join(([f"(No.{i + 1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)]))
if self.debug_mode: logger.info(buf)
return buf
def purge_vector_store(self):
"""
Purges the current vector store and creates a new one.
"""
self.purge()
"""
以下是添加的新方法,原有方法保持不变
"""
def add_text_with_metadata(self, text: str, metadata: dict) -> str:
"""
添加带元数据的文本到向量存储
Args:
text: 文本内容
metadata: 元数据字典
Returns:
添加的节点ID
"""
node = TextNode(text=text, metadata=metadata)
nodes = run_transformations(
[node],
self.vs_index._transformations,
show_progress=True
)
self.vs_index.insert_nodes(nodes)
return nodes[0].node_id if nodes else None
def batch_add_texts_with_metadata(self, texts: List[Tuple[str, dict]]) -> List[str]:
"""
批量添加带元数据的文本
Args:
texts: (text, metadata)元组列表
Returns:
添加的节点ID列表
"""
nodes = [TextNode(text=t, metadata=m) for t, m in texts]
transformed_nodes = run_transformations(
nodes,
self.vs_index._transformations,
show_progress=True
)
if transformed_nodes:
self.vs_index.insert_nodes(transformed_nodes)
return [node.node_id for node in transformed_nodes]
return []
def get_node_metadata(self, node_id: str) -> Optional[dict]:
"""
获取节点的元数据
Args:
node_id: 节点ID
Returns:
节点的元数据字典
"""
node = self.vs_index.storage_context.docstore.docs.get(node_id)
return node.metadata if node else None
def update_node_metadata(self, node_id: str, metadata: dict, merge: bool = True) -> bool:
"""
更新节点的元数据
Args:
node_id: 节点ID
metadata: 新的元数据
merge: 是否与现有元数据合并
Returns:
是否更新成功
"""
docstore = self.vs_index.storage_context.docstore
if node_id in docstore.docs:
node = docstore.docs[node_id]
if merge:
node.metadata.update(metadata)
else:
node.metadata = metadata
return True
return False
def filter_nodes_by_metadata(self, filters: Dict[str, Any]) -> List[TextNode]:
"""
按元数据过滤节点
Args:
filters: 元数据过滤条件
Returns:
符合条件的节点列表
"""
docstore = self.vs_index.storage_context.docstore
results = []
for node in docstore.docs.values():
if all(node.metadata.get(k) == v for k, v in filters.items()):
results.append(node)
return results
def retrieve_with_metadata_filter(
self,
query: str,
metadata_filters: Dict[str, Any],
top_k: int = 5
) -> List[NodeWithScore]:
"""
结合元数据过滤的检索
Args:
query: 查询文本
metadata_filters: 元数据过滤条件
top_k: 返回结果数量
Returns:
检索结果节点列表
"""
retriever = self.vs_index.as_retriever(similarity_top_k=top_k)
nodes = retriever.retrieve(query)
# 应用元数据过滤
filtered_nodes = []
for node in nodes:
if all(node.metadata.get(k) == v for k, v in metadata_filters.items()):
filtered_nodes.append(node)
return filtered_nodes
def get_node_stats(self, node_id: str) -> dict:
"""
获取单个节点的统计信息
Args:
node_id: 节点ID
Returns:
节点统计信息字典
"""
node = self.vs_index.storage_context.docstore.docs.get(node_id)
if not node:
return {}
return {
"text_length": len(node.text),
"token_count": len(node.text.split()),
"has_embedding": node.embedding is not None,
"metadata_keys": list(node.metadata.keys()),
}
def get_nodes_by_content_pattern(self, pattern: str) -> List[TextNode]:
"""
按内容模式查找节点
Args:
pattern: 正则表达式模式
Returns:
匹配的节点列表
"""
import re
docstore = self.vs_index.storage_context.docstore
matched_nodes = []
for node in docstore.docs.values():
if re.search(pattern, node.text):
matched_nodes.append(node)
return matched_nodes
def export_nodes(
self,
output_file: str,
format: str = "json",
include_embeddings: bool = False
) -> None:
"""
Export nodes to file
Args:
output_file: Output file path
format: "json" or "csv"
include_embeddings: Whether to include embeddings
"""
docstore = self.vs_index.storage_context.docstore
data = []
for node_id, node in docstore.docs.items():
node_data = {
"node_id": node_id,
"text": node.text,
"metadata": node.metadata,
}
if include_embeddings and node.embedding is not None:
node_data["embedding"] = node.embedding.tolist()
data.append(node_data)
if format == "json":
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
elif format == "csv":
import csv
import pandas as pd
df = pd.DataFrame(data)
df.to_csv(output_file, index=False, quoting=csv.QUOTE_NONNUMERIC)
else:
raise ValueError(f"Unsupported format: {format}")
def get_statistics(self) -> Dict[str, Any]:
"""Get vector store statistics"""
docstore = self.vs_index.storage_context.docstore
docs = list(docstore.docs.values())
return {
"total_nodes": len(docs),
"total_tokens": sum(len(node.text.split()) for node in docs),
"avg_text_length": np.mean([len(node.text) for node in docs]) if docs else 0,
"embedding_dimension": len(docs[0].embedding) if docs and docs[0].embedding is not None else 0
}