Add batch document inquiry function

这个提交包含在:
lbykkkk
2024-11-03 17:17:16 +08:00
父节点 180550b8f0
当前提交 9172337695
共有 5 个文件被更改,包括 975 次插入30 次删除

查看文件

@@ -28,6 +28,7 @@ def get_crazy_functions():
from crazy_functions.Conversation_To_File import Conversation_To_File_Wrap
from crazy_functions.Conversation_To_File import 删除所有本地对话历史记录
from crazy_functions.辅助功能 import 清除缓存
from crazy_functions.批量文件询问 import 批量文件询问
from crazy_functions.Markdown_Translate import Markdown英译中
from crazy_functions.批量总结PDF文档 import 批量总结PDF文档
from crazy_functions.PDF_Translate import 批量翻译PDF文档
@@ -110,12 +111,13 @@ def get_crazy_functions():
"Function": HotReload(Latex翻译中文并重新编译PDF), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
"Class": Arxiv_Localize, # 新一代插件需要注册Class
},
"批量总结Word文档": {
"批量文件询问": {
"Group": "学术",
"Color": "stop",
"AsButton": False,
"Info": "批量总结word文档 | 输入参数为路径",
"Function": HotReload(总结word文档),
"AdvancedArgs": True,
"Info": "通过在高级参数区写入prompt,可自定义询问逻辑,默认情况下为总结逻辑 | 输入参数为路径",
"Function": HotReload(批量文件询问),
},
"解析整个Matlab项目": {
"Group": "编程",
@@ -238,7 +240,7 @@ def get_crazy_functions():
"AsButton": True, # 加入下拉菜单中
# "Info": "连接网络回答问题(需要访问谷歌)| 输入参数是一个问题",
"Function": HotReload(连接网络回答问题),
"Class": NetworkGPT_Wrap # 新一代插件需要注册Class
# "Class": NetworkGPT_Wrap # 新一代插件需要注册Class
},
"历史上的今天": {
"Group": "对话",

查看文件

@@ -0,0 +1,397 @@
import os
import time
from abc import ABC, abstractmethod
from datetime import datetime
from docx import Document
from docx.enum.style import WD_STYLE_TYPE
from docx.enum.text import WD_PARAGRAPH_ALIGNMENT, WD_LINE_SPACING
from docx.oxml.ns import qn
from docx.shared import Inches, Cm
from docx.shared import Pt, RGBColor, Inches
from typing import Dict, List, Tuple
class DocumentFormatter(ABC):
"""文档格式化基类,定义文档格式化的基本接口"""
def __init__(self, final_summary: str, file_summaries_map: Dict, failed_files: List[Tuple]):
self.final_summary = final_summary
self.file_summaries_map = file_summaries_map
self.failed_files = failed_files
@abstractmethod
def format_failed_files(self) -> str:
"""格式化失败文件列表"""
pass
@abstractmethod
def format_file_summaries(self) -> str:
"""格式化文件总结内容"""
pass
@abstractmethod
def create_document(self) -> str:
"""创建完整文档"""
pass
class WordFormatter(DocumentFormatter):
"""Word格式文档生成器 - 符合中国政府公文格式规范(GB/T 9704-2012),并进行了优化"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.doc = Document()
self._setup_document()
self._create_styles()
# 初始化标题编号系统 - 只使用两级编号
self.numbers = {
1: 0, # 一级标题编号
2: 0 # 二级标题编号
}
def _setup_document(self):
"""设置文档基本格式"""
sections = self.doc.sections
for section in sections:
# 设置页面大小为A4
section.page_width = Cm(21)
section.page_height = Cm(29.7)
# 设置页边距
section.top_margin = Cm(3.7) # 上边距37mm
section.bottom_margin = Cm(3.5) # 下边距35mm
section.left_margin = Cm(2.8) # 左边距28mm
section.right_margin = Cm(2.6) # 右边距26mm
# 设置页眉页脚
section.header_distance = Cm(2.0)
section.footer_distance = Cm(2.0)
def _create_styles(self):
"""创建文档样式"""
# 创建正文样式
style = self.doc.styles.add_style('Normal_Custom', WD_STYLE_TYPE.PARAGRAPH)
style.font.name = '仿宋'
style._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
style.font.size = Pt(14) # 调整正文字号为14号
style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
style.paragraph_format.space_after = Pt(0)
style.paragraph_format.first_line_indent = Pt(28) # 首行缩进两个字符14pt * 2
# 创建各级标题样式(从大到小递减)
self._create_heading_style('Title_Custom', '方正小标宋简体', 32, WD_PARAGRAPH_ALIGNMENT.CENTER) # 大标题,增大字号到32
self._create_heading_style('Heading1_Custom', '黑体', 22, WD_PARAGRAPH_ALIGNMENT.LEFT) # 一级标题
self._create_heading_style('Heading2_Custom', '黑体', 18, WD_PARAGRAPH_ALIGNMENT.LEFT) # 二级标题
self._create_heading_style('Heading3_Custom', '黑体', 16, WD_PARAGRAPH_ALIGNMENT.LEFT) # 三级标题
def _create_heading_style(self, style_name: str, font_name: str, font_size: int, alignment):
"""创建标题样式"""
style = self.doc.styles.add_style(style_name, WD_STYLE_TYPE.PARAGRAPH)
style.font.name = font_name
style._element.rPr.rFonts.set(qn('w:eastAsia'), font_name)
style.font.size = Pt(font_size)
style.font.bold = True # 所有标题都加粗
style.paragraph_format.alignment = alignment
style.paragraph_format.space_before = Pt(12)
style.paragraph_format.space_after = Pt(12)
style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
return style
def _get_heading_number(self, level: int) -> str:
"""生成标题编号"""
if level == 0: # 主标题不需要编号
return ""
self.numbers[level] += 1 # 增加当前级别的编号
# 如果是一级标题,重置二级标题编号
if level == 1:
self.numbers[2] = 0
# 根据级别返回不同格式的编号
if level == 1:
return f"{self.numbers[1]}. "
elif level == 2:
return f"{self.numbers[1]}.{self.numbers[2]} "
return ""
def _add_heading(self, text: str, level: int):
"""添加带编号的标题"""
style_map = {
0: 'Title_Custom',
1: 'Heading1_Custom',
2: 'Heading2_Custom',
3: 'Heading3_Custom'
}
# 获取标题编号
number = self._get_heading_number(level)
# 创建段落
paragraph = self.doc.add_paragraph(style=style_map[level])
# 分别添加编号和文本,并设置样式
if number:
number_run = paragraph.add_run(number)
self._get_run_style(number_run, '黑体', 22 if level == 1 else 18, True)
text_run = paragraph.add_run(text)
font_size = 32 if level == 0 else (22 if level == 1 else 18) # 主标题32号,一级标题22号,其他18号
self._get_run_style(text_run, '黑体', font_size, True)
# 特殊处理:主标题添加日期
if level == 0:
date_paragraph = self.doc.add_paragraph()
date_paragraph.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
date_run = date_paragraph.add_run(datetime.now().strftime('%Y年%m月%d'))
date_run.font.name = '仿宋'
date_run._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
date_run.font.size = Pt(16)
return paragraph
def _get_run_style(self, run, font_name: str, font_size: int, bold: bool = False):
"""设置文本运行对象的样式"""
run.font.name = font_name
run._element.rPr.rFonts.set(qn('w:eastAsia'), font_name)
run.font.size = Pt(font_size)
run.font.bold = bold
def format_failed_files(self) -> str:
"""格式化失败文件列表"""
result = []
if not self.failed_files:
return "\n".join(result)
result.append("处理失败文件:")
for fp, reason in self.failed_files:
result.append(f"{os.path.basename(fp)}: {reason}")
# 在文档中添加内容
self._add_heading("处理失败文件", 1)
for fp, reason in self.failed_files:
self._add_content(f"{os.path.basename(fp)}: {reason}", indent=False)
self.doc.add_paragraph()
return "\n".join(result)
def _add_content(self, text: str, indent: bool = True):
"""添加正文内容"""
paragraph = self.doc.add_paragraph(text, style='Normal_Custom')
if not indent:
paragraph.paragraph_format.first_line_indent = Pt(0) # 不缩进的段落
return paragraph
def format_file_summaries(self) -> str:
"""格式化文件总结内容"""
result = []
sorted_paths = sorted(self.file_summaries_map.keys())
current_dir = ""
for path in sorted_paths:
dir_path = os.path.dirname(path)
if dir_path != current_dir:
if dir_path:
result.append(f"\n📁 {dir_path}")
self._add_heading(f"📁 {dir_path}", 2)
current_dir = dir_path
# 添加文件名和内容到结果字符串
file_name = os.path.basename(path)
result.append(f"\n📄 {file_name}")
result.append(self.file_summaries_map[path])
# 在文档中添加文件名作为带编号的二级标题
self._add_heading(f"📄 {file_name}", 2)
self._add_content(self.file_summaries_map[path])
self.doc.add_paragraph()
return "\n".join(result)
def create_document(self):
"""创建完整Word文档并返回文档对象"""
# 重置所有编号
for level in self.numbers:
self.numbers[level] = 0
# 添加主标题(更大字号和加粗)
self._add_heading("文档总结报告", 0)
self.doc.add_paragraph()
# 添加总体摘要
self._add_heading("总体摘要", 1)
self._add_content(self.final_summary)
self.doc.add_paragraph()
# 添加失败文件列表(如果有)
if self.failed_files:
self.format_failed_files()
# 添加文件详细总结
self._add_heading("各文件详细总结", 1)
self.format_file_summaries()
return self.doc # 返回文档对象
class MarkdownFormatter(DocumentFormatter):
"""Markdown格式文档生成器"""
def format_failed_files(self) -> str:
if not self.failed_files:
return ""
formatted_text = ["\n## ⚠️ 处理失败的文件"]
for fp, reason in self.failed_files:
formatted_text.append(f"- {os.path.basename(fp)}: {reason}")
formatted_text.append("\n---")
return "\n".join(formatted_text)
def format_file_summaries(self) -> str:
formatted_text = []
sorted_paths = sorted(self.file_summaries_map.keys())
current_dir = ""
for path in sorted_paths:
dir_path = os.path.dirname(path)
if dir_path != current_dir:
if dir_path:
formatted_text.append(f"\n## 📁 {dir_path}")
current_dir = dir_path
file_name = os.path.basename(path)
formatted_text.append(f"\n### 📄 {file_name}")
formatted_text.append(self.file_summaries_map[path])
formatted_text.append("\n---")
return "\n".join(formatted_text)
def create_document(self) -> str:
document = [
"# 📑 文档总结报告",
"\n## 总体摘要",
self.final_summary
]
if self.failed_files:
document.append(self.format_failed_files())
document.extend([
"\n# 📚 各文件详细总结",
self.format_file_summaries()
])
return "\n".join(document)
class HtmlFormatter(DocumentFormatter):
"""HTML格式文档生成器"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.css_styles = """
body {
font-family: "Microsoft YaHei", Arial, sans-serif;
line-height: 1.6;
max-width: 1000px;
margin: 0 auto;
padding: 20px;
color: #333;
}
h1 {
color: #2c3e50;
border-bottom: 2px solid #eee;
padding-bottom: 10px;
font-size: 24px;
text-align: center;
}
h2 {
color: #34495e;
margin-top: 30px;
font-size: 20px;
border-left: 4px solid #3498db;
padding-left: 10px;
}
h3 {
color: #2c3e50;
font-size: 18px;
margin-top: 20px;
}
.summary {
background-color: #f8f9fa;
padding: 20px;
border-radius: 5px;
margin: 20px 0;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.details {
margin-top: 40px;
}
.failed-files {
background-color: #fff3f3;
padding: 15px;
border-left: 4px solid #e74c3c;
margin: 20px 0;
}
.file-summary {
background-color: #fff;
padding: 15px;
margin: 15px 0;
border-radius: 4px;
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
}
"""
def format_failed_files(self) -> str:
if not self.failed_files:
return ""
failed_files_html = ['<div class="failed-files">']
failed_files_html.append("<h2>⚠️ 处理失败的文件</h2>")
failed_files_html.append("<ul>")
for fp, reason in self.failed_files:
failed_files_html.append(f"<li><strong>{os.path.basename(fp)}:</strong> {reason}</li>")
failed_files_html.append("</ul></div>")
return "\n".join(failed_files_html)
def format_file_summaries(self) -> str:
formatted_html = []
sorted_paths = sorted(self.file_summaries_map.keys())
current_dir = ""
for path in sorted_paths:
dir_path = os.path.dirname(path)
if dir_path != current_dir:
if dir_path:
formatted_html.append(f'<h2>📁 {dir_path}</h2>')
current_dir = dir_path
file_name = os.path.basename(path)
formatted_html.append('<div class="file-summary">')
formatted_html.append(f'<h3>📄 {file_name}</h3>')
formatted_html.append(f'<p>{self.file_summaries_map[path]}</p>')
formatted_html.append('</div>')
return "\n".join(formatted_html)
def create_document(self) -> str:
return f"""
<!DOCTYPE html>
<html>
<head>
<meta charset='utf-8'>
<title>文档总结报告</title>
<style>{self.css_styles}</style>
</head>
<body>
<h1>📑 文档总结报告</h1>
<h2>总体摘要</h2>
<div class="summary">{self.final_summary}</div>
{self.format_failed_files()}
<div class="details">
<h2>📚 各文件详细总结</h2>
{self.format_file_summaries()}
</div>
</body>
</html>
"""

查看文件

@@ -1,17 +1,13 @@
import llama_index
import os
import atexit
from loguru import logger
from typing import List
from llama_index.core import Document
from llama_index.core.schema import TextNode
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
from shared_utils.connect_void_terminal import get_chat_default_kwargs
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
from llama_index.core.ingestion import run_transformations
from llama_index.core import PromptTemplate
from llama_index.core.response_synthesizers import TreeSummarize
from llama_index.core.schema import TextNode
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
DEFAULT_QUERY_GENERATION_PROMPT = """\
Now, you have context information as below:
@@ -63,7 +59,7 @@ class SaveLoad():
def purge(self):
import shutil
shutil.rmtree(self.checkpoint_dir, ignore_errors=True)
self.vs_index = self.create_new_vs()
self.vs_index = self.create_new_vs(self.checkpoint_dir)
class LlamaIndexRagWorker(SaveLoad):
@@ -75,7 +71,7 @@ class LlamaIndexRagWorker(SaveLoad):
if auto_load_checkpoint:
self.vs_index = self.load_from_checkpoint(checkpoint_dir)
else:
self.vs_index = self.create_new_vs(checkpoint_dir)
self.vs_index = self.create_new_vs()
atexit.register(lambda: self.save_to_checkpoint(checkpoint_dir))
def assign_embedding_model(self):
@@ -91,17 +87,21 @@ class LlamaIndexRagWorker(SaveLoad):
logger.info('oo --------inspect_vector_store end--------')
return vector_store_preview
def add_documents_to_vector_store(self, document_list):
documents = [Document(text=t) for t in document_list]
def add_documents_to_vector_store(self, document_list: List[Document]):
"""
Adds a list of Document objects to the vector store after processing.
"""
documents = document_list
documents_nodes = run_transformations(
documents, # type: ignore
self.vs_index._transformations,
show_progress=True
)
self.vs_index.insert_nodes(documents_nodes)
if self.debug_mode: self.inspect_vector_store()
if self.debug_mode:
self.inspect_vector_store()
def add_text_to_vector_store(self, text):
def add_text_to_vector_store(self, text: str):
node = TextNode(text=text)
documents_nodes = run_transformations(
[node],
@@ -109,14 +109,16 @@ class LlamaIndexRagWorker(SaveLoad):
show_progress=True
)
self.vs_index.insert_nodes(documents_nodes)
if self.debug_mode: self.inspect_vector_store()
if self.debug_mode:
self.inspect_vector_store()
def remember_qa(self, question, answer):
formatted_str = QUESTION_ANSWER_RECORD.format(question=question, answer=answer)
self.add_text_to_vector_store(formatted_str)
def retrieve_from_store_with_query(self, query):
if self.debug_mode: self.inspect_vector_store()
if self.debug_mode:
self.inspect_vector_store()
retriever = self.vs_index.as_retriever()
return retriever.retrieve(query)
@@ -128,3 +130,9 @@ class LlamaIndexRagWorker(SaveLoad):
buf = "\n".join(([f"(No.{i+1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)]))
if self.debug_mode: logger.info(buf)
return buf
def purge_vector_store(self):
"""
Purges the current vector store and creates a new one.
"""
self.purge()

查看文件

@@ -0,0 +1,45 @@
import os
from llama_index.core import SimpleDirectoryReader
supports_format = ['.csv', '.docx','.doc', '.epub', '.ipynb', '.mbox', '.md', '.pdf', '.txt', '.ppt',
'.pptm', '.pptx','.py', '.xls', '.xlsx', '.html', '.json', '.xml', '.yaml', '.yml' ,'.m']
def read_docx_doc(file_path):
if file_path.split(".")[-1] == "docx":
from docx import Document
doc = Document(file_path)
file_content = "\n".join([para.text for para in doc.paragraphs])
else:
try:
import win32com.client
word = win32com.client.Dispatch("Word.Application")
word.visible = False
# 打开文件
doc = word.Documents.Open(os.getcwd() + '/' + file_path)
# file_content = doc.Content.Text
doc = word.ActiveDocument
file_content = doc.Range().Text
doc.Close()
word.Quit()
except:
raise RuntimeError('请先将.doc文档转换为.docx文档。')
return file_content
# 修改后的 extract_text 函数,结合 SimpleDirectoryReader 和自定义解析逻辑
import os
def extract_text(file_path):
_, ext = os.path.splitext(file_path.lower())
# 使用 SimpleDirectoryReader 处理它支持的文件格式
if ext in ['.docx', '.doc']:
return read_docx_doc(file_path)
try:
reader = SimpleDirectoryReader(input_files=[file_path])
documents = reader.load_data()
if len(documents) > 0:
return documents[0].text
except Exception as e:
pass
return None

查看文件

@@ -0,0 +1,493 @@
import os
import threading
import time
from dataclasses import dataclass
from typing import List, Tuple, Dict, Generator
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
from crazy_functions.pdf_fns.breakdown_txt import breakdown_text_to_satisfy_token_limit
from crazy_functions.rag_fns.rag_file_support import extract_text
from request_llms.bridge_all import model_info
from toolbox import update_ui, CatchException, report_exception
@dataclass
class FileFragment:
"""文件片段数据类,用于组织处理单元"""
file_path: str
content: str
rel_path: str
fragment_index: int
total_fragments: int
class BatchDocumentSummarizer:
"""优化的文档总结器 - 批处理版本"""
def __init__(self, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List, history: List, system_prompt: str):
"""初始化总结器"""
self.llm_kwargs = llm_kwargs
self.plugin_kwargs = plugin_kwargs
self.chatbot = chatbot
self.history = history
self.system_prompt = system_prompt
self.failed_files = []
self.file_summaries_map = {}
def _get_token_limit(self) -> int:
"""获取模型token限制"""
max_token = model_info[self.llm_kwargs['llm_model']]['max_token']
return max_token * 3 // 4
def _create_batch_inputs(self, fragments: List[FileFragment]) -> Tuple[List, List, List]:
"""创建批处理输入"""
inputs_array = []
inputs_show_user_array = []
history_array = []
for frag in fragments:
if self.plugin_kwargs.get("advanced_arg"):
i_say = (f'请按照用户要求对文件内容进行处理,文件名为{os.path.basename(frag.file_path)}'
f'用户要求为:{self.plugin_kwargs["advanced_arg"]}'
f'文件内容是 ```{frag.content}```')
i_say_show_user = (f'正在处理 {frag.rel_path} (片段 {frag.fragment_index + 1}/{frag.total_fragments})')
else:
i_say = (f'请对下面的内容用中文做概述,文件名是{os.path.basename(frag.file_path)}'
f'内容是 ```{frag.content}```')
i_say_show_user = f'正在处理 {frag.rel_path} (片段 {frag.fragment_index + 1}/{frag.total_fragments})'
inputs_array.append(i_say)
inputs_show_user_array.append(i_say_show_user)
history_array.append([])
return inputs_array, inputs_show_user_array, history_array
def _process_single_file_with_timeout(self, file_info: Tuple[str, str], mutable_status: List) -> List[FileFragment]:
"""包装了超时控制的文件处理函数"""
def timeout_handler():
thread = threading.current_thread()
if hasattr(thread, '_timeout_occurred'):
thread._timeout_occurred = True
# 设置超时标记
thread = threading.current_thread()
thread._timeout_occurred = False
# 设置超时定时器
timer = threading.Timer(self.watch_dog_patience, timeout_handler)
timer.start()
try:
fp, project_folder = file_info
fragments = []
# 定期检查是否超时
def check_timeout():
if hasattr(thread, '_timeout_occurred') and thread._timeout_occurred:
raise TimeoutError("处理超时")
# 更新状态
mutable_status[0] = "检查文件大小"
mutable_status[1] = time.time()
check_timeout()
# 文件大小检查
if os.path.getsize(fp) > self.max_file_size:
self.failed_files.append((fp, f"文件过大:超过{self.max_file_size / 1024 / 1024}MB"))
mutable_status[2] = "文件过大"
return fragments
check_timeout()
# 更新状态
mutable_status[0] = "提取文件内容"
mutable_status[1] = time.time()
# 提取内容
content = extract_text(fp)
if content is None:
self.failed_files.append((fp, "文件解析失败:不支持的格式或文件损坏"))
mutable_status[2] = "格式不支持"
return fragments
elif not content.strip():
self.failed_files.append((fp, "文件内容为空"))
mutable_status[2] = "内容为空"
return fragments
check_timeout()
# 更新状态
mutable_status[0] = "分割文本"
mutable_status[1] = time.time()
# 分割文本
try:
paper_fragments = breakdown_text_to_satisfy_token_limit(
txt=content,
limit=self._get_token_limit(),
llm_model=self.llm_kwargs['llm_model']
)
except Exception as e:
self.failed_files.append((fp, f"文本分割失败:{str(e)}"))
mutable_status[2] = "分割失败"
return fragments
check_timeout()
# 处理片段
rel_path = os.path.relpath(fp, project_folder)
for i, frag in enumerate(paper_fragments):
if frag.strip():
fragments.append(FileFragment(
file_path=fp,
content=frag,
rel_path=rel_path,
fragment_index=i,
total_fragments=len(paper_fragments)
))
mutable_status[2] = "处理完成"
return fragments
except TimeoutError as e:
self.failed_files.append((fp, "处理超时"))
mutable_status[2] = "处理超时"
return []
except Exception as e:
self.failed_files.append((fp, f"处理失败:{str(e)}"))
mutable_status[2] = "处理异常"
return []
finally:
timer.cancel()
def prepare_fragments(self, project_folder: str, file_paths: List[str]) -> Generator:
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
from typing import Generator, List
"""并行准备所有文件的处理片段"""
all_fragments = []
total_files = len(file_paths)
# 配置参数
self.refresh_interval = 0.2 # UI刷新间隔
self.watch_dog_patience = 5 # 看门狗超时时间
self.max_file_size = 10 * 1024 * 1024 # 10MB限制
self.max_workers = min(32, len(file_paths)) # 最多32个线程
# 创建有超时控制的线程池
executor = ThreadPoolExecutor(max_workers=self.max_workers)
# 用于跨线程状态传递的可变列表 - 增加文件名信息
mutable_status_array = [["等待中", time.time(), "pending", file_path] for file_path in file_paths]
# 创建文件处理任务
file_infos = [(fp, project_folder) for fp in file_paths]
# 提交所有任务,使用带超时控制的处理函数
futures = [
executor.submit(
self._process_single_file_with_timeout,
file_info,
mutable_status_array[i]
) for i, file_info in enumerate(file_infos)
]
# 更新UI的计数器
cnt = 0
try:
# 监控任务执行
while True:
time.sleep(self.refresh_interval)
cnt += 1
# 检查任务完成状态
worker_done = [f.done() for f in futures]
# 更新状态显示
status_str = ""
for i, (status, timestamp, desc, file_path) in enumerate(mutable_status_array):
# 获取文件名(去掉路径)
file_name = os.path.basename(file_path)
if worker_done[i]:
status_str += f"文件 {file_name}: {desc}\n"
else:
status_str += f"文件 {file_name}: {status} {desc}\n"
# 更新UI
self.chatbot[-1] = [
"处理进度",
f"正在处理文件...\n\n{status_str}" + "." * (cnt % 10 + 1)
]
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 检查是否所有任务完成
if all(worker_done):
break
finally:
# 确保线程池正确关闭
executor.shutdown(wait=False)
# 收集结果
processed_files = 0
for future in futures:
try:
fragments = future.result(timeout=0.1) # 给予一个短暂的超时时间来获取结果
all_fragments.extend(fragments)
processed_files += 1
except concurrent.futures.TimeoutError:
# 处理获取结果超时
file_index = futures.index(future)
self.failed_files.append((file_paths[file_index], "结果获取超时"))
continue
except Exception as e:
# 处理其他异常
file_index = futures.index(future)
self.failed_files.append((file_paths[file_index], f"未知错误:{str(e)}"))
continue
# 最终进度更新
self.chatbot.append([
"文件处理完成",
f"成功处理 {len(all_fragments)} 个片段,失败 {len(self.failed_files)} 个文件"
])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return all_fragments
def _process_fragments_batch(self, fragments: List[FileFragment]) -> Generator:
"""批量处理文件片段"""
from collections import defaultdict
batch_size = 64 # 每批处理的片段数
max_retries = 3 # 最大重试次数
retry_delay = 5 # 重试延迟(秒)
results = defaultdict(list)
# 按批次处理
for i in range(0, len(fragments), batch_size):
batch = fragments[i:i + batch_size]
inputs_array, inputs_show_user_array, history_array = self._create_batch_inputs(batch)
sys_prompt_array = ["请总结以下内容:"] * len(batch)
# 添加重试机制
for retry in range(max_retries):
try:
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
inputs_array=inputs_array,
inputs_show_user_array=inputs_show_user_array,
llm_kwargs=self.llm_kwargs,
chatbot=self.chatbot,
history_array=history_array,
sys_prompt_array=sys_prompt_array,
)
# 处理响应
for j, frag in enumerate(batch):
summary = response_collection[j * 2 + 1]
if summary and summary.strip():
results[frag.rel_path].append({
'index': frag.fragment_index,
'summary': summary,
'total': frag.total_fragments
})
break # 成功处理,跳出重试循环
except Exception as e:
if retry == max_retries - 1: # 最后一次重试失败
for frag in batch:
self.failed_files.append((frag.file_path, f"处理失败:{str(e)}"))
else:
yield from update_ui(self.chatbot.append([f"批次处理失败,{retry_delay}秒后重试...", str(e)]))
time.sleep(retry_delay)
return results
def _generate_final_summary_request(self) -> Tuple[List, List, List]:
"""准备最终总结请求"""
if not self.file_summaries_map:
return (["无可用的文件总结"], ["生成最终总结"], [[]])
summaries = list(self.file_summaries_map.values())
if all(not summary for summary in summaries):
return (["所有文件处理均失败"], ["生成最终总结"], [[]])
if self.plugin_kwargs.get("advanced_arg"):
i_say = "根据以上所有文件的处理结果,按要求进行综合处理:" + self.plugin_kwargs['advanced_arg']
else:
i_say = "请根据以上所有文件的处理结果,生成最终的总结,不超过1000字。"
return ([i_say], [i_say], [summaries])
def process_files(self, project_folder: str, file_paths: List[str]) -> Generator:
"""处理所有文件"""
total_files = len(file_paths)
self.chatbot.append([f"开始处理", f"总计 {total_files} 个文件"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 1. 准备所有文件片段
# 在 process_files 函数中:
fragments = yield from self.prepare_fragments(project_folder, file_paths)
if not fragments:
self.chatbot.append(["处理失败", "没有可处理的文件内容"])
return "没有可处理的文件内容"
# 2. 批量处理所有文件片段
self.chatbot.append([f"文件分析", f"共计 {len(fragments)} 个处理单元"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
try:
file_summaries = yield from self._process_fragments_batch(fragments)
except Exception as e:
self.chatbot.append(["处理错误", f"批处理过程失败:{str(e)}"])
return "处理过程发生错误"
# 3. 为每个文件生成整体总结
self.chatbot.append(["生成总结", "正在汇总文件内容..."])
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 处理每个文件的总结
for rel_path, summaries in file_summaries.items():
if len(summaries) > 1: # 多片段文件需要生成整体总结
sorted_summaries = sorted(summaries, key=lambda x: x['index'])
if self.plugin_kwargs.get("advanced_arg"):
i_say = (f"根据以下内容,按要求:{self.plugin_kwargs['advanced_arg']}"
f"总结文件 {os.path.basename(rel_path)} 的主要内容。")
else:
i_say = f"请总结文件 {os.path.basename(rel_path)} 的主要内容,不超过500字。"
try:
summary_texts = [s['summary'] for s in sorted_summaries]
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
inputs_array=[i_say],
inputs_show_user_array=[f"生成 {rel_path} 的总结"],
llm_kwargs=self.llm_kwargs,
chatbot=self.chatbot,
history_array=[summary_texts],
sys_prompt_array=["总结文件内容。"],
)
self.file_summaries_map[rel_path] = response_collection[1]
except Exception as e:
self.chatbot.append(["警告", f"文件 {rel_path} 总结生成失败:{str(e)}"])
self.file_summaries_map[rel_path] = "总结生成失败"
else: # 单片段文件直接使用其唯一的总结
self.file_summaries_map[rel_path] = summaries[0]['summary']
# 4. 生成最终总结
try:
# 收集所有文件的总结用于生成最终总结
file_summaries_for_final = []
for rel_path, summary in self.file_summaries_map.items():
file_summaries_for_final.append(f"文件 {rel_path} 的总结:\n{summary}")
if self.plugin_kwargs.get("advanced_arg"):
final_summary_prompt = ("根据以下所有文件的总结内容,按要求进行综合处理:" +
self.plugin_kwargs['advanced_arg'])
else:
final_summary_prompt = "请根据以下所有文件的总结内容,生成最终的总结报告。"
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
inputs_array=[final_summary_prompt],
inputs_show_user_array=["生成最终总结报告"],
llm_kwargs=self.llm_kwargs,
chatbot=self.chatbot,
history_array=[file_summaries_for_final],
sys_prompt_array=["总结所有文件内容。"],
max_workers=1
)
return response_collection[1] if len(response_collection) > 1 else "生成总结失败"
except Exception as e:
self.chatbot.append(["错误", f"最终总结生成失败:{str(e)}"])
return "生成总结失败"
def save_results(self, final_summary: str):
"""保存结果到文件"""
from toolbox import promote_file_to_downloadzone, write_history_to_file
from crazy_functions.doc_fns.batch_file_query_doc import MarkdownFormatter, HtmlFormatter, WordFormatter
import os
timestamp = time.strftime("%Y%m%d_%H%M%S")
# 创建各种格式化器
md_formatter = MarkdownFormatter(final_summary, self.file_summaries_map, self.failed_files)
html_formatter = HtmlFormatter(final_summary, self.file_summaries_map, self.failed_files)
word_formatter = WordFormatter(final_summary, self.file_summaries_map, self.failed_files)
result_files = []
# 保存 Markdown
md_content = md_formatter.create_document()
result_file_md = write_history_to_file(
history=[md_content], # 直接传入内容列表
file_basename=f"文档总结_{timestamp}.md"
)
result_files.append(result_file_md)
# 保存 HTML
html_content = html_formatter.create_document()
result_file_html = write_history_to_file(
history=[html_content],
file_basename=f"文档总结_{timestamp}.html"
)
result_files.append(result_file_html)
# 保存 Word
doc = word_formatter.create_document()
# 由于 Word 文档需要用 doc.save(),我们使用与 md 文件相同的目录
result_file_docx = os.path.join(
os.path.dirname(result_file_md),
f"文档总结_{timestamp}.docx"
)
doc.save(result_file_docx)
result_files.append(result_file_docx)
# 添加到下载区
for file in result_files:
promote_file_to_downloadzone(file, chatbot=self.chatbot)
self.chatbot.append(["处理完成", f"结果已保存至: {', '.join(result_files)}"])
@CatchException
def 批量文件询问(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
history: List, system_prompt: str, user_request: str):
"""主函数 - 优化版本"""
# 初始化
import glob
import re
from crazy_functions.rag_fns.rag_file_support import supports_format
from toolbox import report_exception
summarizer = BatchDocumentSummarizer(llm_kwargs, plugin_kwargs, chatbot, history, system_prompt)
chatbot.append(["函数插件功能", f"作者lbykkkk,批量总结文件。支持格式: {', '.join(supports_format)}等其他文本格式文件,如果长时间卡在文件处理过程,请查看处理进度,然后删除所有处于“pending”状态的文件,然后重新上传处理。"])
yield from update_ui(chatbot=chatbot, history=history)
# 验证输入路径
if not os.path.exists(txt):
report_exception(chatbot, history, a=f"解析项目: {txt}", b=f"找不到项目或无权访问: {txt}")
yield from update_ui(chatbot=chatbot, history=history)
return
# 获取文件列表
project_folder = txt
extract_folder = next((d for d in glob.glob(f'{project_folder}/*')
if os.path.isdir(d) and d.endswith('.extract')), project_folder)
exclude_patterns = r'/[^/]+\.(zip|rar|7z|tar|gz)$'
file_manifest = [f for f in glob.glob(f'{extract_folder}/**', recursive=True)
if os.path.isfile(f) and not re.search(exclude_patterns, f)]
if not file_manifest:
report_exception(chatbot, history, a=f"解析项目: {txt}", b="未找到支持的文件类型")
yield from update_ui(chatbot=chatbot, history=history)
return
# 处理所有文件并生成总结
final_summary = yield from summarizer.process_files(project_folder, file_manifest)
yield from update_ui(chatbot=chatbot, history=history)
# 保存结果
summarizer.save_results(final_summary)
yield from update_ui(chatbot=chatbot, history=history)