From 15cc08505f7fbffb4ef1a9027f74e5e94b7a8947 Mon Sep 17 00:00:00 2001 From: binary-husky Date: Wed, 19 Jun 2024 11:59:47 +0000 Subject: [PATCH] resolve safe pickle err --- .gitignore | 2 +- crazy_functions/latex_fns/latex_pickle_io.py | 12 ++- tests/test_safe_pickle.py | 17 ++++ tests/test_save_chat_to_html.py | 102 +++++++++++++++++++ 4 files changed, 128 insertions(+), 5 deletions(-) create mode 100644 tests/test_safe_pickle.py create mode 100644 tests/test_save_chat_to_html.py diff --git a/.gitignore b/.gitignore index 4fb8a7df..6d0e0cce 100644 --- a/.gitignore +++ b/.gitignore @@ -154,5 +154,5 @@ flagged request_llms/ChatGLM-6b-onnx-u8s8 .pre-commit-config.yaml themes/common.js.min.*.js -test* +test.html objdump* \ No newline at end of file diff --git a/crazy_functions/latex_fns/latex_pickle_io.py b/crazy_functions/latex_fns/latex_pickle_io.py index f08c78c1..451d735b 100644 --- a/crazy_functions/latex_fns/latex_pickle_io.py +++ b/crazy_functions/latex_fns/latex_pickle_io.py @@ -8,16 +8,20 @@ class SafeUnpickler(pickle.Unpickler): # 定义允许的安全类 safe_classes = { # 在这里添加其他安全的类 - 'latex_actions.LatexPaperFileGroup': LatexPaperFileGroup, - 'latex_actions.LatexPaperSplit' : LatexPaperSplit, + 'LatexPaperFileGroup': LatexPaperFileGroup, + 'LatexPaperSplit' : LatexPaperSplit, } return safe_classes def find_class(self, module, name): # 只允许特定的类进行反序列化 self.safe_classes = self.get_safe_classes() - if f'{module}.{name}' in self.safe_classes: - return self.safe_classes[f'{module}.{name}'] + match_class_name = None + for class_name in self.safe_classes.keys(): + if (class_name in f'{module}.{name}'): + match_class_name = class_name + if match_class_name is not None: + return self.safe_classes[match_class_name] # 如果尝试加载未授权的类,则抛出异常 raise pickle.UnpicklingError(f"Attempted to deserialize unauthorized class '{name}' from module '{module}'") diff --git a/tests/test_safe_pickle.py b/tests/test_safe_pickle.py new file mode 100644 index 00000000..01f69562 --- /dev/null +++ b/tests/test_safe_pickle.py @@ -0,0 +1,17 @@ +def validate_path(): + import os, sys + os.path.dirname(__file__) + root_dir_assume = os.path.abspath(os.path.dirname(__file__) + "/..") + os.chdir(root_dir_assume) + sys.path.append(root_dir_assume) +validate_path() # validate path so you can run from base directory + +from crazy_functions.latex_fns.latex_pickle_io import objdump, objload +from crazy_functions.latex_fns.latex_actions import LatexPaperFileGroup, LatexPaperSplit +pfg = LatexPaperFileGroup() +pfg.get_token_num = None +pfg.target = "target_elem" +x = objdump(pfg) +t = objload() + +print(t.target) \ No newline at end of file diff --git a/tests/test_save_chat_to_html.py b/tests/test_save_chat_to_html.py new file mode 100644 index 00000000..8f69a266 --- /dev/null +++ b/tests/test_save_chat_to_html.py @@ -0,0 +1,102 @@ +def validate_path(): + import os, sys + os.path.dirname(__file__) + root_dir_assume = os.path.abspath(os.path.dirname(__file__) + "/..") + os.chdir(root_dir_assume) + sys.path.append(root_dir_assume) +validate_path() # validate path so you can run from base directory + +def write_chat_to_file(chatbot, history=None, file_name=None): + """ + 将对话记录history以Markdown格式写入文件中。如果没有指定文件名,则使用当前时间生成文件名。 + """ + import os + import time + from themes.theme import advanced_css + # debug + import pickle + # def objdump(obj, file="objdump.tmp"): + # with open(file, "wb+") as f: + # pickle.dump(obj, f) + # return + + def objload(file="objdump.tmp"): + import os + if not os.path.exists(file): + return + with open(file, "rb") as f: + return pickle.load(f) + # objdump((chatbot, history)) + chatbot, history = objload() + + with open("test.html", 'w', encoding='utf8') as f: + from textwrap import dedent + form = dedent(""" + 对话存档 + +
+
+
+ {CHAT_PREVIEW} +
+
+
对话(原始数据)
+ {HISTORY_PREVIEW} +
+
+
+ + """) + + qa_from = dedent(""" +
+
{QUESTION}
+
+
{ANSWER}
+
+ """) + + history_from = dedent(""" +
+
{ENTRY}
+
+ """) + CHAT_PREVIEW_BUF = "" + for i, contents in enumerate(chatbot): + question, answer = contents[0], contents[1] + if question is None: question = "" + try: question = str(question) + except: question = "" + if answer is None: answer = "" + try: answer = str(answer) + except: answer = "" + CHAT_PREVIEW_BUF += qa_from.format(QUESTION=question, ANSWER=answer) + + HISTORY_PREVIEW_BUF = "" + for h in history: + HISTORY_PREVIEW_BUF += history_from.format(ENTRY=h) + html_content = form.format(CHAT_PREVIEW=CHAT_PREVIEW_BUF, HISTORY_PREVIEW=HISTORY_PREVIEW_BUF, CSS=advanced_css) + + + from bs4 import BeautifulSoup + soup = BeautifulSoup(html_content, 'lxml') + + # 提取QaBox信息 + qa_box_list = [] + qa_boxes = soup.find_all("div", class_="QaBox") + for box in qa_boxes: + question = box.find("div", class_="Question").get_text(strip=False) + answer = box.find("div", class_="Answer").get_text(strip=False) + qa_box_list.append({"Question": question, "Answer": answer}) + + # 提取historyBox信息 + history_box_list = [] + history_boxes = soup.find_all("div", class_="historyBox") + for box in history_boxes: + entry = box.find("div", class_="entry").get_text(strip=False) + history_box_list.append(entry) + + print('') + + +write_chat_to_file(None, None, None) \ No newline at end of file