diff --git a/crazy_functions/Conversation_To_File.py b/crazy_functions/Conversation_To_File.py index b8408748..a111002c 100644 --- a/crazy_functions/Conversation_To_File.py +++ b/crazy_functions/Conversation_To_File.py @@ -172,7 +172,7 @@ def 载入对话历史存档(txt, llm_kwargs, plugin_kwargs, chatbot, history, s user_request 当前用户的请求信息(IP地址等) """ from crazy_functions.crazy_utils import get_files_from_everything - success, file_manifest, _ = get_files_from_everything(txt, type='.html') + success, file_manifest, _ = get_files_from_everything(txt, type='.html',chatbot=chatbot) if not success: if txt == "": txt = '空空如也的输入栏' diff --git a/crazy_functions/Latex_Project_Polish.py b/crazy_functions/Latex_Project_Polish.py index 875e5ad4..7fd7ff36 100644 --- a/crazy_functions/Latex_Project_Polish.py +++ b/crazy_functions/Latex_Project_Polish.py @@ -1,3 +1,4 @@ +from shared_utils.fastapi_server import validate_path_safety from toolbox import update_ui, trimmed_format_exc, promote_file_to_downloadzone, get_log_folder from toolbox import CatchException, report_exception, write_history_to_file, zip_folder from loguru import logger @@ -155,6 +156,7 @@ def Latex英文润色(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_p import glob, os if os.path.exists(txt): project_folder = txt + validate_path_safety(project_folder, chatbot.get_user()) else: if txt == "": txt = '空空如也的输入栏' report_exception(chatbot, history, a = f"解析项目: {txt}", b = f"找不到本地项目或无权访问: {txt}") @@ -193,6 +195,7 @@ def Latex中文润色(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_p import glob, os if os.path.exists(txt): project_folder = txt + validate_path_safety(project_folder, chatbot.get_user()) else: if txt == "": txt = '空空如也的输入栏' report_exception(chatbot, history, a = f"解析项目: {txt}", b = f"找不到本地项目或无权访问: {txt}") @@ -229,6 +232,7 @@ def Latex英文纠错(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_p import glob, os if os.path.exists(txt): project_folder = txt + validate_path_safety(project_folder, chatbot.get_user()) else: if txt == "": txt = '空空如也的输入栏' report_exception(chatbot, history, a = f"解析项目: {txt}", b = f"找不到本地项目或无权访问: {txt}") diff --git a/crazy_functions/Markdown_Translate.py b/crazy_functions/Markdown_Translate.py index 45d4c712..f43ef23c 100644 --- a/crazy_functions/Markdown_Translate.py +++ b/crazy_functions/Markdown_Translate.py @@ -1,5 +1,6 @@ import glob, shutil, os, re from loguru import logger +from shared_utils.fastapi_server import validate_path_safety from toolbox import update_ui, trimmed_format_exc, gen_time_str from toolbox import CatchException, report_exception, get_log_folder from toolbox import write_history_to_file, promote_file_to_downloadzone @@ -118,7 +119,7 @@ def 多文件翻译(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 -def get_files_from_everything(txt, preference=''): +def get_files_from_everything(txt, preference='', chatbox=None): if txt == "": return False, None, None success = True if txt.startswith('http'): @@ -146,9 +147,11 @@ def get_files_from_everything(txt, preference=''): # 直接给定文件 file_manifest = [txt] project_folder = os.path.dirname(txt) + validate_path_safety(project_folder, chatbot.get_user()) elif os.path.exists(txt): # 本地路径,递归搜索 project_folder = txt + validate_path_safety(project_folder, chatbot.get_user()) file_manifest = [f for f in glob.glob(f'{project_folder}/**/*.md', recursive=True)] else: project_folder = None @@ -177,7 +180,7 @@ def Markdown英译中(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_p return history = [] # 清空历史,以免输入溢出 - success, file_manifest, project_folder = get_files_from_everything(txt, preference="Github") + success, file_manifest, project_folder = get_files_from_everything(txt, preference="Github", chatbox=chatbot) if not success: # 什么都没有 diff --git a/crazy_functions/PDF_Translate.py b/crazy_functions/PDF_Translate.py index a4d10837..127a1cfc 100644 --- a/crazy_functions/PDF_Translate.py +++ b/crazy_functions/PDF_Translate.py @@ -26,7 +26,7 @@ def 批量翻译PDF文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst # 清空历史,以免输入溢出 history = [] - success, file_manifest, project_folder = get_files_from_everything(txt, type='.pdf') + success, file_manifest, project_folder = get_files_from_everything(txt, type='.pdf', chatbot=chatbot) # 检测输入参数,如没有给定输入参数,直接退出 if (not success) and txt == "": txt = '空空如也的输入栏。提示:请先上传文件(把PDF文件拖入对话)。' diff --git a/crazy_functions/crazy_utils.py b/crazy_functions/crazy_utils.py index 5c8776ac..7db258a4 100644 --- a/crazy_functions/crazy_utils.py +++ b/crazy_functions/crazy_utils.py @@ -2,6 +2,7 @@ import os import threading from loguru import logger from shared_utils.char_visual_effect import scolling_visual_effect +from shared_utils.fastapi_server import validate_path_safety from toolbox import update_ui, get_conf, trimmed_format_exc, get_max_token, Singleton def input_clipping(inputs, history, max_token_limit, return_clip_flags=False): @@ -539,7 +540,7 @@ def read_and_clean_pdf_text(fp): return meta_txt, page_one_meta -def get_files_from_everything(txt, type): # type='.md' +def get_files_from_everything(txt, type, chatbot=None): # type='.md' """ 这个函数是用来获取指定目录下所有指定类型(如.md)的文件,并且对于网络上的文件,也可以获取它。 下面是对每个参数和返回值的说明: @@ -551,6 +552,7 @@ def get_files_from_everything(txt, type): # type='.md' - file_manifest: 文件路径列表,里面包含以指定类型为后缀名的所有文件的绝对路径。 - project_folder: 字符串,表示文件所在的文件夹路径。如果是网络上的文件,就是临时文件夹的路径。 该函数详细注释已添加,请确认是否满足您的需要。 + - chatbot 带Cookies的Chatbot类,为实现更多强大的功能做基础 """ import glob, os @@ -573,9 +575,13 @@ def get_files_from_everything(txt, type): # type='.md' # 直接给定文件 file_manifest = [txt] project_folder = os.path.dirname(txt) + if chatbot is not None: + validate_path_safety(project_folder, chatbot.get_user()) elif os.path.exists(txt): # 本地路径,递归搜索 project_folder = txt + if chatbot is not None: + validate_path_safety(project_folder, chatbot.get_user()) file_manifest = [f for f in glob.glob(f'{project_folder}/**/*'+type, recursive=True)] if len(file_manifest) == 0: success = False diff --git a/crazy_functions/pdf_fns/parse_pdf_via_doc2x.py b/crazy_functions/pdf_fns/parse_pdf_via_doc2x.py index 93d45b1a..cb0fe379 100644 --- a/crazy_functions/pdf_fns/parse_pdf_via_doc2x.py +++ b/crazy_functions/pdf_fns/parse_pdf_via_doc2x.py @@ -242,9 +242,7 @@ def 解析PDF_DOC2X_单文件( extract_archive(file_path=this_file_path, dest_dir=ex_folder) # edit markdown files - success, file_manifest, project_folder = get_files_from_everything( - ex_folder, type=".md" - ) + success, file_manifest, project_folder = get_files_from_everything(ex_folder, type='.md', chatbot=chatbot) for generated_fp in file_manifest: # 修正一些公式问题 with open(generated_fp, "r", encoding="utf8") as f: diff --git a/crazy_functions/pdf_fns/parse_word.py b/crazy_functions/pdf_fns/parse_word.py index 3664a9cb..79fe7d32 100644 --- a/crazy_functions/pdf_fns/parse_word.py +++ b/crazy_functions/pdf_fns/parse_word.py @@ -27,10 +27,10 @@ def extract_text_from_files(txt, chatbot, history): return False, final_result, page_one, file_manifest, excption #如输入区内容不是文件则直接返回输入区内容 #查找输入区内容中的文件 - file_pdf,pdf_manifest,folder_pdf = get_files_from_everything(txt, '.pdf') - file_md,md_manifest,folder_md = get_files_from_everything(txt, '.md') - file_word,word_manifest,folder_word = get_files_from_everything(txt, '.docx') - file_doc,doc_manifest,folder_doc = get_files_from_everything(txt, '.doc') + file_pdf,pdf_manifest,folder_pdf = get_files_from_everything(txt, '.pdf', chatbot=chatbot) + file_md,md_manifest,folder_md = get_files_from_everything(txt, '.md', chatbot=chatbot) + file_word,word_manifest,folder_word = get_files_from_everything(txt, '.docx', chatbot=chatbot) + file_doc,doc_manifest,folder_doc = get_files_from_everything(txt, '.doc', chatbot=chatbot) if file_doc: excption = "word" diff --git a/crazy_functions/总结word文档.py b/crazy_functions/总结word文档.py index 99f0919b..83fe7e63 100644 --- a/crazy_functions/总结word文档.py +++ b/crazy_functions/总结word文档.py @@ -104,6 +104,8 @@ def 总结word文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_pr # 检测输入参数,如没有给定输入参数,直接退出 if os.path.exists(txt): project_folder = txt + from shared_utils.fastapi_server import validate_path_safety + validate_path_safety(project_folder, chatbot.get_user()) else: if txt == "": txt = '空空如也的输入栏' report_exception(chatbot, history, a=f"解析项目: {txt}", b=f"找不到本地项目或无权访问: {txt}") diff --git a/crazy_functions/批量翻译PDF文档_NOUGAT.py b/crazy_functions/批量翻译PDF文档_NOUGAT.py index 130dde8f..95b82aca 100644 --- a/crazy_functions/批量翻译PDF文档_NOUGAT.py +++ b/crazy_functions/批量翻译PDF文档_NOUGAT.py @@ -61,7 +61,7 @@ def 批量翻译PDF文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst history = [] from crazy_functions.crazy_utils import get_files_from_everything - success, file_manifest, project_folder = get_files_from_everything(txt, type='.pdf') + success, file_manifest, project_folder = get_files_from_everything(txt, type='.pdf', chatbot=chatbot) if len(file_manifest) > 0: # 尝试导入依赖,如果缺少依赖,则给出安装建议 try: @@ -73,7 +73,7 @@ def 批量翻译PDF文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst b=f"导入软件依赖失败。使用该模块需要额外依赖,安装方法```pip install --upgrade nougat-ocr tiktoken```。") yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 return - success_mmd, file_manifest_mmd, _ = get_files_from_everything(txt, type='.mmd') + success_mmd, file_manifest_mmd, _ = get_files_from_everything(txt, type='.mmd', chatbot=chatbot) success = success or success_mmd file_manifest += file_manifest_mmd chatbot.append(["文件列表:", ", ".join([e.split('/')[-1] for e in file_manifest])]); diff --git a/crazy_functions/理解PDF文档内容.py b/crazy_functions/理解PDF文档内容.py index 23e3ce4f..bb33634d 100644 --- a/crazy_functions/理解PDF文档内容.py +++ b/crazy_functions/理解PDF文档内容.py @@ -87,6 +87,8 @@ def 理解PDF文档内容标准文件输入(txt, llm_kwargs, plugin_kwargs, chat # 检测输入参数,如没有给定输入参数,直接退出 if os.path.exists(txt): project_folder = txt + from shared_utils.fastapi_server import validate_path_safety + validate_path_safety(project_folder, chatbot.get_user()) else: if txt == "": txt = '空空如也的输入栏' diff --git a/crazy_functions/生成函数注释.py b/crazy_functions/生成函数注释.py index 64a3176c..436e7281 100644 --- a/crazy_functions/生成函数注释.py +++ b/crazy_functions/生成函数注释.py @@ -39,6 +39,8 @@ def 批量生成函数注释(txt, llm_kwargs, plugin_kwargs, chatbot, history, s import glob, os if os.path.exists(txt): project_folder = txt + from shared_utils.fastapi_server import validate_path_safety + validate_path_safety(project_folder, chatbot.get_user()) else: if txt == "": txt = '空空如也的输入栏' report_exception(chatbot, history, a = f"解析项目: {txt}", b = f"找不到本地项目或无权访问: {txt}") diff --git a/crazy_functions/知识库问答.py b/crazy_functions/知识库问答.py index f902ed09..e557628e 100644 --- a/crazy_functions/知识库问答.py +++ b/crazy_functions/知识库问答.py @@ -49,7 +49,7 @@ def 知识库文件注入(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst file_manifest = [] spl = ["txt", "doc", "docx", "email", "epub", "html", "json", "md", "msg", "pdf", "ppt", "pptx", "rtf"] for sp in spl: - _, file_manifest_tmp, _ = get_files_from_everything(txt, type=f'.{sp}') + _, file_manifest_tmp, _ = get_files_from_everything(txt, type=f'.{sp}', chatbot=chatbot) file_manifest += file_manifest_tmp if len(file_manifest) == 0: diff --git a/crazy_functions/解析JupyterNotebook.py b/crazy_functions/解析JupyterNotebook.py index e7186aa9..320ba1f3 100644 --- a/crazy_functions/解析JupyterNotebook.py +++ b/crazy_functions/解析JupyterNotebook.py @@ -126,6 +126,8 @@ def 解析ipynb文件(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_p import os if os.path.exists(txt): project_folder = txt + from shared_utils.fastapi_server import validate_path_safety + validate_path_safety(project_folder, chatbot.get_user()) else: if txt == "": txt = '空空如也的输入栏' diff --git a/crazy_functions/读文章写摘要.py b/crazy_functions/读文章写摘要.py index 1bb0d325..5e167092 100644 --- a/crazy_functions/读文章写摘要.py +++ b/crazy_functions/读文章写摘要.py @@ -48,6 +48,8 @@ def 读文章写摘要(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_ import glob, os if os.path.exists(txt): project_folder = txt + from shared_utils.fastapi_server import validate_path_safety + validate_path_safety(project_folder, chatbot.get_user()) else: if txt == "": txt = '空空如也的输入栏' report_exception(chatbot, history, a = f"解析项目: {txt}", b = f"找不到本地项目或无权访问: {txt}") diff --git a/shared_utils/fastapi_server.py b/shared_utils/fastapi_server.py index 2993c987..907f4d3f 100644 --- a/shared_utils/fastapi_server.py +++ b/shared_utils/fastapi_server.py @@ -51,7 +51,7 @@ def validate_path_safety(path_or_url, user): from toolbox import get_conf, default_user_name from toolbox import FriendlyException PATH_PRIVATE_UPLOAD, PATH_LOGGING = get_conf('PATH_PRIVATE_UPLOAD', 'PATH_LOGGING') - sensitive_path = None + sensitive_path = None # 必须不能包含 '/',即不能是多级路径 path_or_url = os.path.relpath(path_or_url) if path_or_url.startswith(PATH_LOGGING): # 日志文件(按用户划分) sensitive_path = PATH_LOGGING diff --git a/toolbox.py b/toolbox.py index 15f4ff6d..bfe85a5a 100644 --- a/toolbox.py +++ b/toolbox.py @@ -499,6 +499,22 @@ def to_markdown_tabs(head: list, tabs: list, alignment=":---:", column=False, om return tabs_list +def validate_file_size(files, max_size_mb=500): + """ + 验证文件大小是否在允许范围内。 + :param files: 文件的完整路径的列表 + :param max_size_mb: 最大文件大小,单位为MB(默认500MB) + :return: True 如果文件大小有效,否则抛出异常 + """ + # 获取文件大小(字节) + total_size = 0 + max_size_bytes = max_size_mb * 1024 * 1024 + for file in files: + total_size += os.path.getsize(file.name) + if total_size > max_size_bytes: + raise ValueError(f"File size exceeds the allowed limit of {max_size_mb} MB. " + f"Current size: {total_size / (1024 * 1024):.2f} MB") + return True def on_file_uploaded( request: gradio.Request, files:List[str], chatbot:ChatBotWithCookies, @@ -510,6 +526,7 @@ def on_file_uploaded( if len(files) == 0: return chatbot, txt + validate_file_size(files, max_size_mb=500) # 创建工作路径 user_name = default_user_name if not request.username else request.username time_tag = gen_time_str()