镜像自地址
https://github.com/binary-husky/gpt_academic.git
已同步 2025-12-07 06:56:48 +00:00
up
这个提交包含在:
@@ -0,0 +1,85 @@
|
||||
from typing import Dict, List, Optional
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
import numpy as np
|
||||
import os
|
||||
from toolbox import get_conf
|
||||
import openai
|
||||
|
||||
class RagHandler:
|
||||
def __init__(self):
|
||||
# 初始化工作目录
|
||||
self.working_dir = os.path.join(get_conf('ARXIV_CACHE_DIR'), 'rag_cache')
|
||||
if not os.path.exists(self.working_dir):
|
||||
os.makedirs(self.working_dir)
|
||||
|
||||
# 初始化 LightRAG
|
||||
self.rag = LightRAG(
|
||||
working_dir=self.working_dir,
|
||||
llm_model_func=self._llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=1536, # OpenAI embedding 维度
|
||||
max_token_size=8192,
|
||||
func=self._embedding_func,
|
||||
),
|
||||
)
|
||||
|
||||
async def _llm_model_func(self, prompt: str, system_prompt: str = None,
|
||||
history_messages: List = None, **kwargs) -> str:
|
||||
"""LLM 模型函数"""
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
if history_messages:
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
response = await openai.ChatCompletion.acreate(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
temperature=kwargs.get("temperature", 0),
|
||||
max_tokens=kwargs.get("max_tokens", 1000)
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
async def _embedding_func(self, texts: List[str]) -> np.ndarray:
|
||||
"""Embedding 函数"""
|
||||
response = await openai.Embedding.acreate(
|
||||
model="text-embedding-ada-002",
|
||||
input=texts
|
||||
)
|
||||
embeddings = [item["embedding"] for item in response["data"]]
|
||||
return np.array(embeddings)
|
||||
|
||||
def process_paper_content(self, paper_content: Dict) -> None:
|
||||
"""处理论文内容,构建知识图谱"""
|
||||
# 处理标题和摘要
|
||||
content_list = []
|
||||
if paper_content['title']:
|
||||
content_list.append(f"Title: {paper_content['title']}")
|
||||
if paper_content['abstract']:
|
||||
content_list.append(f"Abstract: {paper_content['abstract']}")
|
||||
|
||||
# 添加分段内容
|
||||
content_list.extend(paper_content['segments'])
|
||||
|
||||
# 插入到 RAG 系统
|
||||
self.rag.insert(content_list)
|
||||
|
||||
def query(self, question: str, mode: str = "hybrid") -> str:
|
||||
"""查询论文内容
|
||||
mode: 查询模式,可选 naive/local/global/hybrid
|
||||
"""
|
||||
try:
|
||||
response = self.rag.query(
|
||||
question,
|
||||
param=QueryParam(
|
||||
mode=mode,
|
||||
top_k=5, # 返回相关度最高的5个结果
|
||||
max_token_for_text_unit=2048, # 每个文本单元的最大token数
|
||||
response_type="detailed" # 返回详细回答
|
||||
)
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
return f"查询出错: {str(e)}"
|
||||
在新工单中引用
屏蔽一个用户