这个提交包含在:
lbykkkk
2024-11-16 00:35:31 +08:00
父节点 dd902e9519
当前提交 21626a44d5
共有 12 个文件被更改,包括 2385 次插入1169 次删除

查看文件

@@ -13,18 +13,19 @@ from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
T = TypeVar('T')
@dataclass
@dataclass
class StorageBase:
"""Base class for all storage implementations"""
namespace: str
working_dir: str
async def index_done_callback(self):
"""Hook called after indexing operations"""
pass
async def query_done_callback(self):
"""Hook called after query operations"""
"""Hook called after query operations"""
pass
@@ -32,37 +33,37 @@ class StorageBase:
class JsonKVStorage(StorageBase, Generic[T]):
"""
Key-Value storage using JSON files
Attributes:
namespace (str): Storage namespace
working_dir (str): Working directory for storage files
_file_name (str): JSON file path
_data (Dict[str, T]): In-memory storage
"""
def __post_init__(self):
"""Initialize storage file and load data"""
self._file_name = os.path.join(self.working_dir, f"kv_{self.namespace}.json")
self._file_name = os.path.join(self.working_dir, f"kv_store_{self.namespace}.json")
self._data: Dict[str, T] = {}
self.load()
def load(self):
"""Load data from JSON file"""
if os.path.exists(self._file_name):
with open(self._file_name, 'r', encoding='utf-8') as f:
self._data = json.load(f)
logger.info(f"Loaded {len(self._data)} items from {self._file_name}")
async def save(self):
"""Save data to JSON file"""
os.makedirs(os.path.dirname(self._file_name), exist_ok=True)
with open(self._file_name, 'w', encoding='utf-8') as f:
json.dump(self._data, f, ensure_ascii=False, indent=2)
async def get_by_id(self, id: str) -> Optional[T]:
"""Get item by ID"""
return self._data.get(id)
async def get_by_ids(self, ids: List[str], fields: Optional[Set[str]] = None) -> List[Optional[T]]:
"""Get multiple items by IDs with optional field filtering"""
if fields is None:
@@ -70,16 +71,16 @@ class JsonKVStorage(StorageBase, Generic[T]):
return [{k: v for k, v in self._data[id].items() if k in fields}
if id in self._data else None
for id in ids]
async def filter_keys(self, keys: List[str]) -> Set[str]:
"""Return keys that don't exist in storage"""
return set(k for k in keys if k not in self._data)
async def upsert(self, data: Dict[str, T]):
"""Insert or update items"""
self._data.update(data)
await self.save()
async def drop(self):
"""Clear all data"""
self._data = {}
@@ -95,148 +96,225 @@ class JsonKVStorage(StorageBase, Generic[T]):
await self.save()
@dataclass
class VectorStorage(StorageBase):
"""
Vector storage using LlamaIndex
Vector storage using LlamaIndexRagWorker
Attributes:
namespace (str): Storage namespace
namespace (str): Storage namespace (e.g., 'entities', 'relationships', 'chunks')
working_dir (str): Working directory for storage files
llm_kwargs (dict): LLM configuration
embedding_func (OpenAiEmbeddingModel): Embedding function
meta_fields (Set[str]): Additional fields to store
cosine_better_than_threshold (float): Similarity threshold
meta_fields (Set[str]): Additional metadata fields to store
"""
llm_kwargs: dict
embedding_func: OpenAiEmbeddingModel
meta_fields: Set[str] = field(default_factory=set)
cosine_better_than_threshold: float = 0.2
def __post_init__(self):
"""Initialize LlamaIndex worker"""
checkpoint_dir = os.path.join(self.working_dir, f"vector_{self.namespace}")
# 使用正确的文件命名格式
self._vector_file = os.path.join(self.working_dir, f"vdb_{self.namespace}.json")
# 设置检查点目录
checkpoint_dir = os.path.join(self.working_dir, f"vector_{self.namespace}_checkpoint")
os.makedirs(checkpoint_dir, exist_ok=True)
# 初始化向量存储
self.vector_store = LlamaIndexRagWorker(
user_name=self.namespace,
llm_kwargs=self.llm_kwargs,
checkpoint_dir=checkpoint_dir,
auto_load_checkpoint=True # 自动加载检查点
auto_load_checkpoint=True
)
async def query(self, query: str, top_k: int = 5) -> List[dict]:
logger.info(f"Initialized vector storage for {self.namespace}")
async def query(self, query: str, top_k: int = 5, metadata_filters: Optional[Dict[str, Any]] = None) -> List[dict]:
"""
Query vectors by similarity
Query vectors by similarity with optional metadata filtering
Args:
query: Query text
top_k: Maximum number of results
top_k: Maximum number of results to return
metadata_filters: Optional metadata filters
Returns:
List of similar documents with scores
"""
nodes = self.vector_store.retrieve_from_store_with_query(query)
results = [{
"id": node.node_id,
"text": node.text,
"score": node.score,
**{k: getattr(node, k) for k in self.meta_fields if hasattr(node, k)}
} for node in nodes[:top_k]]
return [r for r in results if r.get('score', 0) > self.cosine_better_than_threshold]
try:
if metadata_filters:
nodes = self.vector_store.retrieve_with_metadata_filter(query, metadata_filters, top_k)
else:
nodes = self.vector_store.retrieve_from_store_with_query(query)[:top_k]
results = []
for node in nodes:
result = {
"id": node.node_id,
"text": node.text,
"score": node.score if hasattr(node, 'score') else 0.0,
}
# Add metadata fields if they exist and are in meta_fields
if hasattr(node, 'metadata'):
result.update({
k: node.metadata[k]
for k in self.meta_fields
if k in node.metadata
})
results.append(result)
return results
except Exception as e:
logger.error(f"Error in vector query: {e}")
raise
async def upsert(self, data: Dict[str, dict]):
"""
Insert or update vectors
Args:
data: Dictionary of documents to insert/update
data: Dictionary of documents to insert/update with format:
{id: {"content": text, "metadata": dict}}
"""
for id, item in data.items():
content = item["content"]
metadata = {k: item[k] for k in self.meta_fields if k in item}
self.vector_store.add_text_with_metadata(content, metadata=metadata)
try:
for doc_id, item in data.items():
content = item["content"]
# 提取元数据
metadata = {
k: item[k]
for k in self.meta_fields
if k in item
}
# 添加文档ID到元数据
metadata["doc_id"] = doc_id
# 添加到向量存储
self.vector_store.add_text_with_metadata(content, metadata)
# 导出向量数据到json文件
self.vector_store.export_nodes(
self._vector_file,
format="json",
include_embeddings=True
)
except Exception as e:
logger.error(f"Error in vector upsert: {e}")
raise
async def save(self):
"""Save vector store to checkpoint and export data"""
try:
# 保存检查点
self.vector_store.save_to_checkpoint()
# 导出向量数据
self.vector_store.export_nodes(
self._vector_file,
format="json",
include_embeddings=True
)
except Exception as e:
logger.error(f"Error saving vector storage: {e}")
raise
async def index_done_callback(self):
"""Save after indexing"""
self.vector_store.save_to_checkpoint()
await self.save()
def get_statistics(self) -> Dict[str, Any]:
"""Get vector store statistics"""
return self.vector_store.get_statistics()
@dataclass
class NetworkStorage(StorageBase):
"""
Graph storage using NetworkX
Attributes:
namespace (str): Storage namespace
working_dir (str): Working directory for storage files
"""
def __post_init__(self):
"""Initialize graph and storage file"""
self._file_name = os.path.join(self.working_dir, f"graph_{self.namespace}.graphml")
self._graph = self._load_graph() or nx.Graph()
logger.info(f"Initialized graph storage for {self.namespace}")
def _load_graph(self) -> Optional[nx.Graph]:
"""Load graph from GraphML file"""
if os.path.exists(self._file_name):
try:
return nx.read_graphml(self._file_name)
graph = nx.read_graphml(self._file_name)
logger.info(f"Loaded graph with {graph.number_of_nodes()} nodes and {graph.number_of_edges()} edges")
return graph
except Exception as e:
logger.error(f"Error loading graph from {self._file_name}: {e}")
return None
return None
async def save_graph(self):
"""Save graph to GraphML file"""
os.makedirs(os.path.dirname(self._file_name), exist_ok=True)
logger.info(f"Saving graph with {self._graph.number_of_nodes()} nodes, {self._graph.number_of_edges()} edges")
nx.write_graphml(self._graph, self._file_name)
try:
os.makedirs(os.path.dirname(self._file_name), exist_ok=True)
logger.info(
f"Saving graph with {self._graph.number_of_nodes()} nodes, {self._graph.number_of_edges()} edges")
nx.write_graphml(self._graph, self._file_name)
except Exception as e:
logger.error(f"Error saving graph: {e}")
raise
async def has_node(self, node_id: str) -> bool:
"""Check if node exists"""
return self._graph.has_node(node_id)
async def has_edge(self, source_id: str, target_id: str) -> bool:
"""Check if edge exists"""
return self._graph.has_edge(source_id, target_id)
async def get_node(self, node_id: str) -> Optional[dict]:
"""Get node attributes"""
if not self._graph.has_node(node_id):
return None
return dict(self._graph.nodes[node_id])
async def get_edge(self, source_id: str, target_id: str) -> Optional[dict]:
"""Get edge attributes"""
if not self._graph.has_edge(source_id, target_id):
return None
return dict(self._graph.edges[source_id, target_id])
async def node_degree(self, node_id: str) -> int:
"""Get node degree"""
return self._graph.degree(node_id)
async def edge_degree(self, source_id: str, target_id: str) -> int:
"""Get sum of degrees of edge endpoints"""
return self._graph.degree(source_id) + self._graph.degree(target_id)
async def get_node_edges(self, source_id: str) -> Optional[List[Tuple[str, str]]]:
"""Get all edges connected to node"""
if not self._graph.has_node(source_id):
return None
return list(self._graph.edges(source_id))
async def upsert_node(self, node_id: str, node_data: Dict[str, str]):
"""Insert or update node"""
# Clean and normalize node data
cleaned_data = {k: html.escape(str(v).upper().strip()) for k, v in node_data.items()}
self._graph.add_node(node_id, **cleaned_data)
await self.save_graph()
async def upsert_edge(self, source_id: str, target_id: str, edge_data: Dict[str, str]):
"""Insert or update edge"""
# Clean and normalize edge data
cleaned_data = {k: html.escape(str(v).strip()) for k, v in edge_data.items()}
self._graph.add_edge(source_id, target_id, **cleaned_data)
await self.save_graph()
async def index_done_callback(self):
"""Save after indexing"""
await self.save_graph()
@@ -245,47 +323,47 @@ class NetworkStorage(StorageBase):
"""Get the largest connected component of the graph"""
if not self._graph:
return nx.Graph()
components = list(nx.connected_components(self._graph))
if not components:
return nx.Graph()
largest_component = max(components, key=len)
return self._graph.subgraph(largest_component).copy()
async def embed_nodes(self, algorithm: str, **kwargs) -> Tuple[np.ndarray, List[str]]:
"""
Embed nodes using specified algorithm
Args:
algorithm: Node embedding algorithm name
**kwargs: Additional algorithm parameters
Returns:
Tuple of (node embeddings, node IDs)
"""
async def embed_nodes(
self,
algorithm: str = "node2vec",
dimensions: int = 128,
walk_length: int = 30,
num_walks: int = 200,
workers: int = 4,
window: int = 10,
min_count: int = 1,
**kwargs
) -> Tuple[np.ndarray, List[str]]:
"""Generate node embeddings using specified algorithm"""
if algorithm == "node2vec":
from node2vec import Node2Vec
# Create node2vec model
node2vec = Node2Vec(
# Create and train node2vec model
n2v = Node2Vec(
self._graph,
dimensions=kwargs.get('dimensions', 128),
walk_length=kwargs.get('walk_length', 30),
num_walks=kwargs.get('num_walks', 200),
workers=kwargs.get('workers', 4)
dimensions=dimensions,
walk_length=walk_length,
num_walks=num_walks,
workers=workers
)
# Train model
model = node2vec.fit(
window=kwargs.get('window', 10),
min_count=kwargs.get('min_count', 1)
model = n2v.fit(
window=window,
min_count=min_count
)
# Get embeddings
# Get embeddings for all nodes
node_ids = list(self._graph.nodes())
embeddings = np.array([model.wv[node] for node in node_ids])
return embeddings, node_ids
else:
raise ValueError(f"Unsupported embedding algorithm: {algorithm}")
raise ValueError(f"Unsupported embedding algorithm: {algorithm}")

查看文件

@@ -23,25 +23,29 @@ class ExtractionExample:
def __init__(self):
"""Initialize RAG system components"""
# 设置工作目录
self.working_dir = f"private_upload/default_user/rag_cache_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
self.working_dir = f"crazy_functions/rag_fns/LightRAG/rag_cache_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
os.makedirs(self.working_dir, exist_ok=True)
logger.info(f"Working directory: {self.working_dir}")
# 初始化embedding
self.llm_kwargs = {'api_key': os.getenv("one_api_key"), 'client_ip': '127.0.0.1',
'embed_model': 'text-embedding-3-small', 'llm_model': 'one-api-Qwen2.5-72B-Instruct',
'max_length': 4096, 'most_recent_uploaded': None, 'temperature': 1, 'top_p': 1}
self.llm_kwargs = {
'api_key': os.getenv("one_api_key"),
'client_ip': '127.0.0.1',
'embed_model': 'text-embedding-3-small',
'llm_model': 'one-api-Qwen2.5-72B-Instruct',
'max_length': 4096,
'most_recent_uploaded': None,
'temperature': 1,
'top_p': 1
}
self.embedding_func = OpenAiEmbeddingModel(self.llm_kwargs)
# 初始化提示模板和抽取器
self.prompt_templates = PromptTemplates()
self.extractor = EntityRelationExtractor(
prompt_templates=self.prompt_templates,
required_prompts = {
'entity_extraction'
},
required_prompts={'entity_extraction'},
entity_extract_max_gleaning=1
)
# 初始化存储系统
@@ -63,18 +67,33 @@ class ExtractionExample:
working_dir=self.working_dir
)
# 向量存储 - 用于相似度检索
self.vector_store = VectorStorage(
namespace="vectors",
# 向量存储 - 用于实体、关系和文本块的向量表示
self.entities_vdb = VectorStorage(
namespace="entities",
working_dir=self.working_dir,
llm_kwargs=self.llm_kwargs,
embedding_func=self.embedding_func,
meta_fields={"entity_name", "entity_type"}
)
self.relationships_vdb = VectorStorage(
namespace="relationships",
working_dir=self.working_dir,
llm_kwargs=self.llm_kwargs,
embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"}
)
self.chunks_vdb = VectorStorage(
namespace="chunks",
working_dir=self.working_dir,
llm_kwargs=self.llm_kwargs,
embedding_func=self.embedding_func
)
# 图存储 - 用于实体关系
self.graph_store = NetworkStorage(
namespace="graph",
namespace="chunk_entity_relation",
working_dir=self.working_dir
)
@@ -152,7 +171,7 @@ class ExtractionExample:
try:
# 向量存储
logger.info("Adding chunks to vector store...")
await self.vector_store.upsert(chunks)
await self.chunks_vdb.upsert(chunks)
# 初始化对话历史
self.conversation_history = {chunk_key: [] for chunk_key in chunks.keys()}
@@ -178,14 +197,32 @@ class ExtractionExample:
# 获取结果
nodes, edges = self.extractor.get_results()
# 存储图数据库
logger.info("Storing extracted information in graph database...")
# 存储实体到向量数据库和图数据库
for node_name, node_instances in nodes.items():
for node in node_instances:
# 存储到向量数据库
await self.entities_vdb.upsert({
f"entity_{node_name}": {
"content": f"{node_name}: {node['description']}",
"entity_name": node_name,
"entity_type": node['entity_type']
}
})
# 存储到图数据库
await self.graph_store.upsert_node(node_name, node)
# 存储关系到向量数据库和图数据库
for (src, tgt), edge_instances in edges.items():
for edge in edge_instances:
# 存储到向量数据库
await self.relationships_vdb.upsert({
f"rel_{src}_{tgt}": {
"content": f"{edge['description']} | {edge['keywords']}",
"src_id": src,
"tgt_id": tgt
}
})
# 存储到图数据库
await self.graph_store.upsert_edge(src, tgt, edge)
return nodes, edges
@@ -197,26 +234,39 @@ class ExtractionExample:
async def query_knowledge_base(self, query: str, top_k: int = 5):
"""Query the knowledge base using various methods"""
try:
# 向量相似度搜索
vector_results = await self.vector_store.query(query, top_k=top_k)
# 向量相似度搜索 - 文本块
chunk_results = await self.chunks_vdb.query(query, top_k=top_k)
# 向量相似度搜索 - 实体
entity_results = await self.entities_vdb.query(query, top_k=top_k)
# 获取相关文本块
chunk_ids = [r["id"] for r in vector_results]
chunk_ids = [r["id"] for r in chunk_results]
chunks = await self.text_chunks.get_by_ids(chunk_ids)
# 获取相关实体
# 假设query中包含实体名称
relevant_nodes = []
for word in query.split():
if await self.graph_store.has_node(word.upper()):
node_data = await self.graph_store.get_node(word.upper())
if node_data:
relevant_nodes.append(node_data)
# 获取实体相关的图结构信息
relevant_edges = []
for entity in entity_results:
if "entity_name" in entity:
entity_name = entity["entity_name"]
if await self.graph_store.has_node(entity_name):
edges = await self.graph_store.get_node_edges(entity_name)
if edges:
edge_data = []
for edge in edges:
edge_info = await self.graph_store.get_edge(edge[0], edge[1])
if edge_info:
edge_data.append({
"source": edge[0],
"target": edge[1],
"data": edge_info
})
relevant_edges.extend(edge_data)
return {
"vector_results": vector_results,
"text_chunks": chunks,
"relevant_entities": relevant_nodes
"chunks": chunks,
"entities": entity_results,
"relationships": relevant_edges
}
except Exception as e:
@@ -228,30 +278,27 @@ class ExtractionExample:
os.makedirs(export_dir, exist_ok=True)
try:
# 导出向量存储
self.vector_store.vector_store.export_nodes(
os.path.join(export_dir, "vector_nodes.json"),
include_embeddings=True
)
# 导出图数据统计
graph_stats = {
"total_nodes": len(list(self.graph_store._graph.nodes())),
"total_edges": len(list(self.graph_store._graph.edges())),
"node_degrees": dict(self.graph_store._graph.degree()),
"largest_component_size": len(self.graph_store.get_largest_connected_component())
}
with open(os.path.join(export_dir, "graph_stats.json"), "w") as f:
json.dump(graph_stats, f, indent=2)
# 导出存储统计
# 导出统计信息
storage_stats = {
"chunks": len(self.text_chunks._data),
"docs": len(self.full_docs._data),
"vector_store": self.vector_store.vector_store.get_statistics()
"chunks": {
"total": len(self.text_chunks._data),
"vector_stats": self.chunks_vdb.get_statistics()
},
"entities": {
"vector_stats": self.entities_vdb.get_statistics()
},
"relationships": {
"vector_stats": self.relationships_vdb.get_statistics()
},
"graph": {
"total_nodes": len(list(self.graph_store._graph.nodes())),
"total_edges": len(list(self.graph_store._graph.edges())),
"node_degrees": dict(self.graph_store._graph.degree()),
"largest_component_size": len(self.graph_store.get_largest_connected_component())
}
}
# 导出统计
with open(os.path.join(export_dir, "storage_stats.json"), "w") as f:
json.dump(storage_stats, f, indent=2)
@@ -299,19 +346,6 @@ async def main():
the company's commitment to innovation and sustainability. The new iPhone
features groundbreaking AI capabilities.
""",
# "business_news": """
# Microsoft and OpenAI expanded their partnership today.
# Satya Nadella emphasized the importance of AI development while
# Sam Altman discussed the future of large language models. The collaboration
# aims to accelerate AI research and deployment.
# """,
#
# "science_paper": """
# Researchers at DeepMind published a breakthrough paper on quantum computing.
# The team demonstrated novel approaches to quantum error correction.
# Dr. Sarah Johnson led the research, collaborating with Google's quantum lab.
# """
}
try: