镜像自地址
https://github.com/binary-husky/gpt_academic.git
已同步 2025-12-08 07:26:48 +00:00
Frontier (#1958)
* update welcome svg * fix loading chatglm3 (#1937) * update welcome svg * update welcome message * fix loading chatglm3 --------- Co-authored-by: binary-husky <qingxu.fu@outlook.com> Co-authored-by: binary-husky <96192199+binary-husky@users.noreply.github.com> * begin rag project with llama index * rag version one * rag beta release * add social worker (proto) * fix llamaindex version --------- Co-authored-by: moetayuko <loli@yuko.moe>
这个提交包含在:
75
crazy_functions/Rag_Interface.py
普通文件
75
crazy_functions/Rag_Interface.py
普通文件
@@ -0,0 +1,75 @@
|
||||
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg
|
||||
from crazy_functions.crazy_utils import input_clipping
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
|
||||
|
||||
RAG_WORKER_REGISTER = {}
|
||||
|
||||
MAX_HISTORY_ROUND = 5
|
||||
MAX_CONTEXT_TOKEN_LIMIT = 4096
|
||||
REMEMBER_PREVIEW = 1000
|
||||
|
||||
@CatchException
|
||||
def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
|
||||
# 1. we retrieve rag worker from global context
|
||||
user_name = chatbot.get_user()
|
||||
if user_name in RAG_WORKER_REGISTER:
|
||||
rag_worker = RAG_WORKER_REGISTER[user_name]
|
||||
else:
|
||||
rag_worker = RAG_WORKER_REGISTER[user_name] = LlamaIndexRagWorker(
|
||||
user_name,
|
||||
llm_kwargs,
|
||||
checkpoint_dir=get_log_folder(user_name, plugin_name='experimental_rag'),
|
||||
auto_load_checkpoint=True)
|
||||
|
||||
chatbot.append([txt, '正在召回知识 ...'])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
# 2. clip history to reduce token consumption
|
||||
# 2-1. reduce chat round
|
||||
txt_origin = txt
|
||||
|
||||
if len(history) > MAX_HISTORY_ROUND * 2:
|
||||
history = history[-(MAX_HISTORY_ROUND * 2):]
|
||||
txt_clip, history, flags = input_clipping(txt, history, max_token_limit=MAX_CONTEXT_TOKEN_LIMIT, return_clip_flags=True)
|
||||
input_is_clipped_flag = (flags["original_input_len"] != flags["clipped_input_len"])
|
||||
|
||||
# 2-2. if input is clipped, add input to vector store before retrieve
|
||||
if input_is_clipped_flag:
|
||||
yield from update_ui_lastest_msg('检测到长输入, 正在向量化 ...', chatbot, history, delay=0) # 刷新界面
|
||||
# save input to vector store
|
||||
rag_worker.add_text_to_vector_store(txt_origin)
|
||||
yield from update_ui_lastest_msg('向量化完成 ...', chatbot, history, delay=0) # 刷新界面
|
||||
if len(txt_origin) > REMEMBER_PREVIEW:
|
||||
HALF = REMEMBER_PREVIEW//2
|
||||
i_say_to_remember = txt[:HALF] + f" ...\n...(省略{len(txt_origin)-REMEMBER_PREVIEW}字)...\n... " + txt[-HALF:]
|
||||
if (flags["original_input_len"] - flags["clipped_input_len"]) > HALF:
|
||||
txt_clip = txt_clip + f" ...\n...(省略{len(txt_origin)-len(txt_clip)-HALF}字)...\n... " + txt[-HALF:]
|
||||
else:
|
||||
pass
|
||||
i_say = txt_clip
|
||||
else:
|
||||
i_say_to_remember = i_say = txt_clip
|
||||
else:
|
||||
i_say_to_remember = i_say = txt_clip
|
||||
|
||||
# 3. we search vector store and build prompts
|
||||
nodes = rag_worker.retrieve_from_store_with_query(i_say)
|
||||
prompt = rag_worker.build_prompt(query=i_say, nodes=nodes)
|
||||
|
||||
# 4. it is time to query llms
|
||||
if len(chatbot) != 0: chatbot.pop(-1) # pop temp chat, because we are going to add them again inside `request_gpt_model_in_new_thread_with_ui_alive`
|
||||
model_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||
inputs=prompt, inputs_show_user=i_say,
|
||||
llm_kwargs=llm_kwargs, chatbot=chatbot, history=history,
|
||||
sys_prompt=system_prompt,
|
||||
retry_times_at_unknown_error=0
|
||||
)
|
||||
|
||||
# 5. remember what has been asked / answered
|
||||
yield from update_ui_lastest_msg(model_say + '</br></br>' + '对话记忆中, 请稍等 ...', chatbot, history, delay=0.5) # 刷新界面
|
||||
rag_worker.remember_qa(i_say_to_remember, model_say)
|
||||
history.extend([i_say, model_say])
|
||||
|
||||
yield from update_ui_lastest_msg(model_say, chatbot, history, delay=0) # 刷新界面
|
||||
65
crazy_functions/Social_Helper.py
普通文件
65
crazy_functions/Social_Helper.py
普通文件
@@ -0,0 +1,65 @@
|
||||
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg
|
||||
from crazy_functions.crazy_utils import input_clipping
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
import pickle, os
|
||||
|
||||
SOCIAL_NETWOK_WORKER_REGISTER = {}
|
||||
|
||||
class SocialNetwork():
|
||||
def __init__(self):
|
||||
self.people = []
|
||||
|
||||
class SocialNetworkWorker():
|
||||
def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None:
|
||||
self.user_name = user_name
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
if auto_load_checkpoint:
|
||||
self.social_network = self.load_from_checkpoint(checkpoint_dir)
|
||||
else:
|
||||
self.social_network = SocialNetwork()
|
||||
|
||||
def does_checkpoint_exist(self, checkpoint_dir=None):
|
||||
import os, glob
|
||||
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
|
||||
if not os.path.exists(checkpoint_dir): return False
|
||||
if len(glob.glob(os.path.join(checkpoint_dir, "social_network.pkl"))) == 0: return False
|
||||
return True
|
||||
|
||||
def save_to_checkpoint(self, checkpoint_dir=None):
|
||||
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
|
||||
with open(os.path.join(checkpoint_dir, 'social_network.pkl'), "wb+") as f:
|
||||
pickle.dump(self.social_network, f)
|
||||
return
|
||||
|
||||
def load_from_checkpoint(self, checkpoint_dir=None):
|
||||
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
|
||||
if self.does_checkpoint_exist(checkpoint_dir=checkpoint_dir):
|
||||
with open(os.path.join(checkpoint_dir, 'social_network.pkl'), "rb") as f:
|
||||
social_network = pickle.load(f)
|
||||
return social_network
|
||||
else:
|
||||
return SocialNetwork()
|
||||
|
||||
|
||||
@CatchException
|
||||
def I人助手(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request, num_day=5):
|
||||
|
||||
# 1. we retrieve worker from global context
|
||||
user_name = chatbot.get_user()
|
||||
checkpoint_dir=get_log_folder(user_name, plugin_name='experimental_rag')
|
||||
if user_name in SOCIAL_NETWOK_WORKER_REGISTER:
|
||||
social_network_worker = SOCIAL_NETWOK_WORKER_REGISTER[user_name]
|
||||
else:
|
||||
social_network_worker = SOCIAL_NETWOK_WORKER_REGISTER[user_name] = SocialNetworkWorker(
|
||||
user_name,
|
||||
llm_kwargs,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
auto_load_checkpoint=True
|
||||
)
|
||||
|
||||
# 2. save
|
||||
social_network_worker.social_network.people.append("张三")
|
||||
social_network_worker.save_to_checkpoint(checkpoint_dir)
|
||||
chatbot.append(["good", "work"])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
@@ -4,7 +4,7 @@ import threading
|
||||
import os
|
||||
import logging
|
||||
|
||||
def input_clipping(inputs, history, max_token_limit):
|
||||
def input_clipping(inputs, history, max_token_limit, return_clip_flags=False):
|
||||
"""
|
||||
当输入文本 + 历史文本超出最大限制时,采取措施丢弃一部分文本。
|
||||
输入:
|
||||
@@ -20,17 +20,20 @@ def input_clipping(inputs, history, max_token_limit):
|
||||
enc = model_info["gpt-3.5-turbo"]['tokenizer']
|
||||
def get_token_num(txt): return len(enc.encode(txt, disallowed_special=()))
|
||||
|
||||
|
||||
mode = 'input-and-history'
|
||||
# 当 输入部分的token占比 小于 全文的一半时,只裁剪历史
|
||||
input_token_num = get_token_num(inputs)
|
||||
original_input_len = len(inputs)
|
||||
if input_token_num < max_token_limit//2:
|
||||
mode = 'only-history'
|
||||
max_token_limit = max_token_limit - input_token_num
|
||||
|
||||
everything = [inputs] if mode == 'input-and-history' else ['']
|
||||
everything.extend(history)
|
||||
n_token = get_token_num('\n'.join(everything))
|
||||
full_token_num = n_token = get_token_num('\n'.join(everything))
|
||||
everything_token = [get_token_num(e) for e in everything]
|
||||
everything_token_num = sum(everything_token)
|
||||
delta = max(everything_token) // 16 # 截断时的颗粒度
|
||||
|
||||
while n_token > max_token_limit:
|
||||
@@ -43,10 +46,24 @@ def input_clipping(inputs, history, max_token_limit):
|
||||
|
||||
if mode == 'input-and-history':
|
||||
inputs = everything[0]
|
||||
full_token_num = everything_token_num
|
||||
else:
|
||||
pass
|
||||
full_token_num = everything_token_num + input_token_num
|
||||
|
||||
history = everything[1:]
|
||||
return inputs, history
|
||||
|
||||
flags = {
|
||||
"mode": mode,
|
||||
"original_input_token_num": input_token_num,
|
||||
"original_full_token_num": full_token_num,
|
||||
"original_input_len": original_input_len,
|
||||
"clipped_input_len": len(inputs),
|
||||
}
|
||||
|
||||
if not return_clip_flags:
|
||||
return inputs, history
|
||||
else:
|
||||
return inputs, history, flags
|
||||
|
||||
def request_gpt_model_in_new_thread_with_ui_alive(
|
||||
inputs, inputs_show_user, llm_kwargs,
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
import llama_index
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.schema import TextNode
|
||||
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
|
||||
from shared_utils.connect_void_terminal import get_chat_default_kwargs
|
||||
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
|
||||
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
|
||||
from llama_index.core.ingestion import run_transformations
|
||||
from llama_index.core import PromptTemplate
|
||||
from llama_index.core.response_synthesizers import TreeSummarize
|
||||
|
||||
DEFAULT_QUERY_GENERATION_PROMPT = """\
|
||||
Now, you have context information as below:
|
||||
---------------------
|
||||
{context_str}
|
||||
---------------------
|
||||
Answer the user request below (use the context information if necessary, otherwise you can ignore them):
|
||||
---------------------
|
||||
{query_str}
|
||||
"""
|
||||
|
||||
QUESTION_ANSWER_RECORD = """\
|
||||
{{
|
||||
"type": "This is a previous conversation with the user",
|
||||
"question": "{question}",
|
||||
"answer": "{answer}",
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
class SaveLoad():
|
||||
|
||||
def does_checkpoint_exist(self, checkpoint_dir=None):
|
||||
import os, glob
|
||||
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
|
||||
if not os.path.exists(checkpoint_dir): return False
|
||||
if len(glob.glob(os.path.join(checkpoint_dir, "*.json"))) == 0: return False
|
||||
return True
|
||||
|
||||
def save_to_checkpoint(self, checkpoint_dir=None):
|
||||
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
|
||||
self.vs_index.storage_context.persist(persist_dir=checkpoint_dir)
|
||||
|
||||
def load_from_checkpoint(self, checkpoint_dir=None):
|
||||
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
|
||||
if self.does_checkpoint_exist(checkpoint_dir=checkpoint_dir):
|
||||
print('loading checkpoint from disk')
|
||||
from llama_index.core import StorageContext, load_index_from_storage
|
||||
storage_context = StorageContext.from_defaults(persist_dir=checkpoint_dir)
|
||||
self.vs_index = load_index_from_storage(storage_context, embed_model=self.embed_model)
|
||||
return self.vs_index
|
||||
else:
|
||||
return self.create_new_vs()
|
||||
|
||||
def create_new_vs(self):
|
||||
return GptacVectorStoreIndex.default_vector_store(embed_model=self.embed_model)
|
||||
|
||||
|
||||
class LlamaIndexRagWorker(SaveLoad):
|
||||
def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None:
|
||||
self.debug_mode = True
|
||||
self.embed_model = OpenAiEmbeddingModel(llm_kwargs)
|
||||
self.user_name = user_name
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
if auto_load_checkpoint:
|
||||
self.vs_index = self.load_from_checkpoint(checkpoint_dir)
|
||||
else:
|
||||
self.vs_index = self.create_new_vs()
|
||||
|
||||
def assign_embedding_model(self):
|
||||
pass
|
||||
|
||||
def inspect_vector_store(self):
|
||||
# This function is for debugging
|
||||
self.vs_index.storage_context.index_store.to_dict()
|
||||
docstore = self.vs_index.storage_context.docstore.docs
|
||||
vector_store_preview = "\n".join([ f"{_id} | {tn.text}" for _id, tn in docstore.items() ])
|
||||
print('\n++ --------inspect_vector_store begin--------')
|
||||
print(vector_store_preview)
|
||||
print('oo --------inspect_vector_store end--------')
|
||||
return vector_store_preview
|
||||
|
||||
def add_documents_to_vector_store(self, document_list):
|
||||
documents = [Document(text=t) for t in document_list]
|
||||
documents_nodes = run_transformations(
|
||||
documents, # type: ignore
|
||||
self.vs_index._transformations,
|
||||
show_progress=True
|
||||
)
|
||||
self.vs_index.insert_nodes(documents_nodes)
|
||||
if self.debug_mode: self.inspect_vector_store()
|
||||
|
||||
def add_text_to_vector_store(self, text):
|
||||
node = TextNode(text=text)
|
||||
documents_nodes = run_transformations(
|
||||
[node],
|
||||
self.vs_index._transformations,
|
||||
show_progress=True
|
||||
)
|
||||
self.vs_index.insert_nodes(documents_nodes)
|
||||
if self.debug_mode: self.inspect_vector_store()
|
||||
|
||||
def remember_qa(self, question, answer):
|
||||
formatted_str = QUESTION_ANSWER_RECORD.format(question=question, answer=answer)
|
||||
self.add_text_to_vector_store(formatted_str)
|
||||
|
||||
def retrieve_from_store_with_query(self, query):
|
||||
if self.debug_mode: self.inspect_vector_store()
|
||||
retriever = self.vs_index.as_retriever()
|
||||
return retriever.retrieve(query)
|
||||
|
||||
def build_prompt(self, query, nodes):
|
||||
context_str = self.generate_node_array_preview(nodes)
|
||||
return DEFAULT_QUERY_GENERATION_PROMPT.format(context_str=context_str, query_str=query)
|
||||
|
||||
def generate_node_array_preview(self, nodes):
|
||||
buf = "\n".join(([f"(No.{i+1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)]))
|
||||
if self.debug_mode: print(buf)
|
||||
return buf
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from llama_index.core.callbacks.base import CallbackManager
|
||||
from llama_index.core.schema import TransformComponent
|
||||
from llama_index.core.service_context import ServiceContext
|
||||
from llama_index.core.settings import (
|
||||
Settings,
|
||||
callback_manager_from_settings_or_context,
|
||||
transformations_from_settings_or_context,
|
||||
)
|
||||
from llama_index.core.storage.storage_context import StorageContext
|
||||
|
||||
|
||||
class GptacVectorStoreIndex(VectorStoreIndex):
|
||||
|
||||
@classmethod
|
||||
def default_vector_store(
|
||||
cls,
|
||||
storage_context: Optional[StorageContext] = None,
|
||||
show_progress: bool = False,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
transformations: Optional[List[TransformComponent]] = None,
|
||||
# deprecated
|
||||
service_context: Optional[ServiceContext] = None,
|
||||
embed_model = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Create index from documents.
|
||||
|
||||
Args:
|
||||
documents (Optional[Sequence[BaseDocument]]): List of documents to
|
||||
build the index from.
|
||||
|
||||
"""
|
||||
storage_context = storage_context or StorageContext.from_defaults()
|
||||
docstore = storage_context.docstore
|
||||
callback_manager = (
|
||||
callback_manager
|
||||
or callback_manager_from_settings_or_context(Settings, service_context)
|
||||
)
|
||||
transformations = transformations or transformations_from_settings_or_context(
|
||||
Settings, service_context
|
||||
)
|
||||
|
||||
with callback_manager.as_trace("index_construction"):
|
||||
|
||||
return cls(
|
||||
nodes=[],
|
||||
storage_context=storage_context,
|
||||
callback_manager=callback_manager,
|
||||
show_progress=show_progress,
|
||||
transformations=transformations,
|
||||
service_context=service_context,
|
||||
embed_model=embed_model,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
在新工单中引用
屏蔽一个用户