镜像自地址
https://github.com/binary-husky/gpt_academic.git
已同步 2025-12-07 15:06:48 +00:00
qw
这个提交包含在:
@@ -30,7 +30,7 @@ def 知识库问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_pro
|
||||
)
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
from .crazy_utils import try_install_deps
|
||||
try_install_deps(['zh_langchain==0.2.1'])
|
||||
try_install_deps(['zh_langchain==0.2.1', 'pypinyin'])
|
||||
|
||||
# < --------------------读取参数--------------- >
|
||||
if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
|
||||
|
||||
@@ -157,7 +157,7 @@ def Latex英文纠错加PDF对比(txt, llm_kwargs, plugin_kwargs, chatbot, histo
|
||||
try:
|
||||
import glob, os, time, subprocess
|
||||
subprocess.Popen(['pdflatex', '-version'])
|
||||
from .latex_utils import Latex精细分解与转化, 编译Latex
|
||||
from .latex_fns.latex_actions import Latex精细分解与转化, 编译Latex
|
||||
except Exception as e:
|
||||
chatbot.append([ f"解析项目: {txt}",
|
||||
f"尝试执行Latex指令失败。Latex没有安装, 或者不在环境变量PATH中。安装方法https://tug.org/texlive/。报错信息\n\n```\n\n{trimmed_format_exc()}\n\n```\n\n"])
|
||||
@@ -234,7 +234,7 @@ def Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot,
|
||||
try:
|
||||
import glob, os, time, subprocess
|
||||
subprocess.Popen(['pdflatex', '-version'])
|
||||
from .latex_utils import Latex精细分解与转化, 编译Latex
|
||||
from .latex_fns.latex_actions import Latex精细分解与转化, 编译Latex
|
||||
except Exception as e:
|
||||
chatbot.append([ f"解析项目: {txt}",
|
||||
f"尝试执行Latex指令失败。Latex没有安装, 或者不在环境变量PATH中。安装方法https://tug.org/texlive/。报错信息\n\n```\n\n{trimmed_format_exc()}\n\n```\n\n"])
|
||||
|
||||
141
crazy_functions/chatglm微调工具.py
普通文件
141
crazy_functions/chatglm微调工具.py
普通文件
@@ -0,0 +1,141 @@
|
||||
from toolbox import CatchException, update_ui, promote_file_to_downloadzone
|
||||
from .crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||
import datetime, json
|
||||
|
||||
def fetch_items(list_of_items, batch_size):
|
||||
for i in range(0, len(list_of_items), batch_size):
|
||||
yield list_of_items[i:i + batch_size]
|
||||
|
||||
def string_to_options(arguments):
|
||||
import argparse
|
||||
import shlex
|
||||
|
||||
# Create an argparse.ArgumentParser instance
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Add command-line arguments
|
||||
parser.add_argument("--llm_to_learn", type=str, help="LLM model to learn", default="gpt-3.5-turbo")
|
||||
parser.add_argument("--prompt_prefix", type=str, help="Prompt prefix", default='')
|
||||
parser.add_argument("--system_prompt", type=str, help="System prompt", default='')
|
||||
parser.add_argument("--batch", type=int, help="System prompt", default=50)
|
||||
parser.add_argument("--pre_seq_len", type=int, help="pre_seq_len", default=50)
|
||||
parser.add_argument("--learning_rate", type=float, help="learning_rate", default=2e-2)
|
||||
parser.add_argument("--num_gpus", type=int, help="num_gpus", default=1)
|
||||
parser.add_argument("--json_dataset", type=str, help="json_dataset", default="")
|
||||
parser.add_argument("--ptuning_directory", type=str, help="ptuning_directory", default="")
|
||||
|
||||
|
||||
|
||||
# Parse the arguments
|
||||
args = parser.parse_args(shlex.split(arguments))
|
||||
|
||||
return args
|
||||
|
||||
@CatchException
|
||||
def 微调数据集生成(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||
"""
|
||||
txt 输入栏用户输入的文本,例如需要翻译的一段话,再例如一个包含了待处理文件的路径
|
||||
llm_kwargs gpt模型参数,如温度和top_p等,一般原样传递下去就行
|
||||
plugin_kwargs 插件模型的参数
|
||||
chatbot 聊天显示框的句柄,用于显示给用户
|
||||
history 聊天历史,前情提要
|
||||
system_prompt 给gpt的静默提醒
|
||||
web_port 当前软件运行的端口号
|
||||
"""
|
||||
history = [] # 清空历史,以免输入溢出
|
||||
chatbot.append(("这是什么功能?", "[Local Message] 微调数据集生成"))
|
||||
if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
|
||||
args = plugin_kwargs.get("advanced_arg", None)
|
||||
if args is None:
|
||||
chatbot.append(("没给定指令", "退出"))
|
||||
yield from update_ui(chatbot=chatbot, history=history); return
|
||||
else:
|
||||
arguments = string_to_options(arguments=args)
|
||||
|
||||
dat = []
|
||||
with open(txt, 'r', encoding='utf8') as f:
|
||||
for line in f.readlines():
|
||||
json_dat = json.loads(line)
|
||||
dat.append(json_dat["content"])
|
||||
|
||||
llm_kwargs['llm_model'] = arguments.llm_to_learn
|
||||
for batch in fetch_items(dat, arguments.batch):
|
||||
res = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
|
||||
inputs_array=[f"{arguments.prompt_prefix}\n\n{b}" for b in (batch)],
|
||||
inputs_show_user_array=[f"Show Nothing" for _ in (batch)],
|
||||
llm_kwargs=llm_kwargs,
|
||||
chatbot=chatbot,
|
||||
history_array=[[] for _ in (batch)],
|
||||
sys_prompt_array=[arguments.system_prompt for _ in (batch)],
|
||||
max_workers=10 # OpenAI所允许的最大并行过载
|
||||
)
|
||||
|
||||
with open(txt+'.generated.json', 'a+', encoding='utf8') as f:
|
||||
for b, r in zip(batch, res[1::2]):
|
||||
f.write(json.dumps({"content":b, "summary":r}, ensure_ascii=False)+'\n')
|
||||
|
||||
promote_file_to_downloadzone(txt+'.generated.json', rename_file='generated.json', chatbot=chatbot)
|
||||
return
|
||||
|
||||
|
||||
|
||||
@CatchException
|
||||
def 启动微调(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||
"""
|
||||
txt 输入栏用户输入的文本,例如需要翻译的一段话,再例如一个包含了待处理文件的路径
|
||||
llm_kwargs gpt模型参数,如温度和top_p等,一般原样传递下去就行
|
||||
plugin_kwargs 插件模型的参数
|
||||
chatbot 聊天显示框的句柄,用于显示给用户
|
||||
history 聊天历史,前情提要
|
||||
system_prompt 给gpt的静默提醒
|
||||
web_port 当前软件运行的端口号
|
||||
"""
|
||||
import subprocess
|
||||
history = [] # 清空历史,以免输入溢出
|
||||
chatbot.append(("这是什么功能?", "[Local Message] 微调数据集生成"))
|
||||
if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
|
||||
args = plugin_kwargs.get("advanced_arg", None)
|
||||
if args is None:
|
||||
chatbot.append(("没给定指令", "退出"))
|
||||
yield from update_ui(chatbot=chatbot, history=history); return
|
||||
else:
|
||||
arguments = string_to_options(arguments=args)
|
||||
|
||||
|
||||
|
||||
pre_seq_len = arguments.pre_seq_len # 128
|
||||
learning_rate = arguments.learning_rate # 2e-2
|
||||
num_gpus = arguments.num_gpus # 1
|
||||
json_dataset = arguments.json_dataset # 't_code.json'
|
||||
ptuning_directory = arguments.ptuning_directory # '/home/hmp/ChatGLM2-6B/ptuning'
|
||||
|
||||
command = f"torchrun --standalone --nnodes=1 --nproc-per-node={num_gpus} main.py \
|
||||
--do_train \
|
||||
--train_file AdvertiseGen/{json_dataset} \
|
||||
--validation_file AdvertiseGen/{json_dataset} \
|
||||
--preprocessing_num_workers 20 \
|
||||
--prompt_column content \
|
||||
--response_column summary \
|
||||
--overwrite_cache \
|
||||
--model_name_or_path THUDM/chatglm2-6b \
|
||||
--output_dir output/clothgen-chatglm2-6b-pt-{pre_seq_len}-{learning_rate} \
|
||||
--overwrite_output_dir \
|
||||
--max_source_length 256 \
|
||||
--max_target_length 256 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 16 \
|
||||
--predict_with_generate \
|
||||
--max_steps 100 \
|
||||
--logging_steps 10 \
|
||||
--save_steps 20 \
|
||||
--learning_rate {learning_rate} \
|
||||
--pre_seq_len {pre_seq_len} \
|
||||
--quantization_bit 4"
|
||||
|
||||
process = subprocess.Popen(command, shell=True, cwd=ptuning_directory)
|
||||
try:
|
||||
process.communicate(timeout=3600*24)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
return
|
||||
@@ -1,231 +0,0 @@
|
||||
"""
|
||||
这是什么?
|
||||
这个文件用于函数插件的单元测试
|
||||
运行方法 python crazy_functions/crazy_functions_test.py
|
||||
"""
|
||||
|
||||
# ==============================================================================================================================
|
||||
|
||||
def validate_path():
|
||||
import os, sys
|
||||
dir_name = os.path.dirname(__file__)
|
||||
root_dir_assume = os.path.abspath(os.path.dirname(__file__) + '/..')
|
||||
os.chdir(root_dir_assume)
|
||||
sys.path.append(root_dir_assume)
|
||||
validate_path() # validate path so you can run from base directory
|
||||
|
||||
# ==============================================================================================================================
|
||||
|
||||
from colorful import *
|
||||
from toolbox import get_conf, ChatBotWithCookies
|
||||
import contextlib
|
||||
import os
|
||||
import sys
|
||||
from functools import wraps
|
||||
proxies, WEB_PORT, LLM_MODEL, CONCURRENT_COUNT, AUTHENTICATION, CHATBOT_HEIGHT, LAYOUT, API_KEY = \
|
||||
get_conf('proxies', 'WEB_PORT', 'LLM_MODEL', 'CONCURRENT_COUNT', 'AUTHENTICATION', 'CHATBOT_HEIGHT', 'LAYOUT', 'API_KEY')
|
||||
|
||||
llm_kwargs = {
|
||||
'api_key': API_KEY,
|
||||
'llm_model': LLM_MODEL,
|
||||
'top_p':1.0,
|
||||
'max_length': None,
|
||||
'temperature':1.0,
|
||||
}
|
||||
plugin_kwargs = { }
|
||||
chatbot = ChatBotWithCookies(llm_kwargs)
|
||||
history = []
|
||||
system_prompt = "Serve me as a writing and programming assistant."
|
||||
web_port = 1024
|
||||
|
||||
# ==============================================================================================================================
|
||||
|
||||
def silence_stdout(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
_original_stdout = sys.stdout
|
||||
sys.stdout = open(os.devnull, 'w')
|
||||
for q in func(*args, **kwargs):
|
||||
sys.stdout = _original_stdout
|
||||
yield q
|
||||
sys.stdout = open(os.devnull, 'w')
|
||||
sys.stdout.close()
|
||||
sys.stdout = _original_stdout
|
||||
return wrapper
|
||||
|
||||
class CLI_Printer():
|
||||
def __init__(self) -> None:
|
||||
self.pre_buf = ""
|
||||
|
||||
def print(self, buf):
|
||||
bufp = ""
|
||||
for index, chat in enumerate(buf):
|
||||
a, b = chat
|
||||
bufp += sprint亮靛('[Me]:' + a) + '\n'
|
||||
bufp += '[GPT]:' + b
|
||||
if index < len(buf)-1:
|
||||
bufp += '\n'
|
||||
|
||||
if self.pre_buf!="" and bufp.startswith(self.pre_buf):
|
||||
print(bufp[len(self.pre_buf):], end='')
|
||||
else:
|
||||
print('\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n'+bufp, end='')
|
||||
self.pre_buf = bufp
|
||||
return
|
||||
|
||||
cli_printer = CLI_Printer()
|
||||
# ==============================================================================================================================
|
||||
def test_解析一个Python项目():
|
||||
from crazy_functions.解析项目源代码 import 解析一个Python项目
|
||||
txt = "crazy_functions/test_project/python/dqn"
|
||||
for cookies, cb, hist, msg in 解析一个Python项目(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||
print(cb)
|
||||
|
||||
def test_解析一个Cpp项目():
|
||||
from crazy_functions.解析项目源代码 import 解析一个C项目
|
||||
txt = "crazy_functions/test_project/cpp/cppipc"
|
||||
for cookies, cb, hist, msg in 解析一个C项目(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||
print(cb)
|
||||
|
||||
def test_Latex英文润色():
|
||||
from crazy_functions.Latex全文润色 import Latex英文润色
|
||||
txt = "crazy_functions/test_project/latex/attention"
|
||||
for cookies, cb, hist, msg in Latex英文润色(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||
print(cb)
|
||||
|
||||
def test_Markdown中译英():
|
||||
from crazy_functions.批量Markdown翻译 import Markdown中译英
|
||||
txt = "README.md"
|
||||
for cookies, cb, hist, msg in Markdown中译英(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||
print(cb)
|
||||
|
||||
def test_批量翻译PDF文档():
|
||||
from crazy_functions.批量翻译PDF文档_多线程 import 批量翻译PDF文档
|
||||
txt = "crazy_functions/test_project/pdf_and_word"
|
||||
for cookies, cb, hist, msg in 批量翻译PDF文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||
print(cb)
|
||||
|
||||
def test_谷歌检索小助手():
|
||||
from crazy_functions.谷歌检索小助手 import 谷歌检索小助手
|
||||
txt = "https://scholar.google.com/scholar?hl=en&as_sdt=0%2C5&q=auto+reinforcement+learning&btnG="
|
||||
for cookies, cb, hist, msg in 谷歌检索小助手(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||
print(cb)
|
||||
|
||||
def test_总结word文档():
|
||||
from crazy_functions.总结word文档 import 总结word文档
|
||||
txt = "crazy_functions/test_project/pdf_and_word"
|
||||
for cookies, cb, hist, msg in 总结word文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||
print(cb)
|
||||
|
||||
def test_下载arxiv论文并翻译摘要():
|
||||
from crazy_functions.下载arxiv论文翻译摘要 import 下载arxiv论文并翻译摘要
|
||||
txt = "1812.10695"
|
||||
for cookies, cb, hist, msg in 下载arxiv论文并翻译摘要(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||
print(cb)
|
||||
|
||||
def test_联网回答问题():
|
||||
from crazy_functions.联网的ChatGPT import 连接网络回答问题
|
||||
# txt = "谁是应急食品?"
|
||||
# >> '根据以上搜索结果可以得知,应急食品是“原神”游戏中的角色派蒙的外号。'
|
||||
# txt = "道路千万条,安全第一条。后面两句是?"
|
||||
# >> '行车不规范,亲人两行泪。'
|
||||
# txt = "You should have gone for the head. What does that mean?"
|
||||
# >> The phrase "You should have gone for the head" is a quote from the Marvel movies, Avengers: Infinity War and Avengers: Endgame. It was spoken by the character Thanos in Infinity War and by Thor in Endgame.
|
||||
txt = "AutoGPT是什么?"
|
||||
for cookies, cb, hist, msg in 连接网络回答问题(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||
print("当前问答:", cb[-1][-1].replace("\n"," "))
|
||||
for i, it in enumerate(cb): print亮蓝(it[0]); print亮黄(it[1])
|
||||
|
||||
def test_解析ipynb文件():
|
||||
from crazy_functions.解析JupyterNotebook import 解析ipynb文件
|
||||
txt = "crazy_functions/test_samples"
|
||||
for cookies, cb, hist, msg in 解析ipynb文件(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||
print(cb)
|
||||
|
||||
|
||||
def test_数学动画生成manim():
|
||||
from crazy_functions.数学动画生成manim import 动画生成
|
||||
txt = "A ball split into 2, and then split into 4, and finally split into 8."
|
||||
for cookies, cb, hist, msg in 动画生成(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||
print(cb)
|
||||
|
||||
|
||||
|
||||
def test_Markdown多语言():
|
||||
from crazy_functions.批量Markdown翻译 import Markdown翻译指定语言
|
||||
txt = "README.md"
|
||||
history = []
|
||||
for lang in ["English", "French", "Japanese", "Korean", "Russian", "Italian", "German", "Portuguese", "Arabic"]:
|
||||
plugin_kwargs = {"advanced_arg": lang}
|
||||
for cookies, cb, hist, msg in Markdown翻译指定语言(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||
print(cb)
|
||||
|
||||
def test_Langchain知识库():
|
||||
from crazy_functions.Langchain知识库 import 知识库问答
|
||||
txt = "./"
|
||||
chatbot = ChatBotWithCookies(llm_kwargs)
|
||||
for cookies, cb, hist, msg in silence_stdout(知识库问答)(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||
cli_printer.print(cb) # print(cb)
|
||||
|
||||
chatbot = ChatBotWithCookies(cookies)
|
||||
from crazy_functions.Langchain知识库 import 读取知识库作答
|
||||
txt = "What is the installation method?"
|
||||
for cookies, cb, hist, msg in silence_stdout(读取知识库作答)(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||
cli_printer.print(cb) # print(cb)
|
||||
|
||||
def test_Langchain知识库读取():
|
||||
from crazy_functions.Langchain知识库 import 读取知识库作答
|
||||
txt = "远程云服务器部署?"
|
||||
for cookies, cb, hist, msg in silence_stdout(读取知识库作答)(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||
cli_printer.print(cb) # print(cb)
|
||||
|
||||
def test_Latex():
|
||||
from crazy_functions.Latex输出PDF结果 import Latex英文纠错加PDF对比, Latex翻译中文并重新编译PDF
|
||||
|
||||
# txt = r"https://arxiv.org/abs/1706.03762"
|
||||
# txt = r"https://arxiv.org/abs/1902.03185"
|
||||
# txt = r"https://arxiv.org/abs/2305.18290"
|
||||
# txt = r"https://arxiv.org/abs/2305.17608"
|
||||
# txt = r"https://arxiv.org/abs/2211.16068" # ACE
|
||||
# txt = r"C:\Users\x\arxiv_cache\2211.16068\workfolder" # ACE
|
||||
# txt = r"https://arxiv.org/abs/2002.09253"
|
||||
# txt = r"https://arxiv.org/abs/2306.07831"
|
||||
# txt = r"https://arxiv.org/abs/2212.10156"
|
||||
# txt = r"https://arxiv.org/abs/2211.11559"
|
||||
# txt = r"https://arxiv.org/abs/2303.08774"
|
||||
txt = r"https://arxiv.org/abs/2303.12712"
|
||||
# txt = r"C:\Users\fuqingxu\arxiv_cache\2303.12712\workfolder"
|
||||
|
||||
|
||||
for cookies, cb, hist, msg in (Latex翻译中文并重新编译PDF)(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||
cli_printer.print(cb) # print(cb)
|
||||
|
||||
|
||||
|
||||
# txt = "2302.02948.tar"
|
||||
# print(txt)
|
||||
# main_tex, work_folder = Latex预处理(txt)
|
||||
# print('main tex:', main_tex)
|
||||
# res = 编译Latex(main_tex, work_folder)
|
||||
# # for cookies, cb, hist, msg in silence_stdout(编译Latex)(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||
# cli_printer.print(cb) # print(cb)
|
||||
|
||||
|
||||
|
||||
# test_解析一个Python项目()
|
||||
# test_Latex英文润色()
|
||||
# test_Markdown中译英()
|
||||
# test_批量翻译PDF文档()
|
||||
# test_谷歌检索小助手()
|
||||
# test_总结word文档()
|
||||
# test_下载arxiv论文并翻译摘要()
|
||||
# test_解析一个Cpp项目()
|
||||
# test_联网回答问题()
|
||||
# test_解析ipynb文件()
|
||||
# test_数学动画生成manim()
|
||||
# test_Langchain知识库()
|
||||
# test_Langchain知识库读取()
|
||||
if __name__ == "__main__":
|
||||
test_Latex()
|
||||
input("程序完成,回车退出。")
|
||||
print("退出。")
|
||||
@@ -130,6 +130,11 @@ def request_gpt_model_in_new_thread_with_ui_alive(
|
||||
yield from update_ui(chatbot=chatbot, history=[]) # 如果最后成功了,则删除报错信息
|
||||
return final_result
|
||||
|
||||
def can_multi_process(llm):
|
||||
if llm.startswith('gpt-'): return True
|
||||
if llm.startswith('api2d-'): return True
|
||||
if llm.startswith('azure-'): return True
|
||||
return False
|
||||
|
||||
def request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
|
||||
inputs_array, inputs_show_user_array, llm_kwargs,
|
||||
@@ -175,7 +180,7 @@ def request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
|
||||
except: max_workers = 8
|
||||
if max_workers <= 0: max_workers = 3
|
||||
# 屏蔽掉 chatglm的多线程,可能会导致严重卡顿
|
||||
if not (llm_kwargs['llm_model'].startswith('gpt-') or llm_kwargs['llm_model'].startswith('api2d-')):
|
||||
if not can_multi_process(llm_kwargs['llm_model']):
|
||||
max_workers = 1
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
|
||||
@@ -1,311 +1,16 @@
|
||||
from toolbox import update_ui, update_ui_lastest_msg # 刷新Gradio前端界面
|
||||
from toolbox import zip_folder, objdump, objload, promote_file_to_downloadzone
|
||||
from .latex_toolbox import PRESERVE, TRANSFORM
|
||||
from .latex_toolbox import set_forbidden_text, set_forbidden_text_begin_end, set_forbidden_text_careful_brace
|
||||
from .latex_toolbox import reverse_forbidden_text_careful_brace, reverse_forbidden_text, convert_to_linklist, post_process
|
||||
from .latex_toolbox import fix_content, find_main_tex_file, merge_tex_files, compile_latex_with_timeout
|
||||
|
||||
import os, shutil
|
||||
import re
|
||||
import numpy as np
|
||||
|
||||
pj = os.path.join
|
||||
|
||||
"""
|
||||
========================================================================
|
||||
Part One
|
||||
Latex segmentation with a binary mask (PRESERVE=0, TRANSFORM=1)
|
||||
========================================================================
|
||||
"""
|
||||
PRESERVE = 0
|
||||
TRANSFORM = 1
|
||||
|
||||
def set_forbidden_text(text, mask, pattern, flags=0):
|
||||
"""
|
||||
Add a preserve text area in this paper
|
||||
e.g. with pattern = r"\\begin\{algorithm\}(.*?)\\end\{algorithm\}"
|
||||
you can mask out (mask = PRESERVE so that text become untouchable for GPT)
|
||||
everything between "\begin{equation}" and "\end{equation}"
|
||||
"""
|
||||
if isinstance(pattern, list): pattern = '|'.join(pattern)
|
||||
pattern_compile = re.compile(pattern, flags)
|
||||
for res in pattern_compile.finditer(text):
|
||||
mask[res.span()[0]:res.span()[1]] = PRESERVE
|
||||
return text, mask
|
||||
|
||||
def reverse_forbidden_text(text, mask, pattern, flags=0, forbid_wrapper=True):
|
||||
"""
|
||||
Move area out of preserve area (make text editable for GPT)
|
||||
count the number of the braces so as to catch compelete text area.
|
||||
e.g.
|
||||
\begin{abstract} blablablablablabla. \end{abstract}
|
||||
"""
|
||||
if isinstance(pattern, list): pattern = '|'.join(pattern)
|
||||
pattern_compile = re.compile(pattern, flags)
|
||||
for res in pattern_compile.finditer(text):
|
||||
if not forbid_wrapper:
|
||||
mask[res.span()[0]:res.span()[1]] = TRANSFORM
|
||||
else:
|
||||
mask[res.regs[0][0]: res.regs[1][0]] = PRESERVE # '\\begin{abstract}'
|
||||
mask[res.regs[1][0]: res.regs[1][1]] = TRANSFORM # abstract
|
||||
mask[res.regs[1][1]: res.regs[0][1]] = PRESERVE # abstract
|
||||
return text, mask
|
||||
|
||||
def set_forbidden_text_careful_brace(text, mask, pattern, flags=0):
|
||||
"""
|
||||
Add a preserve text area in this paper (text become untouchable for GPT).
|
||||
count the number of the braces so as to catch compelete text area.
|
||||
e.g.
|
||||
\caption{blablablablabla\texbf{blablabla}blablabla.}
|
||||
"""
|
||||
pattern_compile = re.compile(pattern, flags)
|
||||
for res in pattern_compile.finditer(text):
|
||||
brace_level = -1
|
||||
p = begin = end = res.regs[0][0]
|
||||
for _ in range(1024*16):
|
||||
if text[p] == '}' and brace_level == 0: break
|
||||
elif text[p] == '}': brace_level -= 1
|
||||
elif text[p] == '{': brace_level += 1
|
||||
p += 1
|
||||
end = p+1
|
||||
mask[begin:end] = PRESERVE
|
||||
return text, mask
|
||||
|
||||
def reverse_forbidden_text_careful_brace(text, mask, pattern, flags=0, forbid_wrapper=True):
|
||||
"""
|
||||
Move area out of preserve area (make text editable for GPT)
|
||||
count the number of the braces so as to catch compelete text area.
|
||||
e.g.
|
||||
\caption{blablablablabla\texbf{blablabla}blablabla.}
|
||||
"""
|
||||
pattern_compile = re.compile(pattern, flags)
|
||||
for res in pattern_compile.finditer(text):
|
||||
brace_level = 0
|
||||
p = begin = end = res.regs[1][0]
|
||||
for _ in range(1024*16):
|
||||
if text[p] == '}' and brace_level == 0: break
|
||||
elif text[p] == '}': brace_level -= 1
|
||||
elif text[p] == '{': brace_level += 1
|
||||
p += 1
|
||||
end = p
|
||||
mask[begin:end] = TRANSFORM
|
||||
if forbid_wrapper:
|
||||
mask[res.regs[0][0]:begin] = PRESERVE
|
||||
mask[end:res.regs[0][1]] = PRESERVE
|
||||
return text, mask
|
||||
|
||||
def set_forbidden_text_begin_end(text, mask, pattern, flags=0, limit_n_lines=42):
|
||||
"""
|
||||
Find all \begin{} ... \end{} text block that with less than limit_n_lines lines.
|
||||
Add it to preserve area
|
||||
"""
|
||||
pattern_compile = re.compile(pattern, flags)
|
||||
def search_with_line_limit(text, mask):
|
||||
for res in pattern_compile.finditer(text):
|
||||
cmd = res.group(1) # begin{what}
|
||||
this = res.group(2) # content between begin and end
|
||||
this_mask = mask[res.regs[2][0]:res.regs[2][1]]
|
||||
white_list = ['document', 'abstract', 'lemma', 'definition', 'sproof',
|
||||
'em', 'emph', 'textit', 'textbf', 'itemize', 'enumerate']
|
||||
if (cmd in white_list) or this.count('\n') >= limit_n_lines: # use a magical number 42
|
||||
this, this_mask = search_with_line_limit(this, this_mask)
|
||||
mask[res.regs[2][0]:res.regs[2][1]] = this_mask
|
||||
else:
|
||||
mask[res.regs[0][0]:res.regs[0][1]] = PRESERVE
|
||||
return text, mask
|
||||
return search_with_line_limit(text, mask)
|
||||
|
||||
class LinkedListNode():
|
||||
"""
|
||||
Linked List Node
|
||||
"""
|
||||
def __init__(self, string, preserve=True) -> None:
|
||||
self.string = string
|
||||
self.preserve = preserve
|
||||
self.next = None
|
||||
# self.begin_line = 0
|
||||
# self.begin_char = 0
|
||||
|
||||
def convert_to_linklist(text, mask):
|
||||
root = LinkedListNode("", preserve=True)
|
||||
current_node = root
|
||||
for c, m, i in zip(text, mask, range(len(text))):
|
||||
if (m==PRESERVE and current_node.preserve) \
|
||||
or (m==TRANSFORM and not current_node.preserve):
|
||||
# add
|
||||
current_node.string += c
|
||||
else:
|
||||
current_node.next = LinkedListNode(c, preserve=(m==PRESERVE))
|
||||
current_node = current_node.next
|
||||
return root
|
||||
"""
|
||||
========================================================================
|
||||
Latex Merge File
|
||||
========================================================================
|
||||
"""
|
||||
|
||||
def 寻找Latex主文件(file_manifest, mode):
|
||||
"""
|
||||
在多Tex文档中,寻找主文件,必须包含documentclass,返回找到的第一个。
|
||||
P.S. 但愿没人把latex模板放在里面传进来 (6.25 加入判定latex模板的代码)
|
||||
"""
|
||||
canidates = []
|
||||
for texf in file_manifest:
|
||||
if os.path.basename(texf).startswith('merge'):
|
||||
continue
|
||||
with open(texf, 'r', encoding='utf8') as f:
|
||||
file_content = f.read()
|
||||
if r'\documentclass' in file_content:
|
||||
canidates.append(texf)
|
||||
else:
|
||||
continue
|
||||
|
||||
if len(canidates) == 0:
|
||||
raise RuntimeError('无法找到一个主Tex文件(包含documentclass关键字)')
|
||||
elif len(canidates) == 1:
|
||||
return canidates[0]
|
||||
else: # if len(canidates) >= 2 通过一些Latex模板中常见(但通常不会出现在正文)的单词,对不同latex源文件扣分,取评分最高者返回
|
||||
canidates_score = []
|
||||
# 给出一些判定模板文档的词作为扣分项
|
||||
unexpected_words = ['\LaTeX', 'manuscript', 'Guidelines', 'font', 'citations', 'rejected', 'blind review', 'reviewers']
|
||||
expected_words = ['\input', '\ref', '\cite']
|
||||
for texf in canidates:
|
||||
canidates_score.append(0)
|
||||
with open(texf, 'r', encoding='utf8') as f:
|
||||
file_content = f.read()
|
||||
for uw in unexpected_words:
|
||||
if uw in file_content:
|
||||
canidates_score[-1] -= 1
|
||||
for uw in expected_words:
|
||||
if uw in file_content:
|
||||
canidates_score[-1] += 1
|
||||
select = np.argmax(canidates_score) # 取评分最高者返回
|
||||
return canidates[select]
|
||||
|
||||
def rm_comments(main_file):
|
||||
new_file_remove_comment_lines = []
|
||||
for l in main_file.splitlines():
|
||||
# 删除整行的空注释
|
||||
if l.lstrip().startswith("%"):
|
||||
pass
|
||||
else:
|
||||
new_file_remove_comment_lines.append(l)
|
||||
main_file = '\n'.join(new_file_remove_comment_lines)
|
||||
# main_file = re.sub(r"\\include{(.*?)}", r"\\input{\1}", main_file) # 将 \include 命令转换为 \input 命令
|
||||
main_file = re.sub(r'(?<!\\)%.*', '', main_file) # 使用正则表达式查找半行注释, 并替换为空字符串
|
||||
return main_file
|
||||
|
||||
def merge_tex_files_(project_foler, main_file, mode):
|
||||
"""
|
||||
Merge Tex project recrusively
|
||||
"""
|
||||
main_file = rm_comments(main_file)
|
||||
for s in reversed([q for q in re.finditer(r"\\input\{(.*?)\}", main_file, re.M)]):
|
||||
f = s.group(1)
|
||||
fp = os.path.join(project_foler, f)
|
||||
if os.path.exists(fp):
|
||||
# e.g., \input{srcs/07_appendix.tex}
|
||||
with open(fp, 'r', encoding='utf-8', errors='replace') as fx:
|
||||
c = fx.read()
|
||||
else:
|
||||
# e.g., \input{srcs/07_appendix}
|
||||
with open(fp+'.tex', 'r', encoding='utf-8', errors='replace') as fx:
|
||||
c = fx.read()
|
||||
c = merge_tex_files_(project_foler, c, mode)
|
||||
main_file = main_file[:s.span()[0]] + c + main_file[s.span()[1]:]
|
||||
return main_file
|
||||
|
||||
def merge_tex_files(project_foler, main_file, mode):
|
||||
"""
|
||||
Merge Tex project recrusively
|
||||
P.S. 顺便把CTEX塞进去以支持中文
|
||||
P.S. 顺便把Latex的注释去除
|
||||
"""
|
||||
main_file = merge_tex_files_(project_foler, main_file, mode)
|
||||
main_file = rm_comments(main_file)
|
||||
|
||||
if mode == 'translate_zh':
|
||||
# find paper documentclass
|
||||
pattern = re.compile(r'\\documentclass.*\n')
|
||||
match = pattern.search(main_file)
|
||||
assert match is not None, "Cannot find documentclass statement!"
|
||||
position = match.end()
|
||||
add_ctex = '\\usepackage{ctex}\n'
|
||||
add_url = '\\usepackage{url}\n' if '{url}' not in main_file else ''
|
||||
main_file = main_file[:position] + add_ctex + add_url + main_file[position:]
|
||||
# fontset=windows
|
||||
import platform
|
||||
main_file = re.sub(r"\\documentclass\[(.*?)\]{(.*?)}", r"\\documentclass[\1,fontset=windows,UTF8]{\2}",main_file)
|
||||
main_file = re.sub(r"\\documentclass{(.*?)}", r"\\documentclass[fontset=windows,UTF8]{\1}",main_file)
|
||||
# find paper abstract
|
||||
pattern_opt1 = re.compile(r'\\begin\{abstract\}.*\n')
|
||||
pattern_opt2 = re.compile(r"\\abstract\{(.*?)\}", flags=re.DOTALL)
|
||||
match_opt1 = pattern_opt1.search(main_file)
|
||||
match_opt2 = pattern_opt2.search(main_file)
|
||||
assert (match_opt1 is not None) or (match_opt2 is not None), "Cannot find paper abstract section!"
|
||||
return main_file
|
||||
|
||||
|
||||
|
||||
"""
|
||||
========================================================================
|
||||
Post process
|
||||
========================================================================
|
||||
"""
|
||||
def mod_inbraket(match):
|
||||
"""
|
||||
为啥chatgpt会把cite里面的逗号换成中文逗号呀
|
||||
"""
|
||||
# get the matched string
|
||||
cmd = match.group(1)
|
||||
str_to_modify = match.group(2)
|
||||
# modify the matched string
|
||||
str_to_modify = str_to_modify.replace(':', ':') # 前面是中文冒号,后面是英文冒号
|
||||
str_to_modify = str_to_modify.replace(',', ',') # 前面是中文逗号,后面是英文逗号
|
||||
# str_to_modify = 'BOOM'
|
||||
return "\\" + cmd + "{" + str_to_modify + "}"
|
||||
|
||||
def fix_content(final_tex, node_string):
|
||||
"""
|
||||
Fix common GPT errors to increase success rate
|
||||
"""
|
||||
final_tex = re.sub(r"(?<!\\)%", "\\%", final_tex)
|
||||
final_tex = re.sub(r"\\([a-z]{2,10})\ \{", r"\\\1{", string=final_tex)
|
||||
final_tex = re.sub(r"\\\ ([a-z]{2,10})\{", r"\\\1{", string=final_tex)
|
||||
final_tex = re.sub(r"\\([a-z]{2,10})\{([^\}]*?)\}", mod_inbraket, string=final_tex)
|
||||
|
||||
if "Traceback" in final_tex and "[Local Message]" in final_tex:
|
||||
final_tex = node_string # 出问题了,还原原文
|
||||
if node_string.count('\\begin') != final_tex.count('\\begin'):
|
||||
final_tex = node_string # 出问题了,还原原文
|
||||
if node_string.count('\_') > 0 and node_string.count('\_') > final_tex.count('\_'):
|
||||
# walk and replace any _ without \
|
||||
final_tex = re.sub(r"(?<!\\)_", "\\_", final_tex)
|
||||
|
||||
def compute_brace_level(string):
|
||||
# this function count the number of { and }
|
||||
brace_level = 0
|
||||
for c in string:
|
||||
if c == "{": brace_level += 1
|
||||
elif c == "}": brace_level -= 1
|
||||
return brace_level
|
||||
def join_most(tex_t, tex_o):
|
||||
# this function join translated string and original string when something goes wrong
|
||||
p_t = 0
|
||||
p_o = 0
|
||||
def find_next(string, chars, begin):
|
||||
p = begin
|
||||
while p < len(string):
|
||||
if string[p] in chars: return p, string[p]
|
||||
p += 1
|
||||
return None, None
|
||||
while True:
|
||||
res1, char = find_next(tex_o, ['{','}'], p_o)
|
||||
if res1 is None: break
|
||||
res2, char = find_next(tex_t, [char], p_t)
|
||||
if res2 is None: break
|
||||
p_o = res1 + 1
|
||||
p_t = res2 + 1
|
||||
return tex_t[:p_t] + tex_o[p_o:]
|
||||
|
||||
if compute_brace_level(final_tex) != compute_brace_level(node_string):
|
||||
# 出问题了,还原部分原文,保证括号正确
|
||||
final_tex = join_most(final_tex, node_string)
|
||||
return final_tex
|
||||
|
||||
def split_subprocess(txt, project_folder, return_dict, opts):
|
||||
"""
|
||||
@@ -317,13 +22,14 @@ def split_subprocess(txt, project_folder, return_dict, opts):
|
||||
mask = np.zeros(len(txt), dtype=np.uint8) + TRANSFORM
|
||||
|
||||
# 吸收title与作者以上的部分
|
||||
text, mask = set_forbidden_text(text, mask, r"(.*?)\\maketitle", re.DOTALL)
|
||||
text, mask = set_forbidden_text(text, mask, r"^(.*?)\\maketitle", re.DOTALL)
|
||||
text, mask = set_forbidden_text(text, mask, r"^(.*?)\\begin{document}", re.DOTALL)
|
||||
# 吸收iffalse注释
|
||||
text, mask = set_forbidden_text(text, mask, r"\\iffalse(.*?)\\fi", re.DOTALL)
|
||||
# 吸收在42行以内的begin-end组合
|
||||
text, mask = set_forbidden_text_begin_end(text, mask, r"\\begin\{([a-z\*]*)\}(.*?)\\end\{\1\}", re.DOTALL, limit_n_lines=42)
|
||||
# 吸收匿名公式
|
||||
text, mask = set_forbidden_text(text, mask, [ r"\$\$(.*?)\$\$", r"\\\[.*?\\\]" ], re.DOTALL)
|
||||
text, mask = set_forbidden_text(text, mask, [ r"\$\$([^$]+)\$\$", r"\\\[.*?\\\]" ], re.DOTALL)
|
||||
# 吸收其他杂项
|
||||
text, mask = set_forbidden_text(text, mask, [ r"\\section\{(.*?)\}", r"\\section\*\{(.*?)\}", r"\\subsection\{(.*?)\}", r"\\subsubsection\{(.*?)\}" ])
|
||||
text, mask = set_forbidden_text(text, mask, [ r"\\bibliography\{(.*?)\}", r"\\bibliographystyle\{(.*?)\}" ])
|
||||
@@ -347,77 +53,9 @@ def split_subprocess(txt, project_folder, return_dict, opts):
|
||||
text, mask = reverse_forbidden_text(text, mask, r"\\begin\{abstract\}(.*?)\\end\{abstract\}", re.DOTALL, forbid_wrapper=True)
|
||||
root = convert_to_linklist(text, mask)
|
||||
|
||||
# 修复括号
|
||||
node = root
|
||||
while True:
|
||||
string = node.string
|
||||
if node.preserve:
|
||||
node = node.next
|
||||
if node is None: break
|
||||
continue
|
||||
def break_check(string):
|
||||
str_stack = [""] # (lv, index)
|
||||
for i, c in enumerate(string):
|
||||
if c == '{':
|
||||
str_stack.append('{')
|
||||
elif c == '}':
|
||||
if len(str_stack) == 1:
|
||||
print('stack fix')
|
||||
return i
|
||||
str_stack.pop(-1)
|
||||
else:
|
||||
str_stack[-1] += c
|
||||
return -1
|
||||
bp = break_check(string)
|
||||
# 最后一步处理,增强稳健性
|
||||
root = post_process(root)
|
||||
|
||||
if bp == -1:
|
||||
pass
|
||||
elif bp == 0:
|
||||
node.string = string[:1]
|
||||
q = LinkedListNode(string[1:], False)
|
||||
q.next = node.next
|
||||
node.next = q
|
||||
else:
|
||||
node.string = string[:bp]
|
||||
q = LinkedListNode(string[bp:], False)
|
||||
q.next = node.next
|
||||
node.next = q
|
||||
|
||||
node = node.next
|
||||
if node is None: break
|
||||
|
||||
# 屏蔽空行和太短的句子
|
||||
node = root
|
||||
while True:
|
||||
if len(node.string.strip('\n').strip(''))==0: node.preserve = True
|
||||
if len(node.string.strip('\n').strip(''))<42: node.preserve = True
|
||||
node = node.next
|
||||
if node is None: break
|
||||
node = root
|
||||
while True:
|
||||
if node.next and node.preserve and node.next.preserve:
|
||||
node.string += node.next.string
|
||||
node.next = node.next.next
|
||||
node = node.next
|
||||
if node is None: break
|
||||
|
||||
# 将前后断行符脱离
|
||||
node = root
|
||||
prev_node = None
|
||||
while True:
|
||||
if not node.preserve:
|
||||
lstriped_ = node.string.lstrip().lstrip('\n')
|
||||
if (prev_node is not None) and (prev_node.preserve) and (len(lstriped_)!=len(node.string)):
|
||||
prev_node.string += node.string[:-len(lstriped_)]
|
||||
node.string = lstriped_
|
||||
rstriped_ = node.string.rstrip().rstrip('\n')
|
||||
if (node.next is not None) and (node.next.preserve) and (len(rstriped_)!=len(node.string)):
|
||||
node.next.string = node.string[len(rstriped_):] + node.next.string
|
||||
node.string = rstriped_
|
||||
# =====
|
||||
prev_node = node
|
||||
node = node.next
|
||||
if node is None: break
|
||||
# 输出html调试文件,用红色标注处保留区(PRESERVE),用黑色标注转换区(TRANSFORM)
|
||||
with open(pj(project_folder, 'debug_log.html'), 'w', encoding='utf8') as f:
|
||||
segment_parts_for_gpt = []
|
||||
@@ -428,7 +66,7 @@ def split_subprocess(txt, project_folder, return_dict, opts):
|
||||
show_html = node.string.replace('\n','<br/>')
|
||||
if not node.preserve:
|
||||
segment_parts_for_gpt.append(node.string)
|
||||
f.write(f'<p style="color:black;">#{show_html}#</p>')
|
||||
f.write(f'<p style="color:black;">#{node.range}{show_html}#</p>')
|
||||
else:
|
||||
f.write(f'<p style="color:red;">{show_html}</p>')
|
||||
node = node.next
|
||||
@@ -439,8 +77,6 @@ def split_subprocess(txt, project_folder, return_dict, opts):
|
||||
return_dict['segment_parts_for_gpt'] = segment_parts_for_gpt
|
||||
return return_dict
|
||||
|
||||
|
||||
|
||||
class LatexPaperSplit():
|
||||
"""
|
||||
break down latex file to a linked list,
|
||||
@@ -455,18 +91,32 @@ class LatexPaperSplit():
|
||||
# 请您不要删除或修改这行警告,除非您是论文的原作者(如果您是论文原作者,欢迎加REAME中的QQ联系开发者)
|
||||
self.msg_declare = "为了防止大语言模型的意外谬误产生扩散影响,禁止移除或修改此警告。}}\\\\"
|
||||
|
||||
def merge_result(self, arr, mode, msg):
|
||||
|
||||
def merge_result(self, arr, mode, msg, buggy_lines=[], buggy_line_surgery_n_lines=10):
|
||||
"""
|
||||
Merge the result after the GPT process completed
|
||||
"""
|
||||
result_string = ""
|
||||
p = 0
|
||||
node_cnt = 0
|
||||
line_cnt = 0
|
||||
|
||||
for node in self.nodes:
|
||||
if node.preserve:
|
||||
line_cnt += node.string.count('\n')
|
||||
result_string += node.string
|
||||
else:
|
||||
result_string += fix_content(arr[p], node.string)
|
||||
p += 1
|
||||
translated_txt = fix_content(arr[node_cnt], node.string)
|
||||
begin_line = line_cnt
|
||||
end_line = line_cnt + translated_txt.count('\n')
|
||||
|
||||
# reverse translation if any error
|
||||
if any([begin_line-buggy_line_surgery_n_lines <= b_line <= end_line+buggy_line_surgery_n_lines for b_line in buggy_lines]):
|
||||
translated_txt = node.string
|
||||
|
||||
result_string += translated_txt
|
||||
node_cnt += 1
|
||||
line_cnt += translated_txt.count('\n')
|
||||
|
||||
if mode == 'translate_zh':
|
||||
pattern = re.compile(r'\\begin\{abstract\}.*\n')
|
||||
match = pattern.search(result_string)
|
||||
@@ -481,6 +131,7 @@ class LatexPaperSplit():
|
||||
result_string = result_string[:position] + self.msg + msg + self.msg_declare + result_string[position:]
|
||||
return result_string
|
||||
|
||||
|
||||
def split(self, txt, project_folder, opts):
|
||||
"""
|
||||
break down latex file to a linked list,
|
||||
@@ -502,7 +153,6 @@ class LatexPaperSplit():
|
||||
return self.sp
|
||||
|
||||
|
||||
|
||||
class LatexPaperFileGroup():
|
||||
"""
|
||||
use tokenizer to break down text according to max_token_limit
|
||||
@@ -530,7 +180,7 @@ class LatexPaperFileGroup():
|
||||
self.sp_file_index.append(index)
|
||||
self.sp_file_tag.append(self.file_paths[index])
|
||||
else:
|
||||
from .crazy_utils import breakdown_txt_to_satisfy_token_limit_for_pdf
|
||||
from ..crazy_utils import breakdown_txt_to_satisfy_token_limit_for_pdf
|
||||
segments = breakdown_txt_to_satisfy_token_limit_for_pdf(file_content, self.get_token_num, max_token_limit)
|
||||
for j, segment in enumerate(segments):
|
||||
self.sp_file_contents.append(segment)
|
||||
@@ -551,41 +201,14 @@ class LatexPaperFileGroup():
|
||||
f.write(res)
|
||||
return manifest
|
||||
|
||||
def write_html(sp_file_contents, sp_file_result, chatbot, project_folder):
|
||||
|
||||
# write html
|
||||
try:
|
||||
import shutil
|
||||
from .crazy_utils import construct_html
|
||||
from toolbox import gen_time_str
|
||||
ch = construct_html()
|
||||
orig = ""
|
||||
trans = ""
|
||||
final = []
|
||||
for c,r in zip(sp_file_contents, sp_file_result):
|
||||
final.append(c)
|
||||
final.append(r)
|
||||
for i, k in enumerate(final):
|
||||
if i%2==0:
|
||||
orig = k
|
||||
if i%2==1:
|
||||
trans = k
|
||||
ch.add_row(a=orig, b=trans)
|
||||
create_report_file_name = f"{gen_time_str()}.trans.html"
|
||||
ch.save_file(create_report_file_name)
|
||||
shutil.copyfile(pj('./gpt_log/', create_report_file_name), pj(project_folder, create_report_file_name))
|
||||
promote_file_to_downloadzone(file=f'./gpt_log/{create_report_file_name}', chatbot=chatbot)
|
||||
except:
|
||||
from toolbox import trimmed_format_exc
|
||||
print('writing html result failed:', trimmed_format_exc())
|
||||
|
||||
def Latex精细分解与转化(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, mode='proofread', switch_prompt=None, opts=[]):
|
||||
import time, os, re
|
||||
from .crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||
from .latex_utils import LatexPaperFileGroup, merge_tex_files, LatexPaperSplit, 寻找Latex主文件
|
||||
from ..crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||
from .latex_actions import LatexPaperFileGroup, LatexPaperSplit
|
||||
|
||||
# <-------- 寻找主tex文件 ---------->
|
||||
maintex = 寻找Latex主文件(file_manifest, mode)
|
||||
maintex = find_main_tex_file(file_manifest, mode)
|
||||
chatbot.append((f"定位主Latex文件", f'[Local Message] 分析结果:该项目的Latex主文件是{maintex}, 如果分析错误, 请立即终止程序, 删除或修改歧义文件, 然后重试。主程序即将开始, 请稍候。'))
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
time.sleep(3)
|
||||
@@ -659,54 +282,51 @@ def Latex精细分解与转化(file_manifest, project_folder, llm_kwargs, plugin
|
||||
# <-------- 写出文件 ---------->
|
||||
msg = f"当前大语言模型: {llm_kwargs['llm_model']},当前语言模型温度设定: {llm_kwargs['temperature']}。"
|
||||
final_tex = lps.merge_result(pfg.file_result, mode, msg)
|
||||
objdump((lps, pfg.file_result, mode, msg), file=pj(project_folder,'merge_result.pkl'))
|
||||
|
||||
with open(project_folder + f'/merge_{mode}.tex', 'w', encoding='utf-8', errors='replace') as f:
|
||||
if mode != 'translate_zh' or "binary" in final_tex: f.write(final_tex)
|
||||
|
||||
|
||||
# <-------- 整理结果, 退出 ---------->
|
||||
chatbot.append((f"完成了吗?", 'GPT结果已输出, 正在编译PDF'))
|
||||
chatbot.append((f"完成了吗?", 'GPT结果已输出, 即将编译PDF'))
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
# <-------- 返回 ---------->
|
||||
return project_folder + f'/merge_{mode}.tex'
|
||||
|
||||
|
||||
|
||||
def remove_buggy_lines(file_path, log_path, tex_name, tex_name_pure, n_fix, work_folder_modified):
|
||||
def remove_buggy_lines(file_path, log_path, tex_name, tex_name_pure, n_fix, work_folder_modified, fixed_line=[]):
|
||||
try:
|
||||
with open(log_path, 'r', encoding='utf-8', errors='replace') as f:
|
||||
log = f.read()
|
||||
with open(file_path, 'r', encoding='utf-8', errors='replace') as f:
|
||||
file_lines = f.readlines()
|
||||
import re
|
||||
buggy_lines = re.findall(tex_name+':([0-9]{1,5}):', log)
|
||||
buggy_lines = [int(l) for l in buggy_lines]
|
||||
buggy_lines = sorted(buggy_lines)
|
||||
print("removing lines that has errors", buggy_lines)
|
||||
file_lines.pop(buggy_lines[0]-1)
|
||||
buggy_line = buggy_lines[0]-1
|
||||
print("reversing tex line that has errors", buggy_line)
|
||||
|
||||
# 重组,逆转出错的段落
|
||||
if buggy_line not in fixed_line:
|
||||
fixed_line.append(buggy_line)
|
||||
|
||||
lps, file_result, mode, msg = objload(file=pj(work_folder_modified,'merge_result.pkl'))
|
||||
final_tex = lps.merge_result(file_result, mode, msg, buggy_lines=fixed_line, buggy_line_surgery_n_lines=5*n_fix)
|
||||
|
||||
with open(pj(work_folder_modified, f"{tex_name_pure}_fix_{n_fix}.tex"), 'w', encoding='utf-8', errors='replace') as f:
|
||||
f.writelines(file_lines)
|
||||
f.write(final_tex)
|
||||
|
||||
return True, f"{tex_name_pure}_fix_{n_fix}", buggy_lines
|
||||
except:
|
||||
print("Fatal error occurred, but we cannot identify error, please download zip, read latex log, and compile manually.")
|
||||
return False, -1, [-1]
|
||||
|
||||
def compile_latex_with_timeout(command, cwd, timeout=60):
|
||||
import subprocess
|
||||
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=cwd)
|
||||
try:
|
||||
stdout, stderr = process.communicate(timeout=timeout)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
stdout, stderr = process.communicate()
|
||||
print("Process timed out!")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def 编译Latex(chatbot, history, main_file_original, main_file_modified, work_folder_original, work_folder_modified, work_folder, mode='default'):
|
||||
import os, time
|
||||
current_dir = os.getcwd()
|
||||
n_fix = 1
|
||||
fixed_line = []
|
||||
max_try = 32
|
||||
chatbot.append([f"正在编译PDF文档", f'编译已经开始。当前工作路径为{work_folder},如果程序停顿5分钟以上,请直接去该路径下取回翻译结果,或者重启之后再度尝试 ...']); yield from update_ui(chatbot=chatbot, history=history)
|
||||
chatbot.append([f"正在编译PDF文档", '...']); yield from update_ui(chatbot=chatbot, history=history); time.sleep(1); chatbot[-1] = list(chatbot[-1]) # 刷新界面
|
||||
@@ -714,6 +334,10 @@ def 编译Latex(chatbot, history, main_file_original, main_file_modified, work_f
|
||||
|
||||
while True:
|
||||
import os
|
||||
may_exist_bbl = pj(work_folder_modified, f'merge.bbl')
|
||||
target_bbl = pj(work_folder_modified, f'{main_file_modified}.bbl')
|
||||
if os.path.exists(may_exist_bbl) and not os.path.exists(target_bbl):
|
||||
shutil.copyfile(may_exist_bbl, target_bbl)
|
||||
|
||||
# https://stackoverflow.com/questions/738755/dont-make-me-manually-abort-a-latex-compile-when-theres-an-error
|
||||
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 编译原始PDF ...', chatbot, history) # 刷新Gradio前端界面
|
||||
@@ -747,7 +371,6 @@ def 编译Latex(chatbot, history, main_file_original, main_file_modified, work_f
|
||||
ok = compile_latex_with_timeout(f'pdflatex -interaction=batchmode -file-line-error merge_diff.tex', work_folder)
|
||||
ok = compile_latex_with_timeout(f'pdflatex -interaction=batchmode -file-line-error merge_diff.tex', work_folder)
|
||||
|
||||
|
||||
# <---------- 检查结果 ----------->
|
||||
results_ = ""
|
||||
original_pdf_success = os.path.exists(pj(work_folder_original, f'{main_file_original}.pdf'))
|
||||
@@ -764,9 +387,19 @@ def 编译Latex(chatbot, history, main_file_original, main_file_modified, work_f
|
||||
if modified_pdf_success:
|
||||
yield from update_ui_lastest_msg(f'转化PDF编译已经成功, 即将退出 ...', chatbot, history) # 刷新Gradio前端界面
|
||||
result_pdf = pj(work_folder_modified, f'{main_file_modified}.pdf') # get pdf path
|
||||
origin_pdf = pj(work_folder_original, f'{main_file_original}.pdf') # get pdf path
|
||||
if os.path.exists(pj(work_folder, '..', 'translation')):
|
||||
shutil.copyfile(result_pdf, pj(work_folder, '..', 'translation', 'translate_zh.pdf'))
|
||||
promote_file_to_downloadzone(result_pdf, rename_file=None, chatbot=chatbot) # promote file to web UI
|
||||
# 将两个PDF拼接
|
||||
if original_pdf_success:
|
||||
try:
|
||||
from .latex_toolbox import merge_pdfs
|
||||
concat_pdf = pj(work_folder_modified, f'comparison.pdf')
|
||||
merge_pdfs(origin_pdf, result_pdf, concat_pdf)
|
||||
promote_file_to_downloadzone(concat_pdf, rename_file=None, chatbot=chatbot) # promote file to web UI
|
||||
except Exception as e:
|
||||
pass
|
||||
return True # 成功啦
|
||||
else:
|
||||
if n_fix>=max_try: break
|
||||
@@ -778,6 +411,7 @@ def 编译Latex(chatbot, history, main_file_original, main_file_modified, work_f
|
||||
tex_name_pure=f'{main_file_modified}',
|
||||
n_fix=n_fix,
|
||||
work_folder_modified=work_folder_modified,
|
||||
fixed_line=fixed_line
|
||||
)
|
||||
yield from update_ui_lastest_msg(f'由于最为关键的转化PDF编译失败, 将根据报错信息修正tex源文件并重试, 当前报错的latex代码处于第{buggy_lines}行 ...', chatbot, history) # 刷新Gradio前端界面
|
||||
if not can_retry: break
|
||||
@@ -785,4 +419,29 @@ def 编译Latex(chatbot, history, main_file_original, main_file_modified, work_f
|
||||
return False # 失败啦
|
||||
|
||||
|
||||
|
||||
def write_html(sp_file_contents, sp_file_result, chatbot, project_folder):
|
||||
# write html
|
||||
try:
|
||||
import shutil
|
||||
from ..crazy_utils import construct_html
|
||||
from toolbox import gen_time_str
|
||||
ch = construct_html()
|
||||
orig = ""
|
||||
trans = ""
|
||||
final = []
|
||||
for c,r in zip(sp_file_contents, sp_file_result):
|
||||
final.append(c)
|
||||
final.append(r)
|
||||
for i, k in enumerate(final):
|
||||
if i%2==0:
|
||||
orig = k
|
||||
if i%2==1:
|
||||
trans = k
|
||||
ch.add_row(a=orig, b=trans)
|
||||
create_report_file_name = f"{gen_time_str()}.trans.html"
|
||||
ch.save_file(create_report_file_name)
|
||||
shutil.copyfile(pj('./gpt_log/', create_report_file_name), pj(project_folder, create_report_file_name))
|
||||
promote_file_to_downloadzone(file=f'./gpt_log/{create_report_file_name}', chatbot=chatbot)
|
||||
except:
|
||||
from toolbox import trimmed_format_exc
|
||||
print('writing html result failed:', trimmed_format_exc())
|
||||
@@ -0,0 +1,456 @@
|
||||
import os, shutil
|
||||
import re
|
||||
import numpy as np
|
||||
PRESERVE = 0
|
||||
TRANSFORM = 1
|
||||
|
||||
pj = os.path.join
|
||||
|
||||
class LinkedListNode():
|
||||
"""
|
||||
Linked List Node
|
||||
"""
|
||||
def __init__(self, string, preserve=True) -> None:
|
||||
self.string = string
|
||||
self.preserve = preserve
|
||||
self.next = None
|
||||
self.range = None
|
||||
# self.begin_line = 0
|
||||
# self.begin_char = 0
|
||||
|
||||
def convert_to_linklist(text, mask):
|
||||
root = LinkedListNode("", preserve=True)
|
||||
current_node = root
|
||||
for c, m, i in zip(text, mask, range(len(text))):
|
||||
if (m==PRESERVE and current_node.preserve) \
|
||||
or (m==TRANSFORM and not current_node.preserve):
|
||||
# add
|
||||
current_node.string += c
|
||||
else:
|
||||
current_node.next = LinkedListNode(c, preserve=(m==PRESERVE))
|
||||
current_node = current_node.next
|
||||
return root
|
||||
|
||||
def post_process(root):
|
||||
# 修复括号
|
||||
node = root
|
||||
while True:
|
||||
string = node.string
|
||||
if node.preserve:
|
||||
node = node.next
|
||||
if node is None: break
|
||||
continue
|
||||
def break_check(string):
|
||||
str_stack = [""] # (lv, index)
|
||||
for i, c in enumerate(string):
|
||||
if c == '{':
|
||||
str_stack.append('{')
|
||||
elif c == '}':
|
||||
if len(str_stack) == 1:
|
||||
print('stack fix')
|
||||
return i
|
||||
str_stack.pop(-1)
|
||||
else:
|
||||
str_stack[-1] += c
|
||||
return -1
|
||||
bp = break_check(string)
|
||||
|
||||
if bp == -1:
|
||||
pass
|
||||
elif bp == 0:
|
||||
node.string = string[:1]
|
||||
q = LinkedListNode(string[1:], False)
|
||||
q.next = node.next
|
||||
node.next = q
|
||||
else:
|
||||
node.string = string[:bp]
|
||||
q = LinkedListNode(string[bp:], False)
|
||||
q.next = node.next
|
||||
node.next = q
|
||||
|
||||
node = node.next
|
||||
if node is None: break
|
||||
|
||||
# 屏蔽空行和太短的句子
|
||||
node = root
|
||||
while True:
|
||||
if len(node.string.strip('\n').strip(''))==0: node.preserve = True
|
||||
if len(node.string.strip('\n').strip(''))<42: node.preserve = True
|
||||
node = node.next
|
||||
if node is None: break
|
||||
node = root
|
||||
while True:
|
||||
if node.next and node.preserve and node.next.preserve:
|
||||
node.string += node.next.string
|
||||
node.next = node.next.next
|
||||
node = node.next
|
||||
if node is None: break
|
||||
|
||||
# 将前后断行符脱离
|
||||
node = root
|
||||
prev_node = None
|
||||
while True:
|
||||
if not node.preserve:
|
||||
lstriped_ = node.string.lstrip().lstrip('\n')
|
||||
if (prev_node is not None) and (prev_node.preserve) and (len(lstriped_)!=len(node.string)):
|
||||
prev_node.string += node.string[:-len(lstriped_)]
|
||||
node.string = lstriped_
|
||||
rstriped_ = node.string.rstrip().rstrip('\n')
|
||||
if (node.next is not None) and (node.next.preserve) and (len(rstriped_)!=len(node.string)):
|
||||
node.next.string = node.string[len(rstriped_):] + node.next.string
|
||||
node.string = rstriped_
|
||||
# =====
|
||||
prev_node = node
|
||||
node = node.next
|
||||
if node is None: break
|
||||
|
||||
# 标注节点的行数范围
|
||||
node = root
|
||||
n_line = 0
|
||||
expansion = 2
|
||||
while True:
|
||||
n_l = node.string.count('\n')
|
||||
node.range = [n_line-expansion, n_line+n_l+expansion] # 失败时,扭转的范围
|
||||
n_line = n_line+n_l
|
||||
node = node.next
|
||||
if node is None: break
|
||||
return root
|
||||
|
||||
|
||||
"""
|
||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
||||
Latex segmentation with a binary mask (PRESERVE=0, TRANSFORM=1)
|
||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
||||
"""
|
||||
|
||||
|
||||
def set_forbidden_text(text, mask, pattern, flags=0):
|
||||
"""
|
||||
Add a preserve text area in this paper
|
||||
e.g. with pattern = r"\\begin\{algorithm\}(.*?)\\end\{algorithm\}"
|
||||
you can mask out (mask = PRESERVE so that text become untouchable for GPT)
|
||||
everything between "\begin{equation}" and "\end{equation}"
|
||||
"""
|
||||
if isinstance(pattern, list): pattern = '|'.join(pattern)
|
||||
pattern_compile = re.compile(pattern, flags)
|
||||
for res in pattern_compile.finditer(text):
|
||||
mask[res.span()[0]:res.span()[1]] = PRESERVE
|
||||
return text, mask
|
||||
|
||||
def reverse_forbidden_text(text, mask, pattern, flags=0, forbid_wrapper=True):
|
||||
"""
|
||||
Move area out of preserve area (make text editable for GPT)
|
||||
count the number of the braces so as to catch compelete text area.
|
||||
e.g.
|
||||
\begin{abstract} blablablablablabla. \end{abstract}
|
||||
"""
|
||||
if isinstance(pattern, list): pattern = '|'.join(pattern)
|
||||
pattern_compile = re.compile(pattern, flags)
|
||||
for res in pattern_compile.finditer(text):
|
||||
if not forbid_wrapper:
|
||||
mask[res.span()[0]:res.span()[1]] = TRANSFORM
|
||||
else:
|
||||
mask[res.regs[0][0]: res.regs[1][0]] = PRESERVE # '\\begin{abstract}'
|
||||
mask[res.regs[1][0]: res.regs[1][1]] = TRANSFORM # abstract
|
||||
mask[res.regs[1][1]: res.regs[0][1]] = PRESERVE # abstract
|
||||
return text, mask
|
||||
|
||||
def set_forbidden_text_careful_brace(text, mask, pattern, flags=0):
|
||||
"""
|
||||
Add a preserve text area in this paper (text become untouchable for GPT).
|
||||
count the number of the braces so as to catch compelete text area.
|
||||
e.g.
|
||||
\caption{blablablablabla\texbf{blablabla}blablabla.}
|
||||
"""
|
||||
pattern_compile = re.compile(pattern, flags)
|
||||
for res in pattern_compile.finditer(text):
|
||||
brace_level = -1
|
||||
p = begin = end = res.regs[0][0]
|
||||
for _ in range(1024*16):
|
||||
if text[p] == '}' and brace_level == 0: break
|
||||
elif text[p] == '}': brace_level -= 1
|
||||
elif text[p] == '{': brace_level += 1
|
||||
p += 1
|
||||
end = p+1
|
||||
mask[begin:end] = PRESERVE
|
||||
return text, mask
|
||||
|
||||
def reverse_forbidden_text_careful_brace(text, mask, pattern, flags=0, forbid_wrapper=True):
|
||||
"""
|
||||
Move area out of preserve area (make text editable for GPT)
|
||||
count the number of the braces so as to catch compelete text area.
|
||||
e.g.
|
||||
\caption{blablablablabla\texbf{blablabla}blablabla.}
|
||||
"""
|
||||
pattern_compile = re.compile(pattern, flags)
|
||||
for res in pattern_compile.finditer(text):
|
||||
brace_level = 0
|
||||
p = begin = end = res.regs[1][0]
|
||||
for _ in range(1024*16):
|
||||
if text[p] == '}' and brace_level == 0: break
|
||||
elif text[p] == '}': brace_level -= 1
|
||||
elif text[p] == '{': brace_level += 1
|
||||
p += 1
|
||||
end = p
|
||||
mask[begin:end] = TRANSFORM
|
||||
if forbid_wrapper:
|
||||
mask[res.regs[0][0]:begin] = PRESERVE
|
||||
mask[end:res.regs[0][1]] = PRESERVE
|
||||
return text, mask
|
||||
|
||||
def set_forbidden_text_begin_end(text, mask, pattern, flags=0, limit_n_lines=42):
|
||||
"""
|
||||
Find all \begin{} ... \end{} text block that with less than limit_n_lines lines.
|
||||
Add it to preserve area
|
||||
"""
|
||||
pattern_compile = re.compile(pattern, flags)
|
||||
def search_with_line_limit(text, mask):
|
||||
for res in pattern_compile.finditer(text):
|
||||
cmd = res.group(1) # begin{what}
|
||||
this = res.group(2) # content between begin and end
|
||||
this_mask = mask[res.regs[2][0]:res.regs[2][1]]
|
||||
white_list = ['document', 'abstract', 'lemma', 'definition', 'sproof',
|
||||
'em', 'emph', 'textit', 'textbf', 'itemize', 'enumerate']
|
||||
if (cmd in white_list) or this.count('\n') >= limit_n_lines: # use a magical number 42
|
||||
this, this_mask = search_with_line_limit(this, this_mask)
|
||||
mask[res.regs[2][0]:res.regs[2][1]] = this_mask
|
||||
else:
|
||||
mask[res.regs[0][0]:res.regs[0][1]] = PRESERVE
|
||||
return text, mask
|
||||
return search_with_line_limit(text, mask)
|
||||
|
||||
|
||||
|
||||
"""
|
||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
||||
Latex Merge File
|
||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
||||
"""
|
||||
|
||||
def find_main_tex_file(file_manifest, mode):
|
||||
"""
|
||||
在多Tex文档中,寻找主文件,必须包含documentclass,返回找到的第一个。
|
||||
P.S. 但愿没人把latex模板放在里面传进来 (6.25 加入判定latex模板的代码)
|
||||
"""
|
||||
canidates = []
|
||||
for texf in file_manifest:
|
||||
if os.path.basename(texf).startswith('merge'):
|
||||
continue
|
||||
with open(texf, 'r', encoding='utf8', errors='ignore') as f:
|
||||
file_content = f.read()
|
||||
if r'\documentclass' in file_content:
|
||||
canidates.append(texf)
|
||||
else:
|
||||
continue
|
||||
|
||||
if len(canidates) == 0:
|
||||
raise RuntimeError('无法找到一个主Tex文件(包含documentclass关键字)')
|
||||
elif len(canidates) == 1:
|
||||
return canidates[0]
|
||||
else: # if len(canidates) >= 2 通过一些Latex模板中常见(但通常不会出现在正文)的单词,对不同latex源文件扣分,取评分最高者返回
|
||||
canidates_score = []
|
||||
# 给出一些判定模板文档的词作为扣分项
|
||||
unexpected_words = ['\LaTeX', 'manuscript', 'Guidelines', 'font', 'citations', 'rejected', 'blind review', 'reviewers']
|
||||
expected_words = ['\input', '\ref', '\cite']
|
||||
for texf in canidates:
|
||||
canidates_score.append(0)
|
||||
with open(texf, 'r', encoding='utf8', errors='ignore') as f:
|
||||
file_content = f.read()
|
||||
for uw in unexpected_words:
|
||||
if uw in file_content:
|
||||
canidates_score[-1] -= 1
|
||||
for uw in expected_words:
|
||||
if uw in file_content:
|
||||
canidates_score[-1] += 1
|
||||
select = np.argmax(canidates_score) # 取评分最高者返回
|
||||
return canidates[select]
|
||||
|
||||
def rm_comments(main_file):
|
||||
new_file_remove_comment_lines = []
|
||||
for l in main_file.splitlines():
|
||||
# 删除整行的空注释
|
||||
if l.lstrip().startswith("%"):
|
||||
pass
|
||||
else:
|
||||
new_file_remove_comment_lines.append(l)
|
||||
main_file = '\n'.join(new_file_remove_comment_lines)
|
||||
# main_file = re.sub(r"\\include{(.*?)}", r"\\input{\1}", main_file) # 将 \include 命令转换为 \input 命令
|
||||
main_file = re.sub(r'(?<!\\)%.*', '', main_file) # 使用正则表达式查找半行注释, 并替换为空字符串
|
||||
return main_file
|
||||
|
||||
def find_tex_file_ignore_case(fp):
|
||||
dir_name = os.path.dirname(fp)
|
||||
base_name = os.path.basename(fp)
|
||||
if not base_name.endswith('.tex'): base_name+='.tex'
|
||||
if os.path.exists(pj(dir_name, base_name)): return pj(dir_name, base_name)
|
||||
# go case in-sensitive
|
||||
import glob
|
||||
for f in glob.glob(dir_name+'/*.tex'):
|
||||
base_name_s = os.path.basename(fp)
|
||||
if base_name_s.lower() == base_name.lower(): return f
|
||||
return None
|
||||
|
||||
def merge_tex_files_(project_foler, main_file, mode):
|
||||
"""
|
||||
Merge Tex project recrusively
|
||||
"""
|
||||
main_file = rm_comments(main_file)
|
||||
for s in reversed([q for q in re.finditer(r"\\input\{(.*?)\}", main_file, re.M)]):
|
||||
f = s.group(1)
|
||||
fp = os.path.join(project_foler, f)
|
||||
fp = find_tex_file_ignore_case(fp)
|
||||
if fp:
|
||||
with open(fp, 'r', encoding='utf-8', errors='replace') as fx: c = fx.read()
|
||||
else:
|
||||
raise RuntimeError(f'找不到{fp},Tex源文件缺失!')
|
||||
c = merge_tex_files_(project_foler, c, mode)
|
||||
main_file = main_file[:s.span()[0]] + c + main_file[s.span()[1]:]
|
||||
return main_file
|
||||
|
||||
def merge_tex_files(project_foler, main_file, mode):
|
||||
"""
|
||||
Merge Tex project recrusively
|
||||
P.S. 顺便把CTEX塞进去以支持中文
|
||||
P.S. 顺便把Latex的注释去除
|
||||
"""
|
||||
main_file = merge_tex_files_(project_foler, main_file, mode)
|
||||
main_file = rm_comments(main_file)
|
||||
|
||||
if mode == 'translate_zh':
|
||||
# find paper documentclass
|
||||
pattern = re.compile(r'\\documentclass.*\n')
|
||||
match = pattern.search(main_file)
|
||||
assert match is not None, "Cannot find documentclass statement!"
|
||||
position = match.end()
|
||||
add_ctex = '\\usepackage{ctex}\n'
|
||||
add_url = '\\usepackage{url}\n' if '{url}' not in main_file else ''
|
||||
main_file = main_file[:position] + add_ctex + add_url + main_file[position:]
|
||||
# fontset=windows
|
||||
import platform
|
||||
main_file = re.sub(r"\\documentclass\[(.*?)\]{(.*?)}", r"\\documentclass[\1,fontset=windows,UTF8]{\2}",main_file)
|
||||
main_file = re.sub(r"\\documentclass{(.*?)}", r"\\documentclass[fontset=windows,UTF8]{\1}",main_file)
|
||||
# find paper abstract
|
||||
pattern_opt1 = re.compile(r'\\begin\{abstract\}.*\n')
|
||||
pattern_opt2 = re.compile(r"\\abstract\{(.*?)\}", flags=re.DOTALL)
|
||||
match_opt1 = pattern_opt1.search(main_file)
|
||||
match_opt2 = pattern_opt2.search(main_file)
|
||||
assert (match_opt1 is not None) or (match_opt2 is not None), "Cannot find paper abstract section!"
|
||||
return main_file
|
||||
|
||||
|
||||
"""
|
||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
||||
Post process
|
||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
||||
"""
|
||||
def mod_inbraket(match):
|
||||
"""
|
||||
为啥chatgpt会把cite里面的逗号换成中文逗号呀
|
||||
"""
|
||||
# get the matched string
|
||||
cmd = match.group(1)
|
||||
str_to_modify = match.group(2)
|
||||
# modify the matched string
|
||||
str_to_modify = str_to_modify.replace(':', ':') # 前面是中文冒号,后面是英文冒号
|
||||
str_to_modify = str_to_modify.replace(',', ',') # 前面是中文逗号,后面是英文逗号
|
||||
# str_to_modify = 'BOOM'
|
||||
return "\\" + cmd + "{" + str_to_modify + "}"
|
||||
|
||||
def fix_content(final_tex, node_string):
|
||||
"""
|
||||
Fix common GPT errors to increase success rate
|
||||
"""
|
||||
final_tex = re.sub(r"(?<!\\)%", "\\%", final_tex)
|
||||
final_tex = re.sub(r"\\([a-z]{2,10})\ \{", r"\\\1{", string=final_tex)
|
||||
final_tex = re.sub(r"\\\ ([a-z]{2,10})\{", r"\\\1{", string=final_tex)
|
||||
final_tex = re.sub(r"\\([a-z]{2,10})\{([^\}]*?)\}", mod_inbraket, string=final_tex)
|
||||
|
||||
if "Traceback" in final_tex and "[Local Message]" in final_tex:
|
||||
final_tex = node_string # 出问题了,还原原文
|
||||
if node_string.count('\\begin') != final_tex.count('\\begin'):
|
||||
final_tex = node_string # 出问题了,还原原文
|
||||
if node_string.count('\_') > 0 and node_string.count('\_') > final_tex.count('\_'):
|
||||
# walk and replace any _ without \
|
||||
final_tex = re.sub(r"(?<!\\)_", "\\_", final_tex)
|
||||
|
||||
def compute_brace_level(string):
|
||||
# this function count the number of { and }
|
||||
brace_level = 0
|
||||
for c in string:
|
||||
if c == "{": brace_level += 1
|
||||
elif c == "}": brace_level -= 1
|
||||
return brace_level
|
||||
def join_most(tex_t, tex_o):
|
||||
# this function join translated string and original string when something goes wrong
|
||||
p_t = 0
|
||||
p_o = 0
|
||||
def find_next(string, chars, begin):
|
||||
p = begin
|
||||
while p < len(string):
|
||||
if string[p] in chars: return p, string[p]
|
||||
p += 1
|
||||
return None, None
|
||||
while True:
|
||||
res1, char = find_next(tex_o, ['{','}'], p_o)
|
||||
if res1 is None: break
|
||||
res2, char = find_next(tex_t, [char], p_t)
|
||||
if res2 is None: break
|
||||
p_o = res1 + 1
|
||||
p_t = res2 + 1
|
||||
return tex_t[:p_t] + tex_o[p_o:]
|
||||
|
||||
if compute_brace_level(final_tex) != compute_brace_level(node_string):
|
||||
# 出问题了,还原部分原文,保证括号正确
|
||||
final_tex = join_most(final_tex, node_string)
|
||||
return final_tex
|
||||
|
||||
def compile_latex_with_timeout(command, cwd, timeout=60):
|
||||
import subprocess
|
||||
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=cwd)
|
||||
try:
|
||||
stdout, stderr = process.communicate(timeout=timeout)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
stdout, stderr = process.communicate()
|
||||
print("Process timed out!")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
|
||||
def merge_pdfs(pdf1_path, pdf2_path, output_path):
|
||||
import PyPDF2
|
||||
Percent = 0.8
|
||||
# Open the first PDF file
|
||||
with open(pdf1_path, 'rb') as pdf1_file:
|
||||
pdf1_reader = PyPDF2.PdfFileReader(pdf1_file)
|
||||
# Open the second PDF file
|
||||
with open(pdf2_path, 'rb') as pdf2_file:
|
||||
pdf2_reader = PyPDF2.PdfFileReader(pdf2_file)
|
||||
# Create a new PDF file to store the merged pages
|
||||
output_writer = PyPDF2.PdfFileWriter()
|
||||
# Determine the number of pages in each PDF file
|
||||
num_pages = max(pdf1_reader.numPages, pdf2_reader.numPages)
|
||||
# Merge the pages from the two PDF files
|
||||
for page_num in range(num_pages):
|
||||
# Add the page from the first PDF file
|
||||
if page_num < pdf1_reader.numPages:
|
||||
page1 = pdf1_reader.getPage(page_num)
|
||||
else:
|
||||
page1 = PyPDF2.PageObject.createBlankPage(pdf1_reader)
|
||||
# Add the page from the second PDF file
|
||||
if page_num < pdf2_reader.numPages:
|
||||
page2 = pdf2_reader.getPage(page_num)
|
||||
else:
|
||||
page2 = PyPDF2.PageObject.createBlankPage(pdf1_reader)
|
||||
# Create a new empty page with double width
|
||||
new_page = PyPDF2.PageObject.createBlankPage(
|
||||
width = int(int(page1.mediaBox.getWidth()) + int(page2.mediaBox.getWidth()) * Percent),
|
||||
height = max(page1.mediaBox.getHeight(), page2.mediaBox.getHeight())
|
||||
)
|
||||
new_page.mergeTranslatedPage(page1, 0, 0)
|
||||
new_page.mergeTranslatedPage(page2, int(int(page1.mediaBox.getWidth())-int(page2.mediaBox.getWidth())* (1-Percent)), 0)
|
||||
output_writer.addPage(new_page)
|
||||
# Save the merged PDF file
|
||||
with open(output_path, 'wb') as output_file:
|
||||
output_writer.write(output_file)
|
||||
@@ -0,0 +1,130 @@
|
||||
import time, threading, json
|
||||
|
||||
|
||||
class AliyunASR():
|
||||
|
||||
def test_on_sentence_begin(self, message, *args):
|
||||
# print("test_on_sentence_begin:{}".format(message))
|
||||
pass
|
||||
|
||||
def test_on_sentence_end(self, message, *args):
|
||||
# print("test_on_sentence_end:{}".format(message))
|
||||
message = json.loads(message)
|
||||
self.parsed_sentence = message['payload']['result']
|
||||
self.event_on_entence_end.set()
|
||||
print(self.parsed_sentence)
|
||||
|
||||
def test_on_start(self, message, *args):
|
||||
# print("test_on_start:{}".format(message))
|
||||
pass
|
||||
|
||||
def test_on_error(self, message, *args):
|
||||
print("on_error args=>{}".format(args))
|
||||
pass
|
||||
|
||||
def test_on_close(self, *args):
|
||||
self.aliyun_service_ok = False
|
||||
pass
|
||||
|
||||
def test_on_result_chg(self, message, *args):
|
||||
# print("test_on_chg:{}".format(message))
|
||||
message = json.loads(message)
|
||||
self.parsed_text = message['payload']['result']
|
||||
self.event_on_result_chg.set()
|
||||
|
||||
def test_on_completed(self, message, *args):
|
||||
# print("on_completed:args=>{} message=>{}".format(args, message))
|
||||
pass
|
||||
|
||||
|
||||
def audio_convertion_thread(self, uuid):
|
||||
# 在一个异步线程中采集音频
|
||||
import nls # pip install git+https://github.com/aliyun/alibabacloud-nls-python-sdk.git
|
||||
import tempfile
|
||||
from scipy import io
|
||||
from toolbox import get_conf
|
||||
from .audio_io import change_sample_rate
|
||||
from .audio_io import RealtimeAudioDistribution
|
||||
NEW_SAMPLERATE = 16000
|
||||
rad = RealtimeAudioDistribution()
|
||||
rad.clean_up()
|
||||
temp_folder = tempfile.gettempdir()
|
||||
TOKEN, APPKEY = get_conf('ALIYUN_TOKEN', 'ALIYUN_APPKEY')
|
||||
if len(TOKEN) == 0:
|
||||
TOKEN = self.get_token()
|
||||
self.aliyun_service_ok = True
|
||||
URL="wss://nls-gateway.aliyuncs.com/ws/v1"
|
||||
sr = nls.NlsSpeechTranscriber(
|
||||
url=URL,
|
||||
token=TOKEN,
|
||||
appkey=APPKEY,
|
||||
on_sentence_begin=self.test_on_sentence_begin,
|
||||
on_sentence_end=self.test_on_sentence_end,
|
||||
on_start=self.test_on_start,
|
||||
on_result_changed=self.test_on_result_chg,
|
||||
on_completed=self.test_on_completed,
|
||||
on_error=self.test_on_error,
|
||||
on_close=self.test_on_close,
|
||||
callback_args=[uuid.hex]
|
||||
)
|
||||
|
||||
r = sr.start(aformat="pcm",
|
||||
enable_intermediate_result=True,
|
||||
enable_punctuation_prediction=True,
|
||||
enable_inverse_text_normalization=True)
|
||||
|
||||
while not self.stop:
|
||||
# time.sleep(self.capture_interval)
|
||||
audio = rad.read(uuid.hex)
|
||||
if audio is not None:
|
||||
# convert to pcm file
|
||||
temp_file = f'{temp_folder}/{uuid.hex}.pcm' #
|
||||
dsdata = change_sample_rate(audio, rad.rate, NEW_SAMPLERATE) # 48000 --> 16000
|
||||
io.wavfile.write(temp_file, NEW_SAMPLERATE, dsdata)
|
||||
# read pcm binary
|
||||
with open(temp_file, "rb") as f: data = f.read()
|
||||
# print('audio len:', len(audio), '\t ds len:', len(dsdata), '\t need n send:', len(data)//640)
|
||||
slices = zip(*(iter(data),) * 640) # 640个字节为一组
|
||||
for i in slices: sr.send_audio(bytes(i))
|
||||
else:
|
||||
time.sleep(0.1)
|
||||
|
||||
if not self.aliyun_service_ok:
|
||||
self.stop = True
|
||||
self.stop_msg = 'Aliyun音频服务异常,请检查ALIYUN_TOKEN和ALIYUN_APPKEY是否过期。'
|
||||
r = sr.stop()
|
||||
|
||||
def get_token(self):
|
||||
from toolbox import get_conf
|
||||
import json
|
||||
from aliyunsdkcore.request import CommonRequest
|
||||
from aliyunsdkcore.client import AcsClient
|
||||
AccessKey_ID, AccessKey_secret = get_conf('ALIYUN_ACCESSKEY', 'ALIYUN_SECRET')
|
||||
|
||||
# 创建AcsClient实例
|
||||
client = AcsClient(
|
||||
AccessKey_ID,
|
||||
AccessKey_secret,
|
||||
"cn-shanghai"
|
||||
)
|
||||
|
||||
# 创建request,并设置参数。
|
||||
request = CommonRequest()
|
||||
request.set_method('POST')
|
||||
request.set_domain('nls-meta.cn-shanghai.aliyuncs.com')
|
||||
request.set_version('2019-02-28')
|
||||
request.set_action_name('CreateToken')
|
||||
|
||||
try:
|
||||
response = client.do_action_with_exception(request)
|
||||
print(response)
|
||||
jss = json.loads(response)
|
||||
if 'Token' in jss and 'Id' in jss['Token']:
|
||||
token = jss['Token']['Id']
|
||||
expireTime = jss['Token']['ExpireTime']
|
||||
print("token = " + token)
|
||||
print("expireTime = " + str(expireTime))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
return token
|
||||
@@ -0,0 +1,51 @@
|
||||
import numpy as np
|
||||
from scipy import interpolate
|
||||
|
||||
def Singleton(cls):
|
||||
_instance = {}
|
||||
|
||||
def _singleton(*args, **kargs):
|
||||
if cls not in _instance:
|
||||
_instance[cls] = cls(*args, **kargs)
|
||||
return _instance[cls]
|
||||
|
||||
return _singleton
|
||||
|
||||
|
||||
@Singleton
|
||||
class RealtimeAudioDistribution():
|
||||
def __init__(self) -> None:
|
||||
self.data = {}
|
||||
self.max_len = 1024*1024
|
||||
self.rate = 48000 # 只读,每秒采样数量
|
||||
|
||||
def clean_up(self):
|
||||
self.data = {}
|
||||
|
||||
def feed(self, uuid, audio):
|
||||
self.rate, audio_ = audio
|
||||
# print('feed', len(audio_), audio_[-25:])
|
||||
if uuid not in self.data:
|
||||
self.data[uuid] = audio_
|
||||
else:
|
||||
new_arr = np.concatenate((self.data[uuid], audio_))
|
||||
if len(new_arr) > self.max_len: new_arr = new_arr[-self.max_len:]
|
||||
self.data[uuid] = new_arr
|
||||
|
||||
def read(self, uuid):
|
||||
if uuid in self.data:
|
||||
res = self.data.pop(uuid)
|
||||
print('\r read-', len(res), '-', max(res), end='', flush=True)
|
||||
else:
|
||||
res = None
|
||||
return res
|
||||
|
||||
def change_sample_rate(audio, old_sr, new_sr):
|
||||
duration = audio.shape[0] / old_sr
|
||||
|
||||
time_old = np.linspace(0, duration, audio.shape[0])
|
||||
time_new = np.linspace(0, duration, int(audio.shape[0] * new_sr / old_sr))
|
||||
|
||||
interpolator = interpolate.interp1d(time_old, audio.T)
|
||||
new_audio = interpolator(time_new).T
|
||||
return new_audio.astype(np.int16)
|
||||
@@ -144,11 +144,11 @@ def 下载arxiv论文并翻译摘要(txt, llm_kwargs, plugin_kwargs, chatbot, hi
|
||||
|
||||
# 尝试导入依赖,如果缺少依赖,则给出安装建议
|
||||
try:
|
||||
import pdfminer, bs4
|
||||
import bs4
|
||||
except:
|
||||
report_execption(chatbot, history,
|
||||
a = f"解析项目: {txt}",
|
||||
b = f"导入软件依赖失败。使用该模块需要额外依赖,安装方法```pip install --upgrade pdfminer beautifulsoup4```。")
|
||||
b = f"导入软件依赖失败。使用该模块需要额外依赖,安装方法```pip install --upgrade beautifulsoup4```。")
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
return
|
||||
|
||||
|
||||
63
crazy_functions/交互功能函数模板.py
普通文件
63
crazy_functions/交互功能函数模板.py
普通文件
@@ -0,0 +1,63 @@
|
||||
from toolbox import CatchException, update_ui
|
||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
|
||||
|
||||
@CatchException
|
||||
def 交互功能模板函数(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||
"""
|
||||
txt 输入栏用户输入的文本,例如需要翻译的一段话,再例如一个包含了待处理文件的路径
|
||||
llm_kwargs gpt模型参数, 如温度和top_p等, 一般原样传递下去就行
|
||||
plugin_kwargs 插件模型的参数, 如温度和top_p等, 一般原样传递下去就行
|
||||
chatbot 聊天显示框的句柄,用于显示给用户
|
||||
history 聊天历史,前情提要
|
||||
system_prompt 给gpt的静默提醒
|
||||
web_port 当前软件运行的端口号
|
||||
"""
|
||||
history = [] # 清空历史,以免输入溢出
|
||||
chatbot.append(("这是什么功能?", "交互功能函数模板。在执行完成之后, 可以将自身的状态存储到cookie中, 等待用户的再次调用。"))
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
state = chatbot._cookies.get('plugin_state_0001', None) # 初始化插件状态
|
||||
|
||||
if state is None:
|
||||
chatbot._cookies['lock_plugin'] = 'crazy_functions.交互功能函数模板->交互功能模板函数' # 赋予插件锁定 锁定插件回调路径,当下一次用户提交时,会直接转到该函数
|
||||
chatbot._cookies['plugin_state_0001'] = 'wait_user_keyword' # 赋予插件状态
|
||||
|
||||
chatbot.append(("第一次调用:", "请输入关键词, 我将为您查找相关壁纸, 建议使用英文单词, 插件锁定中,请直接提交即可。"))
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
return
|
||||
|
||||
if state == 'wait_user_keyword':
|
||||
chatbot._cookies['lock_plugin'] = None # 解除插件锁定,避免遗忘导致死锁
|
||||
chatbot._cookies['plugin_state_0001'] = None # 解除插件状态,避免遗忘导致死锁
|
||||
|
||||
# 解除插件锁定
|
||||
chatbot.append((f"获取关键词:{txt}", ""))
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
page_return = get_image_page_by_keyword(txt)
|
||||
inputs=inputs_show_user=f"Extract all image urls in this html page, pick the first 5 images and show them with markdown format: \n\n {page_return}"
|
||||
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||
inputs=inputs, inputs_show_user=inputs_show_user,
|
||||
llm_kwargs=llm_kwargs, chatbot=chatbot, history=[],
|
||||
sys_prompt="When you want to show an image, use markdown format. e.g. . If there are no image url provided, answer 'no image url provided'"
|
||||
)
|
||||
chatbot[-1] = [chatbot[-1][0], gpt_say]
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
return
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------------
|
||||
|
||||
def get_image_page_by_keyword(keyword):
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
response = requests.get(f'https://wallhaven.cc/search?q={keyword}', timeout=2)
|
||||
res = "image urls: \n"
|
||||
for image_element in BeautifulSoup(response.content, 'html.parser').findAll("img"):
|
||||
try:
|
||||
res += image_element["data-src"]
|
||||
res += "\n"
|
||||
except:
|
||||
pass
|
||||
return res
|
||||
31
crazy_functions/命令行助手.py
普通文件
31
crazy_functions/命令行助手.py
普通文件
@@ -0,0 +1,31 @@
|
||||
from toolbox import CatchException, update_ui, gen_time_str
|
||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from .crazy_utils import input_clipping
|
||||
import copy, json
|
||||
|
||||
@CatchException
|
||||
def 命令行助手(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||
"""
|
||||
txt 输入栏用户输入的文本, 例如需要翻译的一段话, 再例如一个包含了待处理文件的路径
|
||||
llm_kwargs gpt模型参数, 如温度和top_p等, 一般原样传递下去就行
|
||||
plugin_kwargs 插件模型的参数, 暂时没有用武之地
|
||||
chatbot 聊天显示框的句柄, 用于显示给用户
|
||||
history 聊天历史, 前情提要
|
||||
system_prompt 给gpt的静默提醒
|
||||
web_port 当前软件运行的端口号
|
||||
"""
|
||||
# 清空历史, 以免输入溢出
|
||||
history = []
|
||||
|
||||
# 输入
|
||||
i_say = "请写bash命令实现以下功能:" + txt
|
||||
# 开始
|
||||
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||
inputs=i_say, inputs_show_user=txt,
|
||||
llm_kwargs=llm_kwargs, chatbot=chatbot, history=[],
|
||||
sys_prompt="你是一个Linux大师级用户。注意,当我要求你写bash命令时,尽可能地仅用一行命令解决我的要求。"
|
||||
)
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 界面更新
|
||||
|
||||
|
||||
|
||||
@@ -27,8 +27,10 @@ def gen_image(llm_kwargs, prompt, resolution="256x256"):
|
||||
}
|
||||
response = requests.post(url, headers=headers, json=data, proxies=proxies)
|
||||
print(response.content)
|
||||
image_url = json.loads(response.content.decode('utf8'))['data'][0]['url']
|
||||
|
||||
try:
|
||||
image_url = json.loads(response.content.decode('utf8'))['data'][0]['url']
|
||||
except:
|
||||
raise RuntimeError(response.content.decode())
|
||||
# 文件保存到本地
|
||||
r = requests.get(image_url, proxies=proxies)
|
||||
file_path = 'gpt_log/image_gen/'
|
||||
@@ -53,7 +55,7 @@ def 图片生成(prompt, llm_kwargs, plugin_kwargs, chatbot, history, system_pro
|
||||
web_port 当前软件运行的端口号
|
||||
"""
|
||||
history = [] # 清空历史,以免输入溢出
|
||||
chatbot.append(("这是什么功能?", "[Local Message] 生成图像, 请先把模型切换至gpt-xxxx或者api2d-xxxx。如果中文效果不理想, 尝试Prompt。正在处理中 ....."))
|
||||
chatbot.append(("这是什么功能?", "[Local Message] 生成图像, 请先把模型切换至gpt-*或者api2d-*。如果中文效果不理想, 请尝试英文Prompt。正在处理中 ....."))
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 由于请求gpt需要一段时间,我们先及时地做一次界面更新
|
||||
if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
|
||||
resolution = plugin_kwargs.get("advanced_arg", '256x256')
|
||||
|
||||
@@ -12,7 +12,7 @@ def write_chat_to_file(chatbot, history=None, file_name=None):
|
||||
file_name = 'chatGPT对话历史' + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '.html'
|
||||
os.makedirs('./gpt_log/', exist_ok=True)
|
||||
with open(f'./gpt_log/{file_name}', 'w', encoding='utf8') as f:
|
||||
from theme import advanced_css
|
||||
from themes.theme import advanced_css
|
||||
f.write(f'<!DOCTYPE html><head><meta charset="utf-8"><title>对话历史</title><style>{advanced_css}</style></head>')
|
||||
for i, contents in enumerate(chatbot):
|
||||
for j, content in enumerate(contents):
|
||||
|
||||
@@ -14,17 +14,19 @@ def 解析docx(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot
|
||||
doc = Document(fp)
|
||||
file_content = "\n".join([para.text for para in doc.paragraphs])
|
||||
else:
|
||||
import win32com.client
|
||||
word = win32com.client.Dispatch("Word.Application")
|
||||
word.visible = False
|
||||
# 打开文件
|
||||
print('fp', os.getcwd())
|
||||
doc = word.Documents.Open(os.getcwd() + '/' + fp)
|
||||
# file_content = doc.Content.Text
|
||||
doc = word.ActiveDocument
|
||||
file_content = doc.Range().Text
|
||||
doc.Close()
|
||||
word.Quit()
|
||||
try:
|
||||
import win32com.client
|
||||
word = win32com.client.Dispatch("Word.Application")
|
||||
word.visible = False
|
||||
# 打开文件
|
||||
doc = word.Documents.Open(os.getcwd() + '/' + fp)
|
||||
# file_content = doc.Content.Text
|
||||
doc = word.ActiveDocument
|
||||
file_content = doc.Range().Text
|
||||
doc.Close()
|
||||
word.Quit()
|
||||
except:
|
||||
raise RuntimeError('请先将.doc文档转换为.docx文档。')
|
||||
|
||||
print(file_content)
|
||||
# private_upload里面的文件名在解压zip后容易出现乱码(rar和7z格式正常),故可以只分析文章内容,不输入文件名
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from toolbox import update_ui, trimmed_format_exc, gen_time_str
|
||||
from toolbox import CatchException, report_execption, write_results_to_file
|
||||
import glob, time, os, re
|
||||
from toolbox import update_ui, trimmed_format_exc, gen_time_str, disable_auto_promotion
|
||||
from toolbox import CatchException, report_execption, write_history_to_file
|
||||
from toolbox import promote_file_to_downloadzone, get_log_folder
|
||||
fast_debug = False
|
||||
|
||||
class PaperFileGroup():
|
||||
@@ -42,13 +44,13 @@ class PaperFileGroup():
|
||||
def write_result(self, language):
|
||||
manifest = []
|
||||
for path, res in zip(self.file_paths, self.file_result):
|
||||
with open(path + f'.{gen_time_str()}.{language}.md', 'w', encoding='utf8') as f:
|
||||
manifest.append(path + f'.{gen_time_str()}.{language}.md')
|
||||
dst_file = os.path.join(get_log_folder(), f'{gen_time_str()}.md')
|
||||
with open(dst_file, 'w', encoding='utf8') as f:
|
||||
manifest.append(dst_file)
|
||||
f.write(res)
|
||||
return manifest
|
||||
|
||||
def 多文件翻译(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, language='en'):
|
||||
import time, os, re
|
||||
from .crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||
|
||||
# <-------- 读取Markdown文件,删除其中的所有注释 ---------->
|
||||
@@ -102,28 +104,38 @@ def 多文件翻译(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
|
||||
print(trimmed_format_exc())
|
||||
|
||||
# <-------- 整理结果,退出 ---------->
|
||||
create_report_file_name = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + f"-chatgpt.polish.md"
|
||||
res = write_results_to_file(gpt_response_collection, file_name=create_report_file_name)
|
||||
create_report_file_name = gen_time_str() + f"-chatgpt.md"
|
||||
res = write_history_to_file(gpt_response_collection, file_basename=create_report_file_name)
|
||||
promote_file_to_downloadzone(res, chatbot=chatbot)
|
||||
history = gpt_response_collection
|
||||
chatbot.append((f"{fp}完成了吗?", res))
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
|
||||
def get_files_from_everything(txt):
|
||||
import glob, os
|
||||
|
||||
def get_files_from_everything(txt, preference=''):
|
||||
if txt == "": return False, None, None
|
||||
success = True
|
||||
if txt.startswith('http'):
|
||||
# 网络的远程文件
|
||||
txt = txt.replace("https://github.com/", "https://raw.githubusercontent.com/")
|
||||
txt = txt.replace("/blob/", "/")
|
||||
import requests
|
||||
from toolbox import get_conf
|
||||
proxies, = get_conf('proxies')
|
||||
# 网络的远程文件
|
||||
if preference == 'Github':
|
||||
print('正在从github下载资源 ...')
|
||||
if not txt.endswith('.md'):
|
||||
# Make a request to the GitHub API to retrieve the repository information
|
||||
url = txt.replace("https://github.com/", "https://api.github.com/repos/") + '/readme'
|
||||
response = requests.get(url, proxies=proxies)
|
||||
txt = response.json()['download_url']
|
||||
else:
|
||||
txt = txt.replace("https://github.com/", "https://raw.githubusercontent.com/")
|
||||
txt = txt.replace("/blob/", "/")
|
||||
|
||||
r = requests.get(txt, proxies=proxies)
|
||||
with open('./gpt_log/temp.md', 'wb+') as f: f.write(r.content)
|
||||
project_folder = './gpt_log/'
|
||||
file_manifest = ['./gpt_log/temp.md']
|
||||
download_local = f'{get_log_folder(plugin_name="批量Markdown翻译")}/raw-readme-{gen_time_str()}.md'
|
||||
project_folder = f'{get_log_folder(plugin_name="批量Markdown翻译")}'
|
||||
with open(download_local, 'wb+') as f: f.write(r.content)
|
||||
file_manifest = [download_local]
|
||||
elif txt.endswith('.md'):
|
||||
# 直接给定文件
|
||||
file_manifest = [txt]
|
||||
@@ -145,11 +157,11 @@ def Markdown英译中(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_p
|
||||
"函数插件功能?",
|
||||
"对整个Markdown项目进行翻译。函数插件贡献者: Binary-Husky"])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
disable_auto_promotion(chatbot)
|
||||
|
||||
# 尝试导入依赖,如果缺少依赖,则给出安装建议
|
||||
try:
|
||||
import tiktoken
|
||||
import glob, os
|
||||
except:
|
||||
report_execption(chatbot, history,
|
||||
a=f"解析项目: {txt}",
|
||||
@@ -158,7 +170,7 @@ def Markdown英译中(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_p
|
||||
return
|
||||
history = [] # 清空历史,以免输入溢出
|
||||
|
||||
success, file_manifest, project_folder = get_files_from_everything(txt)
|
||||
success, file_manifest, project_folder = get_files_from_everything(txt, preference="Github")
|
||||
|
||||
if not success:
|
||||
# 什么都没有
|
||||
@@ -185,11 +197,11 @@ def Markdown中译英(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_p
|
||||
"函数插件功能?",
|
||||
"对整个Markdown项目进行翻译。函数插件贡献者: Binary-Husky"])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
disable_auto_promotion(chatbot)
|
||||
|
||||
# 尝试导入依赖,如果缺少依赖,则给出安装建议
|
||||
try:
|
||||
import tiktoken
|
||||
import glob, os
|
||||
except:
|
||||
report_execption(chatbot, history,
|
||||
a=f"解析项目: {txt}",
|
||||
@@ -218,11 +230,11 @@ def Markdown翻译指定语言(txt, llm_kwargs, plugin_kwargs, chatbot, history,
|
||||
"函数插件功能?",
|
||||
"对整个Markdown项目进行翻译。函数插件贡献者: Binary-Husky"])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
disable_auto_promotion(chatbot)
|
||||
|
||||
# 尝试导入依赖,如果缺少依赖,则给出安装建议
|
||||
try:
|
||||
import tiktoken
|
||||
import glob, os
|
||||
except:
|
||||
report_execption(chatbot, history,
|
||||
a=f"解析项目: {txt}",
|
||||
|
||||
@@ -1,121 +1,107 @@
|
||||
from toolbox import update_ui
|
||||
from toolbox import update_ui, promote_file_to_downloadzone, gen_time_str
|
||||
from toolbox import CatchException, report_execption, write_results_to_file
|
||||
import re
|
||||
import unicodedata
|
||||
fast_debug = False
|
||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from .crazy_utils import read_and_clean_pdf_text
|
||||
from .crazy_utils import input_clipping
|
||||
|
||||
def is_paragraph_break(match):
|
||||
"""
|
||||
根据给定的匹配结果来判断换行符是否表示段落分隔。
|
||||
如果换行符前为句子结束标志(句号,感叹号,问号),且下一个字符为大写字母,则换行符更有可能表示段落分隔。
|
||||
也可以根据之前的内容长度来判断段落是否已经足够长。
|
||||
"""
|
||||
prev_char, next_char = match.groups()
|
||||
|
||||
# 句子结束标志
|
||||
sentence_endings = ".!?"
|
||||
|
||||
# 设定一个最小段落长度阈值
|
||||
min_paragraph_length = 140
|
||||
|
||||
if prev_char in sentence_endings and next_char.isupper() and len(match.string[:match.start(1)]) > min_paragraph_length:
|
||||
return "\n\n"
|
||||
else:
|
||||
return " "
|
||||
|
||||
def normalize_text(text):
|
||||
"""
|
||||
通过把连字(ligatures)等文本特殊符号转换为其基本形式来对文本进行归一化处理。
|
||||
例如,将连字 "fi" 转换为 "f" 和 "i"。
|
||||
"""
|
||||
# 对文本进行归一化处理,分解连字
|
||||
normalized_text = unicodedata.normalize("NFKD", text)
|
||||
|
||||
# 替换其他特殊字符
|
||||
cleaned_text = re.sub(r'[^\x00-\x7F]+', '', normalized_text)
|
||||
|
||||
return cleaned_text
|
||||
|
||||
def clean_text(raw_text):
|
||||
"""
|
||||
对从 PDF 提取出的原始文本进行清洗和格式化处理。
|
||||
1. 对原始文本进行归一化处理。
|
||||
2. 替换跨行的连词
|
||||
3. 根据 heuristic 规则判断换行符是否是段落分隔,并相应地进行替换
|
||||
"""
|
||||
# 对文本进行归一化处理
|
||||
normalized_text = normalize_text(raw_text)
|
||||
|
||||
# 替换跨行的连词
|
||||
text = re.sub(r'(\w+-\n\w+)', lambda m: m.group(1).replace('-\n', ''), normalized_text)
|
||||
|
||||
# 根据前后相邻字符的特点,找到原文本中的换行符
|
||||
newlines = re.compile(r'(\S)\n(\S)')
|
||||
|
||||
# 根据 heuristic 规则,用空格或段落分隔符替换原换行符
|
||||
final_text = re.sub(newlines, lambda m: m.group(1) + is_paragraph_break(m) + m.group(2), text)
|
||||
|
||||
return final_text.strip()
|
||||
|
||||
def 解析PDF(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
||||
import time, glob, os, fitz
|
||||
print('begin analysis on:', file_manifest)
|
||||
for index, fp in enumerate(file_manifest):
|
||||
with fitz.open(fp) as doc:
|
||||
file_content = ""
|
||||
for page in doc:
|
||||
file_content += page.get_text()
|
||||
file_content = clean_text(file_content)
|
||||
print(file_content)
|
||||
file_write_buffer = []
|
||||
for file_name in file_manifest:
|
||||
print('begin analysis on:', file_name)
|
||||
############################## <第 0 步,切割PDF> ##################################
|
||||
# 递归地切割PDF文件,每一块(尽量是完整的一个section,比如introduction,experiment等,必要时再进行切割)
|
||||
# 的长度必须小于 2500 个 Token
|
||||
file_content, page_one = read_and_clean_pdf_text(file_name) # (尝试)按照章节切割PDF
|
||||
file_content = file_content.encode('utf-8', 'ignore').decode() # avoid reading non-utf8 chars
|
||||
page_one = str(page_one).encode('utf-8', 'ignore').decode() # avoid reading non-utf8 chars
|
||||
|
||||
TOKEN_LIMIT_PER_FRAGMENT = 2500
|
||||
|
||||
prefix = "接下来请你逐文件分析下面的论文文件,概括其内容" if index==0 else ""
|
||||
i_say = prefix + f'请对下面的文章片段用中文做一个概述,文件名是{os.path.relpath(fp, project_folder)},文章内容是 ```{file_content}```'
|
||||
i_say_show_user = prefix + f'[{index}/{len(file_manifest)}] 请对下面的文章片段做一个概述: {os.path.abspath(fp)}'
|
||||
chatbot.append((i_say_show_user, "[Local Message] waiting gpt response."))
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
from .crazy_utils import breakdown_txt_to_satisfy_token_limit_for_pdf
|
||||
from request_llm.bridge_all import model_info
|
||||
enc = model_info["gpt-3.5-turbo"]['tokenizer']
|
||||
def get_token_num(txt): return len(enc.encode(txt, disallowed_special=()))
|
||||
paper_fragments = breakdown_txt_to_satisfy_token_limit_for_pdf(
|
||||
txt=file_content, get_token_fn=get_token_num, limit=TOKEN_LIMIT_PER_FRAGMENT)
|
||||
page_one_fragments = breakdown_txt_to_satisfy_token_limit_for_pdf(
|
||||
txt=str(page_one), get_token_fn=get_token_num, limit=TOKEN_LIMIT_PER_FRAGMENT//4)
|
||||
# 为了更好的效果,我们剥离Introduction之后的部分(如果有)
|
||||
paper_meta = page_one_fragments[0].split('introduction')[0].split('Introduction')[0].split('INTRODUCTION')[0]
|
||||
|
||||
############################## <第 1 步,从摘要中提取高价值信息,放到history中> ##################################
|
||||
final_results = []
|
||||
final_results.append(paper_meta)
|
||||
|
||||
if not fast_debug:
|
||||
msg = '正常'
|
||||
# ** gpt request **
|
||||
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||
inputs=i_say,
|
||||
inputs_show_user=i_say_show_user,
|
||||
llm_kwargs=llm_kwargs,
|
||||
chatbot=chatbot,
|
||||
history=[],
|
||||
sys_prompt="总结文章。"
|
||||
) # 带超时倒计时
|
||||
|
||||
############################## <第 2 步,迭代地历遍整个文章,提取精炼信息> ##################################
|
||||
i_say_show_user = f'首先你在中文语境下通读整篇论文。'; gpt_say = "[Local Message] 收到。" # 用户提示
|
||||
chatbot.append([i_say_show_user, gpt_say]); yield from update_ui(chatbot=chatbot, history=[]) # 更新UI
|
||||
|
||||
chatbot[-1] = (i_say_show_user, gpt_say)
|
||||
history.append(i_say_show_user); history.append(gpt_say)
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg=msg) # 刷新界面
|
||||
if not fast_debug: time.sleep(2)
|
||||
iteration_results = []
|
||||
last_iteration_result = paper_meta # 初始值是摘要
|
||||
MAX_WORD_TOTAL = 4096 * 0.7
|
||||
n_fragment = len(paper_fragments)
|
||||
if n_fragment >= 20: print('文章极长,不能达到预期效果')
|
||||
for i in range(n_fragment):
|
||||
NUM_OF_WORD = MAX_WORD_TOTAL // n_fragment
|
||||
i_say = f"Read this section, recapitulate the content of this section with less than {NUM_OF_WORD} Chinese characters: {paper_fragments[i]}"
|
||||
i_say_show_user = f"[{i+1}/{n_fragment}] Read this section, recapitulate the content of this section with less than {NUM_OF_WORD} Chinese characters: {paper_fragments[i][:200]}"
|
||||
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(i_say, i_say_show_user, # i_say=真正给chatgpt的提问, i_say_show_user=给用户看的提问
|
||||
llm_kwargs, chatbot,
|
||||
history=["The main idea of the previous section is?", last_iteration_result], # 迭代上一次的结果
|
||||
sys_prompt="Extract the main idea of this section with Chinese." # 提示
|
||||
)
|
||||
iteration_results.append(gpt_say)
|
||||
last_iteration_result = gpt_say
|
||||
|
||||
all_file = ', '.join([os.path.relpath(fp, project_folder) for index, fp in enumerate(file_manifest)])
|
||||
i_say = f'根据以上你自己的分析,对全文进行概括,用学术性语言写一段中文摘要,然后再写一段英文摘要(包括{all_file})。'
|
||||
chatbot.append((i_say, "[Local Message] waiting gpt response."))
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
if not fast_debug:
|
||||
msg = '正常'
|
||||
# ** gpt request **
|
||||
############################## <第 3 步,整理history,提取总结> ##################################
|
||||
final_results.extend(iteration_results)
|
||||
final_results.append(f'Please conclude this paper discussed above。')
|
||||
# This prompt is from https://github.com/kaixindelele/ChatPaper/blob/main/chat_paper.py
|
||||
NUM_OF_WORD = 1000
|
||||
i_say = """
|
||||
1. Mark the title of the paper (with Chinese translation)
|
||||
2. list all the authors' names (use English)
|
||||
3. mark the first author's affiliation (output Chinese translation only)
|
||||
4. mark the keywords of this article (use English)
|
||||
5. link to the paper, Github code link (if available, fill in Github:None if not)
|
||||
6. summarize according to the following four points.Be sure to use Chinese answers (proper nouns need to be marked in English)
|
||||
- (1):What is the research background of this article?
|
||||
- (2):What are the past methods? What are the problems with them? Is the approach well motivated?
|
||||
- (3):What is the research methodology proposed in this paper?
|
||||
- (4):On what task and what performance is achieved by the methods in this paper? Can the performance support their goals?
|
||||
Follow the format of the output that follows:
|
||||
1. Title: xxx\n\n
|
||||
2. Authors: xxx\n\n
|
||||
3. Affiliation: xxx\n\n
|
||||
4. Keywords: xxx\n\n
|
||||
5. Urls: xxx or xxx , xxx \n\n
|
||||
6. Summary: \n\n
|
||||
- (1):xxx;\n
|
||||
- (2):xxx;\n
|
||||
- (3):xxx;\n
|
||||
- (4):xxx.\n\n
|
||||
Be sure to use Chinese answers (proper nouns need to be marked in English), statements as concise and academic as possible,
|
||||
do not have too much repetitive information, numerical values using the original numbers.
|
||||
"""
|
||||
# This prompt is from https://github.com/kaixindelele/ChatPaper/blob/main/chat_paper.py
|
||||
file_write_buffer.extend(final_results)
|
||||
i_say, final_results = input_clipping(i_say, final_results, max_token_limit=2000)
|
||||
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||
inputs=i_say,
|
||||
inputs_show_user=i_say,
|
||||
llm_kwargs=llm_kwargs,
|
||||
chatbot=chatbot,
|
||||
history=history,
|
||||
sys_prompt="总结文章。"
|
||||
) # 带超时倒计时
|
||||
inputs=i_say, inputs_show_user='开始最终总结',
|
||||
llm_kwargs=llm_kwargs, chatbot=chatbot, history=final_results,
|
||||
sys_prompt= f"Extract the main idea of this paper with less than {NUM_OF_WORD} Chinese characters"
|
||||
)
|
||||
final_results.append(gpt_say)
|
||||
file_write_buffer.extend([i_say, gpt_say])
|
||||
############################## <第 4 步,设置一个token上限> ##################################
|
||||
_, final_results = input_clipping("", final_results, max_token_limit=3200)
|
||||
yield from update_ui(chatbot=chatbot, history=final_results) # 注意这里的历史记录被替代了
|
||||
|
||||
chatbot[-1] = (i_say, gpt_say)
|
||||
history.append(i_say); history.append(gpt_say)
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg=msg) # 刷新界面
|
||||
res = write_results_to_file(history)
|
||||
chatbot.append(("完成了吗?", res))
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg=msg) # 刷新界面
|
||||
res = write_results_to_file(file_write_buffer, file_name=gen_time_str())
|
||||
promote_file_to_downloadzone(res.split('\t')[-1], chatbot=chatbot)
|
||||
yield from update_ui(chatbot=chatbot, history=final_results) # 刷新界面
|
||||
|
||||
|
||||
@CatchException
|
||||
@@ -151,10 +137,7 @@ def 批量总结PDF文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
|
||||
return
|
||||
|
||||
# 搜索需要处理的文件清单
|
||||
file_manifest = [f for f in glob.glob(f'{project_folder}/**/*.pdf', recursive=True)] # + \
|
||||
# [f for f in glob.glob(f'{project_folder}/**/*.tex', recursive=True)] + \
|
||||
# [f for f in glob.glob(f'{project_folder}/**/*.cpp', recursive=True)] + \
|
||||
# [f for f in glob.glob(f'{project_folder}/**/*.c', recursive=True)]
|
||||
file_manifest = [f for f in glob.glob(f'{project_folder}/**/*.pdf', recursive=True)]
|
||||
|
||||
# 如果没找到任何文件
|
||||
if len(file_manifest) == 0:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from toolbox import CatchException, report_execption, write_results_to_file
|
||||
from toolbox import update_ui
|
||||
from toolbox import update_ui, promote_file_to_downloadzone
|
||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from .crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||
from .crazy_utils import read_and_clean_pdf_text
|
||||
@@ -147,23 +147,14 @@ def 解析PDF(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot,
|
||||
print('writing html result failed:', trimmed_format_exc())
|
||||
|
||||
# 准备文件的下载
|
||||
import shutil
|
||||
for pdf_path in generated_conclusion_files:
|
||||
# 重命名文件
|
||||
rename_file = f'./gpt_log/翻译-{os.path.basename(pdf_path)}'
|
||||
if os.path.exists(rename_file):
|
||||
os.remove(rename_file)
|
||||
shutil.copyfile(pdf_path, rename_file)
|
||||
if os.path.exists(pdf_path):
|
||||
os.remove(pdf_path)
|
||||
rename_file = f'翻译-{os.path.basename(pdf_path)}'
|
||||
promote_file_to_downloadzone(pdf_path, rename_file=rename_file, chatbot=chatbot)
|
||||
for html_path in generated_html_files:
|
||||
# 重命名文件
|
||||
rename_file = f'./gpt_log/翻译-{os.path.basename(html_path)}'
|
||||
if os.path.exists(rename_file):
|
||||
os.remove(rename_file)
|
||||
shutil.copyfile(html_path, rename_file)
|
||||
if os.path.exists(html_path):
|
||||
os.remove(html_path)
|
||||
rename_file = f'翻译-{os.path.basename(html_path)}'
|
||||
promote_file_to_downloadzone(html_path, rename_file=rename_file, chatbot=chatbot)
|
||||
chatbot.append(("给出输出文件清单", str(generated_conclusion_files + generated_html_files)))
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
|
||||
@@ -1,87 +1,70 @@
|
||||
from toolbox import CatchException, update_ui, gen_time_str
|
||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from .crazy_utils import input_clipping
|
||||
import copy, json
|
||||
|
||||
|
||||
prompt = """
|
||||
I have to achieve some functionalities by calling one of the functions below.
|
||||
Your job is to find the correct funtion to use to satisfy my requirement,
|
||||
and then write python code to call this function with correct parameters.
|
||||
|
||||
These are functions you are allowed to choose from:
|
||||
1.
|
||||
功能描述: 总结音视频内容
|
||||
调用函数: ConcludeAudioContent(txt, llm_kwargs)
|
||||
参数说明:
|
||||
txt: 音频文件的路径
|
||||
llm_kwargs: 模型参数, 永远给定None
|
||||
2.
|
||||
功能描述: 将每次对话记录写入Markdown格式的文件中
|
||||
调用函数: WriteMarkdown()
|
||||
3.
|
||||
功能描述: 将指定目录下的PDF文件从英文翻译成中文
|
||||
调用函数: BatchTranslatePDFDocuments_MultiThreaded(txt, llm_kwargs)
|
||||
参数说明:
|
||||
txt: PDF文件所在的路径
|
||||
llm_kwargs: 模型参数, 永远给定None
|
||||
4.
|
||||
功能描述: 根据文本使用GPT模型生成相应的图像
|
||||
调用函数: ImageGeneration(txt, llm_kwargs)
|
||||
参数说明:
|
||||
txt: 图像生成所用到的提示文本
|
||||
llm_kwargs: 模型参数, 永远给定None
|
||||
5.
|
||||
功能描述: 对输入的word文档进行摘要生成
|
||||
调用函数: SummarizingWordDocuments(input_path, output_path)
|
||||
参数说明:
|
||||
input_path: 待处理的word文档路径
|
||||
output_path: 摘要生成后的文档路径
|
||||
|
||||
|
||||
You should always anwser with following format:
|
||||
----------------
|
||||
Code:
|
||||
```
|
||||
class AutoAcademic(object):
|
||||
def __init__(self):
|
||||
self.selected_function = "FILL_CORRECT_FUNCTION_HERE" # e.g., "GenerateImage"
|
||||
self.txt = "FILL_MAIN_PARAMETER_HERE" # e.g., "荷叶上的蜻蜓"
|
||||
self.llm_kwargs = None
|
||||
```
|
||||
Explanation:
|
||||
只有GenerateImage和生成图像相关, 因此选择GenerateImage函数。
|
||||
----------------
|
||||
|
||||
Now, this is my requirement:
|
||||
|
||||
"""
|
||||
def get_fn_lib():
|
||||
return {
|
||||
"BatchTranslatePDFDocuments_MultiThreaded": ("crazy_functions.批量翻译PDF文档_多线程", "批量翻译PDF文档"),
|
||||
"SummarizingWordDocuments": ("crazy_functions.总结word文档", "总结word文档"),
|
||||
"ImageGeneration": ("crazy_functions.图片生成", "图片生成"),
|
||||
"TranslateMarkdownFromEnglishToChinese": ("crazy_functions.批量Markdown翻译", "Markdown中译英"),
|
||||
"SummaryAudioVideo": ("crazy_functions.总结音视频", "总结音视频"),
|
||||
"BatchTranslatePDFDocuments_MultiThreaded": {
|
||||
"module": "crazy_functions.批量翻译PDF文档_多线程",
|
||||
"function": "批量翻译PDF文档",
|
||||
"description": "Translate PDF Documents",
|
||||
"arg_1_description": "A path containing pdf files.",
|
||||
},
|
||||
"SummarizingWordDocuments": {
|
||||
"module": "crazy_functions.总结word文档",
|
||||
"function": "总结word文档",
|
||||
"description": "Summarize Word Documents",
|
||||
"arg_1_description": "A path containing Word files.",
|
||||
},
|
||||
"ImageGeneration": {
|
||||
"module": "crazy_functions.图片生成",
|
||||
"function": "图片生成",
|
||||
"description": "Generate a image that satisfies some description.",
|
||||
"arg_1_description": "Descriptions about the image to be generated.",
|
||||
},
|
||||
"TranslateMarkdownFromEnglishToChinese": {
|
||||
"module": "crazy_functions.批量Markdown翻译",
|
||||
"function": "Markdown中译英",
|
||||
"description": "Translate Markdown Documents from English to Chinese.",
|
||||
"arg_1_description": "A path containing Markdown files.",
|
||||
},
|
||||
"SummaryAudioVideo": {
|
||||
"module": "crazy_functions.总结音视频",
|
||||
"function": "总结音视频",
|
||||
"description": "Get text from a piece of audio and summarize this audio.",
|
||||
"arg_1_description": "A path containing audio files.",
|
||||
},
|
||||
}
|
||||
|
||||
functions = [
|
||||
{
|
||||
"name": k,
|
||||
"description": v['description'],
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"plugin_arg_1": {
|
||||
"type": "string",
|
||||
"description": v['arg_1_description'],
|
||||
},
|
||||
},
|
||||
"required": ["plugin_arg_1"],
|
||||
},
|
||||
} for k, v in get_fn_lib().items()
|
||||
]
|
||||
|
||||
def inspect_dependency(chatbot, history):
|
||||
return True
|
||||
|
||||
def eval_code(code, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||
import subprocess, sys, os, shutil, importlib
|
||||
|
||||
with open('gpt_log/void_terminal_runtime.py', 'w', encoding='utf8') as f:
|
||||
f.write(code)
|
||||
|
||||
import importlib
|
||||
try:
|
||||
AutoAcademic = getattr(importlib.import_module('gpt_log.void_terminal_runtime', 'AutoAcademic'), 'AutoAcademic')
|
||||
# importlib.reload(AutoAcademic)
|
||||
auto_dict = AutoAcademic()
|
||||
selected_function = auto_dict.selected_function
|
||||
txt = auto_dict.txt
|
||||
fp, fn = get_fn_lib()[selected_function]
|
||||
tmp = get_fn_lib()[code['name']]
|
||||
fp, fn = tmp['module'], tmp['function']
|
||||
fn_plugin = getattr(importlib.import_module(fp, fn), fn)
|
||||
yield from fn_plugin(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port)
|
||||
arg = json.loads(code['arguments'])['plugin_arg_1']
|
||||
yield from fn_plugin(arg, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port)
|
||||
except:
|
||||
from toolbox import trimmed_format_exc
|
||||
chatbot.append(["执行错误", f"\n```\n{trimmed_format_exc()}\n```\n"])
|
||||
@@ -110,22 +93,27 @@ def 终端(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_
|
||||
history = []
|
||||
|
||||
# 基本信息:功能、贡献者
|
||||
chatbot.append(["函数插件功能?", "根据自然语言执行插件命令, 作者: binary-husky, 插件初始化中 ..."])
|
||||
chatbot.append(["虚空终端插件的功能?", "根据自然语言的描述, 执行任意插件的命令."])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
# # 尝试导入依赖, 如果缺少依赖, 则给出安装建议
|
||||
# dep_ok = yield from inspect_dependency(chatbot=chatbot, history=history) # 刷新界面
|
||||
# if not dep_ok: return
|
||||
|
||||
# 输入
|
||||
i_say = prompt + txt
|
||||
i_say = txt
|
||||
# 开始
|
||||
llm_kwargs_function_call = copy.deepcopy(llm_kwargs)
|
||||
llm_kwargs_function_call['llm_model'] = 'gpt-call-fn' # 修改调用函数
|
||||
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||
inputs=i_say, inputs_show_user=txt,
|
||||
llm_kwargs=llm_kwargs, chatbot=chatbot, history=[],
|
||||
sys_prompt=""
|
||||
llm_kwargs=llm_kwargs_function_call, chatbot=chatbot, history=[],
|
||||
sys_prompt=functions
|
||||
)
|
||||
|
||||
# 将代码转为动画
|
||||
code = get_code_block(gpt_say)
|
||||
yield from eval_code(code, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port)
|
||||
res = json.loads(gpt_say)['choices'][0]
|
||||
if res['finish_reason'] == 'function_call':
|
||||
code = json.loads(gpt_say)['choices'][0]
|
||||
yield from eval_code(code['message']['function_call'], llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port)
|
||||
else:
|
||||
chatbot.append(["无法调用相关功能", res])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ def 同时问询(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt
|
||||
"""
|
||||
txt 输入栏用户输入的文本,例如需要翻译的一段话,再例如一个包含了待处理文件的路径
|
||||
llm_kwargs gpt模型参数,如温度和top_p等,一般原样传递下去就行
|
||||
plugin_kwargs 插件模型的参数,如温度和top_p等,一般原样传递下去就行
|
||||
plugin_kwargs 插件模型的参数,用于灵活调整复杂功能的各种参数
|
||||
chatbot 聊天显示框的句柄,用于显示给用户
|
||||
history 聊天历史,前情提要
|
||||
system_prompt 给gpt的静默提醒
|
||||
@@ -35,19 +35,21 @@ def 同时问询_指定模型(txt, llm_kwargs, plugin_kwargs, chatbot, history,
|
||||
"""
|
||||
txt 输入栏用户输入的文本,例如需要翻译的一段话,再例如一个包含了待处理文件的路径
|
||||
llm_kwargs gpt模型参数,如温度和top_p等,一般原样传递下去就行
|
||||
plugin_kwargs 插件模型的参数,如温度和top_p等,一般原样传递下去就行
|
||||
plugin_kwargs 插件模型的参数,用于灵活调整复杂功能的各种参数
|
||||
chatbot 聊天显示框的句柄,用于显示给用户
|
||||
history 聊天历史,前情提要
|
||||
system_prompt 给gpt的静默提醒
|
||||
web_port 当前软件运行的端口号
|
||||
"""
|
||||
history = [] # 清空历史,以免输入溢出
|
||||
chatbot.append((txt, "正在同时咨询ChatGPT和ChatGLM……"))
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 由于请求gpt需要一段时间,我们先及时地做一次界面更新
|
||||
|
||||
if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
|
||||
# llm_kwargs['llm_model'] = 'chatglm&gpt-3.5-turbo&api2d-gpt-3.5-turbo' # 支持任意数量的llm接口,用&符号分隔
|
||||
llm_kwargs['llm_model'] = plugin_kwargs.get("advanced_arg", 'chatglm&gpt-3.5-turbo') # 'chatglm&gpt-3.5-turbo' # 支持任意数量的llm接口,用&符号分隔
|
||||
|
||||
chatbot.append((txt, f"正在同时咨询{llm_kwargs['llm_model']}"))
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 由于请求gpt需要一段时间,我们先及时地做一次界面更新
|
||||
|
||||
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||
inputs=txt, inputs_show_user=txt,
|
||||
llm_kwargs=llm_kwargs, chatbot=chatbot, history=history,
|
||||
|
||||
195
crazy_functions/语音助手.py
普通文件
195
crazy_functions/语音助手.py
普通文件
@@ -0,0 +1,195 @@
|
||||
from toolbox import update_ui
|
||||
from toolbox import CatchException, get_conf, markdown_convertion
|
||||
from crazy_functions.crazy_utils import input_clipping
|
||||
from request_llm.bridge_all import predict_no_ui_long_connection
|
||||
import threading, time
|
||||
import numpy as np
|
||||
from .live_audio.aliyunASR import AliyunASR
|
||||
import json
|
||||
|
||||
class WatchDog():
|
||||
def __init__(self, timeout, bark_fn, interval=3, msg="") -> None:
|
||||
self.last_feed = None
|
||||
self.timeout = timeout
|
||||
self.bark_fn = bark_fn
|
||||
self.interval = interval
|
||||
self.msg = msg
|
||||
self.kill_dog = False
|
||||
|
||||
def watch(self):
|
||||
while True:
|
||||
if self.kill_dog: break
|
||||
if time.time() - self.last_feed > self.timeout:
|
||||
if len(self.msg) > 0: print(self.msg)
|
||||
self.bark_fn()
|
||||
break
|
||||
time.sleep(self.interval)
|
||||
|
||||
def begin_watch(self):
|
||||
self.last_feed = time.time()
|
||||
th = threading.Thread(target=self.watch)
|
||||
th.daemon = True
|
||||
th.start()
|
||||
|
||||
def feed(self):
|
||||
self.last_feed = time.time()
|
||||
|
||||
def chatbot2history(chatbot):
|
||||
history = []
|
||||
for c in chatbot:
|
||||
for q in c:
|
||||
if q not in ["[请讲话]", "[等待GPT响应]", "[正在等您说完问题]"]:
|
||||
history.append(q.strip('<div class="markdown-body">').strip('</div>').strip('<p>').strip('</p>'))
|
||||
return history
|
||||
|
||||
class AsyncGptTask():
|
||||
def __init__(self) -> None:
|
||||
self.observe_future = []
|
||||
self.observe_future_chatbot_index = []
|
||||
|
||||
def gpt_thread_worker(self, i_say, llm_kwargs, history, sys_prompt, observe_window, index):
|
||||
try:
|
||||
MAX_TOKEN_ALLO = 2560
|
||||
i_say, history = input_clipping(i_say, history, max_token_limit=MAX_TOKEN_ALLO)
|
||||
gpt_say_partial = predict_no_ui_long_connection(inputs=i_say, llm_kwargs=llm_kwargs, history=history, sys_prompt=sys_prompt,
|
||||
observe_window=observe_window[index], console_slience=True)
|
||||
except ConnectionAbortedError as token_exceed_err:
|
||||
print('至少一个线程任务Token溢出而失败', e)
|
||||
except Exception as e:
|
||||
print('至少一个线程任务意外失败', e)
|
||||
|
||||
def add_async_gpt_task(self, i_say, chatbot_index, llm_kwargs, history, system_prompt):
|
||||
self.observe_future.append([""])
|
||||
self.observe_future_chatbot_index.append(chatbot_index)
|
||||
cur_index = len(self.observe_future)-1
|
||||
th_new = threading.Thread(target=self.gpt_thread_worker, args=(i_say, llm_kwargs, history, system_prompt, self.observe_future, cur_index))
|
||||
th_new.daemon = True
|
||||
th_new.start()
|
||||
|
||||
def update_chatbot(self, chatbot):
|
||||
for of, ofci in zip(self.observe_future, self.observe_future_chatbot_index):
|
||||
try:
|
||||
chatbot[ofci] = list(chatbot[ofci])
|
||||
chatbot[ofci][1] = markdown_convertion(of[0])
|
||||
except:
|
||||
self.observe_future = []
|
||||
self.observe_future_chatbot_index = []
|
||||
return chatbot
|
||||
|
||||
class InterviewAssistant(AliyunASR):
|
||||
def __init__(self):
|
||||
self.capture_interval = 0.5 # second
|
||||
self.stop = False
|
||||
self.parsed_text = ""
|
||||
self.parsed_sentence = ""
|
||||
self.buffered_sentence = ""
|
||||
self.event_on_result_chg = threading.Event()
|
||||
self.event_on_entence_end = threading.Event()
|
||||
self.event_on_commit_question = threading.Event()
|
||||
|
||||
def __del__(self):
|
||||
self.stop = True
|
||||
self.stop_msg = ""
|
||||
self.commit_wd.kill_dog = True
|
||||
self.plugin_wd.kill_dog = True
|
||||
|
||||
def init(self, chatbot):
|
||||
# 初始化音频采集线程
|
||||
self.captured_audio = np.array([])
|
||||
self.keep_latest_n_second = 10
|
||||
self.commit_after_pause_n_second = 2.0
|
||||
self.ready_audio_flagment = None
|
||||
self.stop = False
|
||||
self.plugin_wd = WatchDog(timeout=5, bark_fn=self.__del__, msg="程序终止")
|
||||
self.aut = threading.Thread(target=self.audio_convertion_thread, args=(chatbot._cookies['uuid'],))
|
||||
self.aut.daemon = True
|
||||
self.aut.start()
|
||||
# th2 = threading.Thread(target=self.audio2txt_thread, args=(chatbot._cookies['uuid'],))
|
||||
# th2.daemon = True
|
||||
# th2.start()
|
||||
|
||||
def no_audio_for_a_while(self):
|
||||
if len(self.buffered_sentence) < 7: # 如果一句话小于7个字,暂不提交
|
||||
self.commit_wd.begin_watch()
|
||||
else:
|
||||
self.event_on_commit_question.set()
|
||||
|
||||
def begin(self, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
||||
# main plugin function
|
||||
self.init(chatbot)
|
||||
chatbot.append(["[请讲话]", "[正在等您说完问题]"])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
self.plugin_wd.begin_watch()
|
||||
self.agt = AsyncGptTask()
|
||||
self.commit_wd = WatchDog(timeout=self.commit_after_pause_n_second, bark_fn=self.no_audio_for_a_while, interval=0.2)
|
||||
self.commit_wd.begin_watch()
|
||||
|
||||
while not self.stop:
|
||||
self.event_on_result_chg.wait(timeout=0.25) # run once every 0.25 second
|
||||
chatbot = self.agt.update_chatbot(chatbot) # 将子线程的gpt结果写入chatbot
|
||||
history = chatbot2history(chatbot)
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
self.plugin_wd.feed()
|
||||
|
||||
if self.event_on_result_chg.is_set():
|
||||
# update audio decode result
|
||||
self.event_on_result_chg.clear()
|
||||
chatbot[-1] = list(chatbot[-1])
|
||||
chatbot[-1][0] = self.buffered_sentence + self.parsed_text
|
||||
history = chatbot2history(chatbot)
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
self.commit_wd.feed()
|
||||
|
||||
if self.event_on_entence_end.is_set():
|
||||
# called when a sentence has ended
|
||||
self.event_on_entence_end.clear()
|
||||
self.parsed_text = self.parsed_sentence
|
||||
self.buffered_sentence += self.parsed_sentence
|
||||
|
||||
if self.event_on_commit_question.is_set():
|
||||
# called when a question should be commited
|
||||
self.event_on_commit_question.clear()
|
||||
if len(self.buffered_sentence) == 0: raise RuntimeError
|
||||
|
||||
self.commit_wd.begin_watch()
|
||||
chatbot[-1] = list(chatbot[-1])
|
||||
chatbot[-1] = [self.buffered_sentence, "[等待GPT响应]"]
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
# add gpt task 创建子线程请求gpt,避免线程阻塞
|
||||
history = chatbot2history(chatbot)
|
||||
self.agt.add_async_gpt_task(self.buffered_sentence, len(chatbot)-1, llm_kwargs, history, system_prompt)
|
||||
|
||||
self.buffered_sentence = ""
|
||||
chatbot.append(["[请讲话]", "[正在等您说完问题]"])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
if len(self.stop_msg) != 0:
|
||||
raise RuntimeError(self.stop_msg)
|
||||
|
||||
|
||||
|
||||
@CatchException
|
||||
def 语音助手(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||
# pip install -U openai-whisper
|
||||
chatbot.append(["对话助手函数插件:使用时,双手离开鼠标键盘吧", "音频助手, 正在听您讲话(点击“停止”键可终止程序)..."])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
# 尝试导入依赖,如果缺少依赖,则给出安装建议
|
||||
try:
|
||||
import nls
|
||||
from scipy import io
|
||||
except:
|
||||
chatbot.append(["导入依赖失败", "使用该模块需要额外依赖, 安装方法:```pip install --upgrade aliyun-python-sdk-core==2.13.3 pyOpenSSL scipy git+https://github.com/aliyun/alibabacloud-nls-python-sdk.git```"])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
return
|
||||
|
||||
APPKEY = get_conf('ALIYUN_APPKEY')
|
||||
if APPKEY == "":
|
||||
chatbot.append(["导入依赖失败", "没有阿里云语音识别APPKEY和TOKEN, 详情见https://help.aliyun.com/document_detail/450255.html"])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
return
|
||||
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
ia = InterviewAssistant()
|
||||
yield from ia.begin(llm_kwargs, plugin_kwargs, chatbot, history, system_prompt)
|
||||
|
||||
@@ -104,7 +104,7 @@ def 谷歌检索小助手(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
|
||||
meta_paper_info_list = meta_paper_info_list[batchsize:]
|
||||
|
||||
chatbot.append(["状态?",
|
||||
"已经全部完成,您可以试试让AI写一个Related Works,例如您可以继续输入Write an academic \"Related Works\" section about \"你搜索的研究领域\" for me."])
|
||||
"已经全部完成,您可以试试让AI写一个Related Works,例如您可以继续输入Write a \"Related Works\" section about \"你搜索的研究领域\" for me."])
|
||||
msg = '正常'
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg=msg) # 刷新界面
|
||||
res = write_results_to_file(history)
|
||||
|
||||
28
crazy_functions/辅助回答.py
普通文件
28
crazy_functions/辅助回答.py
普通文件
@@ -0,0 +1,28 @@
|
||||
# encoding: utf-8
|
||||
# @Time : 2023/4/19
|
||||
# @Author : Spike
|
||||
# @Descr :
|
||||
from toolbox import update_ui
|
||||
from toolbox import CatchException, report_execption, write_results_to_file
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
|
||||
|
||||
@CatchException
|
||||
def 猜你想问(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||
if txt:
|
||||
show_say = txt
|
||||
prompt = txt+'\n回答完问题后,再列出用户可能提出的三个问题。'
|
||||
else:
|
||||
prompt = history[-1]+"\n分析上述回答,再列出用户可能提出的三个问题。"
|
||||
show_say = '分析上述回答,再列出用户可能提出的三个问题。'
|
||||
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||
inputs=prompt,
|
||||
inputs_show_user=show_say,
|
||||
llm_kwargs=llm_kwargs,
|
||||
chatbot=chatbot,
|
||||
history=history,
|
||||
sys_prompt=system_prompt
|
||||
)
|
||||
chatbot[-1] = (show_say, gpt_say)
|
||||
history.extend([show_say, gpt_say])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
@@ -1,13 +1,12 @@
|
||||
from toolbox import CatchException, update_ui
|
||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
import datetime, re
|
||||
|
||||
import datetime
|
||||
@CatchException
|
||||
def 高阶功能模板函数(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||
"""
|
||||
txt 输入栏用户输入的文本,例如需要翻译的一段话,再例如一个包含了待处理文件的路径
|
||||
llm_kwargs gpt模型参数,如温度和top_p等,一般原样传递下去就行
|
||||
plugin_kwargs 插件模型的参数,暂时没有用武之地
|
||||
plugin_kwargs 插件模型的参数,用于灵活调整复杂功能的各种参数
|
||||
chatbot 聊天显示框的句柄,用于显示给用户
|
||||
history 聊天历史,前情提要
|
||||
system_prompt 给gpt的静默提醒
|
||||
@@ -19,34 +18,12 @@ def 高阶功能模板函数(txt, llm_kwargs, plugin_kwargs, chatbot, history, s
|
||||
for i in range(5):
|
||||
currentMonth = (datetime.date.today() + datetime.timedelta(days=i)).month
|
||||
currentDay = (datetime.date.today() + datetime.timedelta(days=i)).day
|
||||
i_say = f'历史中哪些事件发生在{currentMonth}月{currentDay}日?用中文列举两条,然后分别给出描述事件的两个英文单词。' + '当你给出关键词时,使用以下json格式:{"KeyWords":[EnglishKeyWord1,EnglishKeyWord2]}。'
|
||||
i_say = f'历史中哪些事件发生在{currentMonth}月{currentDay}日?列举两条并发送相关图片。发送图片时,请使用Markdown,将Unsplash API中的PUT_YOUR_QUERY_HERE替换成描述该事件的一个最重要的单词。'
|
||||
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||
inputs=i_say, inputs_show_user=i_say,
|
||||
llm_kwargs=llm_kwargs, chatbot=chatbot, history=[],
|
||||
sys_prompt='输出格式示例:1908年,美国消防救援事业发展的“美国消防协会”成立。关键词:{"KeyWords":["Fire","American"]}。'
|
||||
sys_prompt="当你想发送一张照片时,请使用Markdown, 并且不要有反斜线, 不要用代码块。使用 Unsplash API (https://source.unsplash.com/1280x720/? < PUT_YOUR_QUERY_HERE >)。"
|
||||
)
|
||||
gpt_say = get_images(gpt_say)
|
||||
chatbot[-1] = (i_say, gpt_say)
|
||||
history.append(i_say);history.append(gpt_say)
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 界面更新
|
||||
|
||||
|
||||
def get_images(gpt_say):
|
||||
def get_image_by_keyword(keyword):
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
response = requests.get(f'https://wallhaven.cc/search?q={keyword}', timeout=2)
|
||||
for image_element in BeautifulSoup(response.content, 'html.parser').findAll("img"):
|
||||
if "data-src" in image_element: break
|
||||
return image_element["data-src"]
|
||||
|
||||
for keywords in re.findall('{"KeyWords":\[(.*?)\]}', gpt_say):
|
||||
keywords = [n.strip('"') for n in keywords.split(',')]
|
||||
try:
|
||||
description = keywords[0]
|
||||
url = get_image_by_keyword(keywords[0])
|
||||
img_tag = f"\n\n"
|
||||
gpt_say += img_tag
|
||||
except:
|
||||
continue
|
||||
return gpt_say
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 界面更新
|
||||
在新工单中引用
屏蔽一个用户