镜像自地址
https://github.com/binary-husky/gpt_academic.git
已同步 2025-12-06 06:26:47 +00:00
fix local vector store bug
这个提交包含在:
@@ -26,10 +26,6 @@ EMBEDDING_MODEL = "text2vec"
|
||||
# Embedding running device
|
||||
EMBEDDING_DEVICE = "cpu"
|
||||
|
||||
VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vector_store")
|
||||
|
||||
UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content")
|
||||
|
||||
# 基于上下文的prompt模版,请务必保留"{question}"和"{context}"
|
||||
PROMPT_TEMPLATE = """已知信息:
|
||||
{context}
|
||||
@@ -159,7 +155,7 @@ class LocalDocQA:
|
||||
elif os.path.isfile(filepath):
|
||||
file = os.path.split(filepath)[-1]
|
||||
try:
|
||||
docs = load_file(filepath, sentence_size)
|
||||
docs = load_file(filepath, SENTENCE_SIZE)
|
||||
print(f"{file} 已成功加载")
|
||||
loaded_files.append(filepath)
|
||||
except Exception as e:
|
||||
@@ -171,7 +167,7 @@ class LocalDocQA:
|
||||
for file in tqdm(os.listdir(filepath), desc="加载文件"):
|
||||
fullfilepath = os.path.join(filepath, file)
|
||||
try:
|
||||
docs += load_file(fullfilepath, sentence_size)
|
||||
docs += load_file(fullfilepath, SENTENCE_SIZE)
|
||||
loaded_files.append(fullfilepath)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
@@ -185,21 +181,19 @@ class LocalDocQA:
|
||||
else:
|
||||
docs = []
|
||||
for file in filepath:
|
||||
try:
|
||||
docs += load_file(file)
|
||||
print(f"{file} 已成功加载")
|
||||
loaded_files.append(file)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(f"{file} 未能成功加载")
|
||||
docs += load_file(file, SENTENCE_SIZE)
|
||||
print(f"{file} 已成功加载")
|
||||
loaded_files.append(file)
|
||||
|
||||
if len(docs) > 0:
|
||||
print("文件加载完毕,正在生成向量库")
|
||||
if vs_path and os.path.isdir(vs_path):
|
||||
self.vector_store = FAISS.load_local(vs_path, text2vec)
|
||||
self.vector_store.add_documents(docs)
|
||||
try:
|
||||
self.vector_store = FAISS.load_local(vs_path, text2vec)
|
||||
self.vector_store.add_documents(docs)
|
||||
except:
|
||||
self.vector_store = FAISS.from_documents(docs, text2vec)
|
||||
else:
|
||||
if not vs_path: assert False
|
||||
self.vector_store = FAISS.from_documents(docs, text2vec) # docs 为Document列表
|
||||
|
||||
self.vector_store.save_local(vs_path)
|
||||
@@ -208,9 +202,9 @@ class LocalDocQA:
|
||||
self.vector_store = FAISS.load_local(vs_path, text2vec)
|
||||
return vs_path, loaded_files
|
||||
|
||||
def get_loaded_file(self):
|
||||
def get_loaded_file(self, vs_path):
|
||||
ds = self.vector_store.docstore
|
||||
return set([ds._dict[k].metadata['source'].split(UPLOAD_ROOT_PATH)[-1] for k in ds._dict])
|
||||
return set([ds._dict[k].metadata['source'].split(vs_path)[-1] for k in ds._dict])
|
||||
|
||||
|
||||
# query 查询内容
|
||||
@@ -228,7 +222,7 @@ class LocalDocQA:
|
||||
self.vector_store.score_threshold = score_threshold
|
||||
self.vector_store.chunk_size = chunk_size
|
||||
|
||||
embedding = self.vector_store.embedding_function(query)
|
||||
embedding = self.vector_store.embedding_function.embed_query(query)
|
||||
related_docs_with_score = similarity_search_with_score_by_vector(self.vector_store, embedding, k=vector_search_top_k)
|
||||
|
||||
if not related_docs_with_score:
|
||||
@@ -247,27 +241,23 @@ class LocalDocQA:
|
||||
|
||||
|
||||
|
||||
def construct_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation, text2vec):
|
||||
def construct_vector_store(vs_id, vs_path, files, sentence_size, history, one_conent, one_content_segmentation, text2vec):
|
||||
for file in files:
|
||||
assert os.path.exists(file), "输入文件不存在"
|
||||
import nltk
|
||||
if NLTK_DATA_PATH not in nltk.data.path: nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
local_doc_qa = LocalDocQA()
|
||||
local_doc_qa.init_cfg()
|
||||
vs_path = os.path.join(VS_ROOT_PATH, vs_id)
|
||||
filelist = []
|
||||
if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, vs_id)):
|
||||
os.makedirs(os.path.join(UPLOAD_ROOT_PATH, vs_id))
|
||||
if isinstance(files, list):
|
||||
for file in files:
|
||||
file_name = file.name if not isinstance(file, str) else file
|
||||
filename = os.path.split(file_name)[-1]
|
||||
shutil.copyfile(file_name, os.path.join(UPLOAD_ROOT_PATH, vs_id, filename))
|
||||
filelist.append(os.path.join(UPLOAD_ROOT_PATH, vs_id, filename))
|
||||
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, vs_path, sentence_size, text2vec)
|
||||
else:
|
||||
vs_path, loaded_files = local_doc_qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation,
|
||||
sentence_size, text2vec)
|
||||
if not os.path.exists(os.path.join(vs_path, vs_id)):
|
||||
os.makedirs(os.path.join(vs_path, vs_id))
|
||||
for file in files:
|
||||
file_name = file.name if not isinstance(file, str) else file
|
||||
filename = os.path.split(file_name)[-1]
|
||||
shutil.copyfile(file_name, os.path.join(vs_path, vs_id, filename))
|
||||
filelist.append(os.path.join(vs_path, vs_id, filename))
|
||||
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, os.path.join(vs_path, vs_id), sentence_size, text2vec)
|
||||
|
||||
if len(loaded_files):
|
||||
file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问"
|
||||
else:
|
||||
@@ -297,12 +287,13 @@ class knowledge_archive_interface():
|
||||
return self.text2vec_large_chinese
|
||||
|
||||
|
||||
def feed_archive(self, file_manifest, id="default"):
|
||||
def feed_archive(self, file_manifest, vs_path, id="default"):
|
||||
self.threadLock.acquire()
|
||||
# import uuid
|
||||
self.current_id = id
|
||||
self.qa_handle, self.kai_path = construct_vector_store(
|
||||
vs_id=self.current_id,
|
||||
vs_path=vs_path,
|
||||
files=file_manifest,
|
||||
sentence_size=100,
|
||||
history=[],
|
||||
@@ -315,15 +306,16 @@ class knowledge_archive_interface():
|
||||
def get_current_archive_id(self):
|
||||
return self.current_id
|
||||
|
||||
def get_loaded_file(self):
|
||||
return self.qa_handle.get_loaded_file()
|
||||
def get_loaded_file(self, vs_path):
|
||||
return self.qa_handle.get_loaded_file(vs_path)
|
||||
|
||||
def answer_with_archive_by_id(self, txt, id):
|
||||
def answer_with_archive_by_id(self, txt, id, vs_path):
|
||||
self.threadLock.acquire()
|
||||
if not self.current_id == id:
|
||||
self.current_id = id
|
||||
self.qa_handle, self.kai_path = construct_vector_store(
|
||||
vs_id=self.current_id,
|
||||
vs_path=vs_path,
|
||||
files=[],
|
||||
sentence_size=100,
|
||||
history=[],
|
||||
|
||||
在新工单中引用
屏蔽一个用户