diff --git a/.gitignore b/.gitignore index 4fb8a7df..6d0e0cce 100644 --- a/.gitignore +++ b/.gitignore @@ -154,5 +154,5 @@ flagged request_llms/ChatGLM-6b-onnx-u8s8 .pre-commit-config.yaml themes/common.js.min.*.js -test* +test.html objdump* \ No newline at end of file diff --git a/crazy_functions/latex_fns/latex_pickle_io.py b/crazy_functions/latex_fns/latex_pickle_io.py index f08c78c1..451d735b 100644 --- a/crazy_functions/latex_fns/latex_pickle_io.py +++ b/crazy_functions/latex_fns/latex_pickle_io.py @@ -8,16 +8,20 @@ class SafeUnpickler(pickle.Unpickler): # 定义允许的安全类 safe_classes = { # 在这里添加其他安全的类 - 'latex_actions.LatexPaperFileGroup': LatexPaperFileGroup, - 'latex_actions.LatexPaperSplit' : LatexPaperSplit, + 'LatexPaperFileGroup': LatexPaperFileGroup, + 'LatexPaperSplit' : LatexPaperSplit, } return safe_classes def find_class(self, module, name): # 只允许特定的类进行反序列化 self.safe_classes = self.get_safe_classes() - if f'{module}.{name}' in self.safe_classes: - return self.safe_classes[f'{module}.{name}'] + match_class_name = None + for class_name in self.safe_classes.keys(): + if (class_name in f'{module}.{name}'): + match_class_name = class_name + if match_class_name is not None: + return self.safe_classes[match_class_name] # 如果尝试加载未授权的类,则抛出异常 raise pickle.UnpicklingError(f"Attempted to deserialize unauthorized class '{name}' from module '{module}'") diff --git a/tests/test_safe_pickle.py b/tests/test_safe_pickle.py new file mode 100644 index 00000000..01f69562 --- /dev/null +++ b/tests/test_safe_pickle.py @@ -0,0 +1,17 @@ +def validate_path(): + import os, sys + os.path.dirname(__file__) + root_dir_assume = os.path.abspath(os.path.dirname(__file__) + "/..") + os.chdir(root_dir_assume) + sys.path.append(root_dir_assume) +validate_path() # validate path so you can run from base directory + +from crazy_functions.latex_fns.latex_pickle_io import objdump, objload +from crazy_functions.latex_fns.latex_actions import LatexPaperFileGroup, LatexPaperSplit +pfg = LatexPaperFileGroup() +pfg.get_token_num = None +pfg.target = "target_elem" +x = objdump(pfg) +t = objload() + +print(t.target) \ No newline at end of file diff --git a/tests/test_save_chat_to_html.py b/tests/test_save_chat_to_html.py new file mode 100644 index 00000000..8f69a266 --- /dev/null +++ b/tests/test_save_chat_to_html.py @@ -0,0 +1,102 @@ +def validate_path(): + import os, sys + os.path.dirname(__file__) + root_dir_assume = os.path.abspath(os.path.dirname(__file__) + "/..") + os.chdir(root_dir_assume) + sys.path.append(root_dir_assume) +validate_path() # validate path so you can run from base directory + +def write_chat_to_file(chatbot, history=None, file_name=None): + """ + 将对话记录history以Markdown格式写入文件中。如果没有指定文件名,则使用当前时间生成文件名。 + """ + import os + import time + from themes.theme import advanced_css + # debug + import pickle + # def objdump(obj, file="objdump.tmp"): + # with open(file, "wb+") as f: + # pickle.dump(obj, f) + # return + + def objload(file="objdump.tmp"): + import os + if not os.path.exists(file): + return + with open(file, "rb") as f: + return pickle.load(f) + # objdump((chatbot, history)) + chatbot, history = objload() + + with open("test.html", 'w', encoding='utf8') as f: + from textwrap import dedent + form = dedent(""" +