镜像自地址
https://github.com/binary-husky/gpt_academic.git
已同步 2025-12-06 14:36:48 +00:00
first_version
这个提交包含在:
@@ -1,7 +1,12 @@
|
|||||||
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg
|
import os
|
||||||
from crazy_functions.crazy_utils import input_clipping
|
from typing import List
|
||||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
|
||||||
|
|
||||||
|
from llama_index.core import Document
|
||||||
|
from shared_utils.fastapi_server import validate_path_safety
|
||||||
|
from crazy_functions.crazy_utils import input_clipping, request_gpt_model_in_new_thread_with_ui_alive
|
||||||
|
from toolbox import CatchException, update_ui, get_log_folder, update_ui_lastest_msg
|
||||||
|
from toolbox import report_exception
|
||||||
|
from crazy_functions.rag_fns.rag_file_support import extract_text
|
||||||
VECTOR_STORE_TYPE = "Milvus"
|
VECTOR_STORE_TYPE = "Milvus"
|
||||||
|
|
||||||
if VECTOR_STORE_TYPE == "Milvus":
|
if VECTOR_STORE_TYPE == "Milvus":
|
||||||
@@ -21,11 +26,22 @@ MAX_CONTEXT_TOKEN_LIMIT = 4096
|
|||||||
REMEMBER_PREVIEW = 1000
|
REMEMBER_PREVIEW = 1000
|
||||||
|
|
||||||
@CatchException
|
@CatchException
|
||||||
def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
def handle_document_upload(files: List[str], llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||||
|
"""
|
||||||
|
Handles document uploads by extracting text and adding it to the vector store.
|
||||||
|
|
||||||
# 1. we retrieve rag worker from global context
|
Args:
|
||||||
|
files (List[str]): List of file paths to process.
|
||||||
|
llm_kwargs: Language model keyword arguments.
|
||||||
|
plugin_kwargs: Plugin keyword arguments.
|
||||||
|
chatbot: Chatbot instance.
|
||||||
|
history: Chat history.
|
||||||
|
system_prompt: System prompt.
|
||||||
|
user_request: User request.
|
||||||
|
"""
|
||||||
user_name = chatbot.get_user()
|
user_name = chatbot.get_user()
|
||||||
checkpoint_dir = get_log_folder(user_name, plugin_name='experimental_rag')
|
checkpoint_dir = get_log_folder(user_name, plugin_name='experimental_rag')
|
||||||
|
|
||||||
if user_name in RAG_WORKER_REGISTER:
|
if user_name in RAG_WORKER_REGISTER:
|
||||||
rag_worker = RAG_WORKER_REGISTER[user_name]
|
rag_worker = RAG_WORKER_REGISTER[user_name]
|
||||||
else:
|
else:
|
||||||
@@ -33,21 +49,86 @@ def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, u
|
|||||||
user_name,
|
user_name,
|
||||||
llm_kwargs,
|
llm_kwargs,
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
auto_load_checkpoint=True)
|
auto_load_checkpoint=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for file_path in files:
|
||||||
|
try:
|
||||||
|
validate_path_safety(file_path, user_name)
|
||||||
|
text = extract_text(file_path)
|
||||||
|
document = Document(text=text, metadata={"source": file_path})
|
||||||
|
rag_worker.add_documents_to_vector_store([document])
|
||||||
|
chatbot.append([f"上传文件: {os.path.basename(file_path)}", "文件已成功添加到知识库。"])
|
||||||
|
except Exception as e:
|
||||||
|
report_exception(chatbot, history, a=f"处理文件: {file_path}", b=str(e))
|
||||||
|
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
|
|
||||||
|
@CatchException
|
||||||
|
def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||||
|
"""
|
||||||
|
Handles RAG-based Q&A, including special commands and document uploads.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
txt (str): User input text.
|
||||||
|
llm_kwargs: Language model keyword arguments.
|
||||||
|
plugin_kwargs: Plugin keyword arguments.
|
||||||
|
chatbot: Chatbot instance.
|
||||||
|
history: Chat history.
|
||||||
|
system_prompt: System prompt.
|
||||||
|
user_request: User request.
|
||||||
|
"""
|
||||||
|
# Define commands
|
||||||
|
CLEAR_VECTOR_DB_CMD = "清空向量数据库"
|
||||||
|
UPLOAD_DOCUMENT_CMD = "上传文档"
|
||||||
|
|
||||||
|
# 1. Retrieve RAG worker from global context
|
||||||
|
user_name = chatbot.get_user()
|
||||||
|
checkpoint_dir = get_log_folder(user_name, plugin_name='experimental_rag')
|
||||||
|
|
||||||
|
if user_name in RAG_WORKER_REGISTER:
|
||||||
|
rag_worker = RAG_WORKER_REGISTER[user_name]
|
||||||
|
else:
|
||||||
|
rag_worker = RAG_WORKER_REGISTER[user_name] = LlamaIndexRagWorker(
|
||||||
|
user_name,
|
||||||
|
llm_kwargs,
|
||||||
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
auto_load_checkpoint=True
|
||||||
|
)
|
||||||
|
|
||||||
current_context = f"{VECTOR_STORE_TYPE} @ {checkpoint_dir}"
|
current_context = f"{VECTOR_STORE_TYPE} @ {checkpoint_dir}"
|
||||||
tip = "提示:输入“清空向量数据库”可以清空RAG向量数据库"
|
tip = "提示:输入“清空向量数据库”可以清空RAG向量数据库"
|
||||||
if txt == "清空向量数据库":
|
|
||||||
chatbot.append([txt, f'正在清空 ({current_context}) ...'])
|
# 2. Handle special commands
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
if txt.startswith(UPLOAD_DOCUMENT_CMD):
|
||||||
rag_worker.purge()
|
# Extract file paths from the user input
|
||||||
yield from update_ui_lastest_msg('已清空', chatbot, history, delay=0) # 刷新界面
|
# Assuming the user inputs file paths separated by commas after the command
|
||||||
|
file_paths = txt[len(UPLOAD_DOCUMENT_CMD):].strip().split(',')
|
||||||
|
file_paths = [path.strip() for path in file_paths if path.strip()]
|
||||||
|
|
||||||
|
if not file_paths:
|
||||||
|
report_exception(chatbot, history, a="上传文档", b="未提供任何文件路径。")
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
return
|
||||||
|
|
||||||
|
chatbot.append([txt, f'正在处理上传的文档 ({current_context}) ...'])
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
|
|
||||||
|
yield from handle_document_upload(file_paths, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)
|
||||||
return
|
return
|
||||||
|
|
||||||
chatbot.append([txt, f'正在召回知识 ({current_context}) ...'])
|
elif txt == CLEAR_VECTOR_DB_CMD:
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
chatbot.append([txt, f'正在清空 ({current_context}) ...'])
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
|
rag_worker.purge_vector_store()
|
||||||
|
yield from update_ui_lastest_msg('已清空', chatbot, history, delay=0) # 刷新界面
|
||||||
|
return
|
||||||
|
|
||||||
# 2. clip history to reduce token consumption
|
# 3. Normal Q&A processing
|
||||||
# 2-1. reduce chat round
|
chatbot.append([txt, f'正在召回知识 ({current_context}) ...'])
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
|
|
||||||
|
# 4. Clip history to reduce token consumption
|
||||||
txt_origin = txt
|
txt_origin = txt
|
||||||
|
|
||||||
if len(history) > MAX_HISTORY_ROUND * 2:
|
if len(history) > MAX_HISTORY_ROUND * 2:
|
||||||
@@ -55,41 +136,48 @@ def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, u
|
|||||||
txt_clip, history, flags = input_clipping(txt, history, max_token_limit=MAX_CONTEXT_TOKEN_LIMIT, return_clip_flags=True)
|
txt_clip, history, flags = input_clipping(txt, history, max_token_limit=MAX_CONTEXT_TOKEN_LIMIT, return_clip_flags=True)
|
||||||
input_is_clipped_flag = (flags["original_input_len"] != flags["clipped_input_len"])
|
input_is_clipped_flag = (flags["original_input_len"] != flags["clipped_input_len"])
|
||||||
|
|
||||||
# 2-2. if input is clipped, add input to vector store before retrieve
|
# 5. If input is clipped, add input to vector store before retrieve
|
||||||
if input_is_clipped_flag:
|
if input_is_clipped_flag:
|
||||||
yield from update_ui_lastest_msg('检测到长输入, 正在向量化 ...', chatbot, history, delay=0) # 刷新界面
|
yield from update_ui_lastest_msg('检测到长输入, 正在向量化 ...', chatbot, history, delay=0) # 刷新界面
|
||||||
# save input to vector store
|
# Save input to vector store
|
||||||
rag_worker.add_text_to_vector_store(txt_origin)
|
rag_worker.add_text_to_vector_store(txt_origin)
|
||||||
yield from update_ui_lastest_msg('向量化完成 ...', chatbot, history, delay=0) # 刷新界面
|
yield from update_ui_lastest_msg('向量化完成 ...', chatbot, history, delay=0) # 刷新界面
|
||||||
|
|
||||||
if len(txt_origin) > REMEMBER_PREVIEW:
|
if len(txt_origin) > REMEMBER_PREVIEW:
|
||||||
HALF = REMEMBER_PREVIEW//2
|
HALF = REMEMBER_PREVIEW // 2
|
||||||
i_say_to_remember = txt[:HALF] + f" ...\n...(省略{len(txt_origin)-REMEMBER_PREVIEW}字)...\n... " + txt[-HALF:]
|
i_say_to_remember = txt[:HALF] + f" ...\n...(省略{len(txt_origin)-REMEMBER_PREVIEW}字)...\n... " + txt[-HALF:]
|
||||||
if (flags["original_input_len"] - flags["clipped_input_len"]) > HALF:
|
if (flags["original_input_len"] - flags["clipped_input_len"]) > HALF:
|
||||||
txt_clip = txt_clip + f" ...\n...(省略{len(txt_origin)-len(txt_clip)-HALF}字)...\n... " + txt[-HALF:]
|
txt_clip = txt_clip + f" ...\n...(省略{len(txt_origin)-len(txt_clip)-HALF}字)...\n... " + txt[-HALF:]
|
||||||
else:
|
|
||||||
pass
|
|
||||||
i_say = txt_clip
|
|
||||||
else:
|
else:
|
||||||
i_say_to_remember = i_say = txt_clip
|
i_say_to_remember = i_say = txt_clip
|
||||||
else:
|
else:
|
||||||
i_say_to_remember = i_say = txt_clip
|
i_say_to_remember = i_say = txt_clip
|
||||||
|
|
||||||
# 3. we search vector store and build prompts
|
# 6. Search vector store and build prompts
|
||||||
nodes = rag_worker.retrieve_from_store_with_query(i_say)
|
nodes = rag_worker.retrieve_from_store_with_query(i_say)
|
||||||
prompt = rag_worker.build_prompt(query=i_say, nodes=nodes)
|
prompt = rag_worker.build_prompt(query=i_say, nodes=nodes)
|
||||||
|
|
||||||
# 4. it is time to query llms
|
# 7. Query language model
|
||||||
if len(chatbot) != 0: chatbot.pop(-1) # pop temp chat, because we are going to add them again inside `request_gpt_model_in_new_thread_with_ui_alive`
|
if len(chatbot) != 0:
|
||||||
|
chatbot.pop(-1) # Pop temp chat, because we are going to add them again inside `request_gpt_model_in_new_thread_with_ui_alive`
|
||||||
|
|
||||||
model_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
model_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||||
inputs=prompt, inputs_show_user=i_say,
|
inputs=prompt,
|
||||||
llm_kwargs=llm_kwargs, chatbot=chatbot, history=history,
|
inputs_show_user=i_say,
|
||||||
|
llm_kwargs=llm_kwargs,
|
||||||
|
chatbot=chatbot,
|
||||||
|
history=history,
|
||||||
sys_prompt=system_prompt,
|
sys_prompt=system_prompt,
|
||||||
retry_times_at_unknown_error=0
|
retry_times_at_unknown_error=0
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. remember what has been asked / answered
|
# 8. Remember Q&A
|
||||||
yield from update_ui_lastest_msg(model_say + '</br></br>' + f'对话记忆中, 请稍等 ({current_context}) ...', chatbot, history, delay=0.5) # 刷新界面
|
yield from update_ui_lastest_msg(
|
||||||
|
model_say + '</br></br>' + f'对话记忆中, 请稍等 ({current_context}) ...',
|
||||||
|
chatbot, history, delay=0.5
|
||||||
|
)
|
||||||
rag_worker.remember_qa(i_say_to_remember, model_say)
|
rag_worker.remember_qa(i_say_to_remember, model_say)
|
||||||
history.extend([i_say, model_say])
|
history.extend([i_say, model_say])
|
||||||
|
|
||||||
yield from update_ui_lastest_msg(model_say, chatbot, history, delay=0, msg=tip) # 刷新界面
|
# 9. Final UI Update
|
||||||
|
yield from update_ui_lastest_msg(model_say, chatbot, history, delay=0, msg=tip)
|
||||||
@@ -1,16 +1,12 @@
|
|||||||
import llama_index
|
|
||||||
import os
|
|
||||||
import atexit
|
import atexit
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_index.core import Document
|
from llama_index.core import Document
|
||||||
from llama_index.core.schema import TextNode
|
|
||||||
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
|
|
||||||
from shared_utils.connect_void_terminal import get_chat_default_kwargs
|
|
||||||
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
|
|
||||||
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
|
|
||||||
from llama_index.core.ingestion import run_transformations
|
from llama_index.core.ingestion import run_transformations
|
||||||
from llama_index.core import PromptTemplate
|
from llama_index.core.schema import TextNode
|
||||||
from llama_index.core.response_synthesizers import TreeSummarize
|
|
||||||
|
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
|
||||||
|
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
|
||||||
|
|
||||||
DEFAULT_QUERY_GENERATION_PROMPT = """\
|
DEFAULT_QUERY_GENERATION_PROMPT = """\
|
||||||
Now, you have context information as below:
|
Now, you have context information as below:
|
||||||
@@ -74,7 +70,7 @@ class LlamaIndexRagWorker(SaveLoad):
|
|||||||
if auto_load_checkpoint:
|
if auto_load_checkpoint:
|
||||||
self.vs_index = self.load_from_checkpoint(checkpoint_dir)
|
self.vs_index = self.load_from_checkpoint(checkpoint_dir)
|
||||||
else:
|
else:
|
||||||
self.vs_index = self.create_new_vs(checkpoint_dir)
|
self.vs_index = self.create_new_vs()
|
||||||
atexit.register(lambda: self.save_to_checkpoint(checkpoint_dir))
|
atexit.register(lambda: self.save_to_checkpoint(checkpoint_dir))
|
||||||
|
|
||||||
def assign_embedding_model(self):
|
def assign_embedding_model(self):
|
||||||
@@ -84,38 +80,44 @@ class LlamaIndexRagWorker(SaveLoad):
|
|||||||
# This function is for debugging
|
# This function is for debugging
|
||||||
self.vs_index.storage_context.index_store.to_dict()
|
self.vs_index.storage_context.index_store.to_dict()
|
||||||
docstore = self.vs_index.storage_context.docstore.docs
|
docstore = self.vs_index.storage_context.docstore.docs
|
||||||
vector_store_preview = "\n".join([ f"{_id} | {tn.text}" for _id, tn in docstore.items() ])
|
vector_store_preview = "\n".join([f"{_id} | {tn.text}" for _id, tn in docstore.items()])
|
||||||
print('\n++ --------inspect_vector_store begin--------')
|
print('\n++ --------inspect_vector_store begin--------')
|
||||||
print(vector_store_preview)
|
print(vector_store_preview)
|
||||||
print('oo --------inspect_vector_store end--------')
|
print('oo --------inspect_vector_store end--------')
|
||||||
return vector_store_preview
|
return vector_store_preview
|
||||||
|
|
||||||
def add_documents_to_vector_store(self, document_list):
|
def add_documents_to_vector_store(self, document_list: List[Document]):
|
||||||
documents = [Document(text=t) for t in document_list]
|
"""
|
||||||
|
Adds a list of Document objects to the vector store after processing.
|
||||||
|
"""
|
||||||
|
documents = document_list
|
||||||
documents_nodes = run_transformations(
|
documents_nodes = run_transformations(
|
||||||
documents, # type: ignore
|
documents, # type: ignore
|
||||||
self.vs_index._transformations,
|
self.vs_index._transformations,
|
||||||
show_progress=True
|
show_progress=True
|
||||||
)
|
)
|
||||||
self.vs_index.insert_nodes(documents_nodes)
|
self.vs_index.insert_nodes(documents_nodes)
|
||||||
if self.debug_mode: self.inspect_vector_store()
|
if self.debug_mode:
|
||||||
|
self.inspect_vector_store()
|
||||||
|
|
||||||
def add_text_to_vector_store(self, text):
|
def add_text_to_vector_store(self, text: str):
|
||||||
node = TextNode(text=text)
|
node = TextNode(text=text)
|
||||||
documents_nodes = run_transformations(
|
documents_nodes = run_transformations(
|
||||||
[node],
|
[node],
|
||||||
self.vs_index._transformations,
|
self.vs_index._transformations,
|
||||||
show_progress=True
|
show_progress=True
|
||||||
)
|
)
|
||||||
self.vs_index.insert_nodes(documents_nodes)
|
self.vs_index.insert_nodes(documents_nodes)
|
||||||
if self.debug_mode: self.inspect_vector_store()
|
if self.debug_mode:
|
||||||
|
self.inspect_vector_store()
|
||||||
|
|
||||||
def remember_qa(self, question, answer):
|
def remember_qa(self, question, answer):
|
||||||
formatted_str = QUESTION_ANSWER_RECORD.format(question=question, answer=answer)
|
formatted_str = QUESTION_ANSWER_RECORD.format(question=question, answer=answer)
|
||||||
self.add_text_to_vector_store(formatted_str)
|
self.add_text_to_vector_store(formatted_str)
|
||||||
|
|
||||||
def retrieve_from_store_with_query(self, query):
|
def retrieve_from_store_with_query(self, query):
|
||||||
if self.debug_mode: self.inspect_vector_store()
|
if self.debug_mode:
|
||||||
|
self.inspect_vector_store()
|
||||||
retriever = self.vs_index.as_retriever()
|
retriever = self.vs_index.as_retriever()
|
||||||
return retriever.retrieve(query)
|
return retriever.retrieve(query)
|
||||||
|
|
||||||
@@ -124,6 +126,13 @@ class LlamaIndexRagWorker(SaveLoad):
|
|||||||
return DEFAULT_QUERY_GENERATION_PROMPT.format(context_str=context_str, query_str=query)
|
return DEFAULT_QUERY_GENERATION_PROMPT.format(context_str=context_str, query_str=query)
|
||||||
|
|
||||||
def generate_node_array_preview(self, nodes):
|
def generate_node_array_preview(self, nodes):
|
||||||
buf = "\n".join(([f"(No.{i+1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)]))
|
buf = "\n".join([f"(No.{i+1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)])
|
||||||
if self.debug_mode: print(buf)
|
if self.debug_mode:
|
||||||
|
print(buf)
|
||||||
return buf
|
return buf
|
||||||
|
|
||||||
|
def purge_vector_store(self):
|
||||||
|
"""
|
||||||
|
Purges the current vector store and creates a new one.
|
||||||
|
"""
|
||||||
|
self.purge()
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import markdown
|
||||||
|
from PyPDF2 import PdfReader
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
from docx import Document as DocxDocument
|
||||||
|
|
||||||
|
|
||||||
|
def extract_text_from_pdf(file_path):
|
||||||
|
reader = PdfReader(file_path)
|
||||||
|
text = ""
|
||||||
|
for page in reader.pages:
|
||||||
|
text += page.extract_text() + "\n"
|
||||||
|
return text
|
||||||
|
|
||||||
|
def extract_text_from_docx(file_path):
|
||||||
|
doc = DocxDocument(file_path)
|
||||||
|
return "\n".join([para.text for para in doc.paragraphs])
|
||||||
|
|
||||||
|
def extract_text_from_txt(file_path):
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
def extract_text_from_md(file_path):
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
md_content = f.read()
|
||||||
|
return markdown.markdown(md_content)
|
||||||
|
|
||||||
|
def extract_text_from_html(file_path):
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
soup = BeautifulSoup(f, 'html.parser')
|
||||||
|
return soup.get_text()
|
||||||
|
|
||||||
|
def extract_text(file_path):
|
||||||
|
_, ext = os.path.splitext(file_path.lower())
|
||||||
|
if ext == '.pdf':
|
||||||
|
return extract_text_from_pdf(file_path)
|
||||||
|
elif ext in ['.docx', '.doc']:
|
||||||
|
return extract_text_from_docx(file_path)
|
||||||
|
elif ext == '.txt':
|
||||||
|
return extract_text_from_txt(file_path)
|
||||||
|
elif ext == '.md':
|
||||||
|
return extract_text_from_md(file_path)
|
||||||
|
elif ext == '.html':
|
||||||
|
return extract_text_from_html(file_path)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported file extension: {ext}")
|
||||||
在新工单中引用
屏蔽一个用户