镜像自地址
https://github.com/binary-husky/gpt_academic.git
已同步 2025-12-06 06:26:47 +00:00
更多模型切换
这个提交包含在:
179
toolbox.py
179
toolbox.py
@@ -1,13 +1,10 @@
|
||||
import markdown
|
||||
import mdtex2html
|
||||
import threading
|
||||
import importlib
|
||||
import traceback
|
||||
import inspect
|
||||
import re
|
||||
from latex2mathml.converter import convert as tex2mathml
|
||||
from functools import wraps, lru_cache
|
||||
|
||||
############################### 插件输入输出接驳区 #######################################
|
||||
class ChatBotWithCookies(list):
|
||||
def __init__(self, cookie):
|
||||
@@ -25,9 +22,10 @@ class ChatBotWithCookies(list):
|
||||
|
||||
def ArgsGeneralWrapper(f):
|
||||
"""
|
||||
装饰器函数,用于重组输入参数,改变输入参数的顺序与结构。
|
||||
装饰器函数,用于重组输入参数,改变输入参数的顺序与结构。
|
||||
"""
|
||||
def decorated(cookies, max_length, llm_model, txt, txt2, top_p, temperature, chatbot, history, system_prompt, *args):
|
||||
from request_llm.bridge_all import model_info
|
||||
txt_passon = txt
|
||||
if txt == "" and txt2 != "": txt_passon = txt2
|
||||
# 引入一个有cookie的chatbot
|
||||
@@ -38,6 +36,7 @@ def ArgsGeneralWrapper(f):
|
||||
llm_kwargs = {
|
||||
'api_key': cookies['api_key'],
|
||||
'llm_model': llm_model,
|
||||
'endpoint': model_info[llm_model]['endpoint'],
|
||||
'top_p':top_p,
|
||||
'max_length': max_length,
|
||||
'temperature':temperature,
|
||||
@@ -56,69 +55,10 @@ def update_ui(chatbot, history, msg='正常', **kwargs): # 刷新界面
|
||||
"""
|
||||
assert isinstance(chatbot, ChatBotWithCookies), "在传递chatbot的过程中不要将其丢弃。必要时,可用clear将其清空,然后用for+append循环重新赋值。"
|
||||
yield chatbot.get_cookies(), chatbot, history, msg
|
||||
############################### ################## #######################################
|
||||
##########################################################################################
|
||||
|
||||
def get_reduce_token_percent(text):
|
||||
"""
|
||||
* 此函数未来将被弃用
|
||||
"""
|
||||
try:
|
||||
# text = "maximum context length is 4097 tokens. However, your messages resulted in 4870 tokens"
|
||||
pattern = r"(\d+)\s+tokens\b"
|
||||
match = re.findall(pattern, text)
|
||||
EXCEED_ALLO = 500 # 稍微留一点余地,否则在回复时会因余量太少出问题
|
||||
max_limit = float(match[0]) - EXCEED_ALLO
|
||||
current_tokens = float(match[1])
|
||||
ratio = max_limit/current_tokens
|
||||
assert ratio > 0 and ratio < 1
|
||||
return ratio, str(int(current_tokens-max_limit))
|
||||
except:
|
||||
return 0.5, '不详'
|
||||
|
||||
|
||||
|
||||
def write_results_to_file(history, file_name=None):
|
||||
"""
|
||||
将对话记录history以Markdown格式写入文件中。如果没有指定文件名,则使用当前时间生成文件名。
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
if file_name is None:
|
||||
# file_name = time.strftime("chatGPT分析报告%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
|
||||
file_name = 'chatGPT分析报告' + \
|
||||
time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
|
||||
os.makedirs('./gpt_log/', exist_ok=True)
|
||||
with open(f'./gpt_log/{file_name}', 'w', encoding='utf8') as f:
|
||||
f.write('# chatGPT 分析报告\n')
|
||||
for i, content in enumerate(history):
|
||||
try: # 这个bug没找到触发条件,暂时先这样顶一下
|
||||
if type(content) != str:
|
||||
content = str(content)
|
||||
except:
|
||||
continue
|
||||
if i % 2 == 0:
|
||||
f.write('## ')
|
||||
f.write(content)
|
||||
f.write('\n\n')
|
||||
res = '以上材料已经被写入' + os.path.abspath(f'./gpt_log/{file_name}')
|
||||
print(res)
|
||||
return res
|
||||
|
||||
|
||||
def regular_txt_to_markdown(text):
|
||||
"""
|
||||
将普通文本转换为Markdown格式的文本。
|
||||
"""
|
||||
text = text.replace('\n', '\n\n')
|
||||
text = text.replace('\n\n\n', '\n\n')
|
||||
text = text.replace('\n\n\n', '\n\n')
|
||||
return text
|
||||
|
||||
|
||||
def CatchException(f):
|
||||
"""
|
||||
装饰器函数,捕捉函数f中的异常并封装到一个生成器中返回,并显示到聊天当中。
|
||||
装饰器函数,捕捉函数f中的异常并封装到一个生成器中返回,并显示到聊天当中。
|
||||
"""
|
||||
@wraps(f)
|
||||
def decorated(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT):
|
||||
@@ -155,9 +95,70 @@ def HotReload(f):
|
||||
return decorated
|
||||
|
||||
|
||||
####################################### 其他小工具 #####################################
|
||||
|
||||
def get_reduce_token_percent(text):
|
||||
"""
|
||||
* 此函数未来将被弃用
|
||||
"""
|
||||
try:
|
||||
# text = "maximum context length is 4097 tokens. However, your messages resulted in 4870 tokens"
|
||||
pattern = r"(\d+)\s+tokens\b"
|
||||
match = re.findall(pattern, text)
|
||||
EXCEED_ALLO = 500 # 稍微留一点余地,否则在回复时会因余量太少出问题
|
||||
max_limit = float(match[0]) - EXCEED_ALLO
|
||||
current_tokens = float(match[1])
|
||||
ratio = max_limit/current_tokens
|
||||
assert ratio > 0 and ratio < 1
|
||||
return ratio, str(int(current_tokens-max_limit))
|
||||
except:
|
||||
return 0.5, '不详'
|
||||
|
||||
|
||||
|
||||
def write_results_to_file(history, file_name=None):
|
||||
"""
|
||||
将对话记录history以Markdown格式写入文件中。如果没有指定文件名,则使用当前时间生成文件名。
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
if file_name is None:
|
||||
# file_name = time.strftime("chatGPT分析报告%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
|
||||
file_name = 'chatGPT分析报告' + \
|
||||
time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
|
||||
os.makedirs('./gpt_log/', exist_ok=True)
|
||||
with open(f'./gpt_log/{file_name}', 'w', encoding='utf8') as f:
|
||||
f.write('# chatGPT 分析报告\n')
|
||||
for i, content in enumerate(history):
|
||||
try: # 这个bug没找到触发条件,暂时先这样顶一下
|
||||
if type(content) != str:
|
||||
content = str(content)
|
||||
except:
|
||||
continue
|
||||
if i % 2 == 0:
|
||||
f.write('## ')
|
||||
f.write(content)
|
||||
f.write('\n\n')
|
||||
res = '以上材料已经被写入' + os.path.abspath(f'./gpt_log/{file_name}')
|
||||
print(res)
|
||||
return res
|
||||
|
||||
|
||||
def regular_txt_to_markdown(text):
|
||||
"""
|
||||
将普通文本转换为Markdown格式的文本。
|
||||
"""
|
||||
text = text.replace('\n', '\n\n')
|
||||
text = text.replace('\n\n\n', '\n\n')
|
||||
text = text.replace('\n\n\n', '\n\n')
|
||||
return text
|
||||
|
||||
|
||||
|
||||
|
||||
def report_execption(chatbot, history, a, b):
|
||||
"""
|
||||
向chatbot中添加错误信息
|
||||
向chatbot中添加错误信息
|
||||
"""
|
||||
chatbot.append((a, b))
|
||||
history.append(a)
|
||||
@@ -166,7 +167,7 @@ def report_execption(chatbot, history, a, b):
|
||||
|
||||
def text_divide_paragraph(text):
|
||||
"""
|
||||
将文本按照段落分隔符分割开,生成带有段落标签的HTML代码。
|
||||
将文本按照段落分隔符分割开,生成带有段落标签的HTML代码。
|
||||
"""
|
||||
if '```' in text:
|
||||
# careful input
|
||||
@@ -182,7 +183,7 @@ def text_divide_paragraph(text):
|
||||
|
||||
def markdown_convertion(txt):
|
||||
"""
|
||||
将Markdown格式的文本转换为HTML格式。如果包含数学公式,则先将公式转换为HTML格式。
|
||||
将Markdown格式的文本转换为HTML格式。如果包含数学公式,则先将公式转换为HTML格式。
|
||||
"""
|
||||
pre = '<div class="markdown-body">'
|
||||
suf = '</div>'
|
||||
@@ -274,7 +275,7 @@ def close_up_code_segment_during_stream(gpt_reply):
|
||||
|
||||
def format_io(self, y):
|
||||
"""
|
||||
将输入和输出解析为HTML格式。将y中最后一项的输入部分段落化,并将输出部分的Markdown和数学公式转换为HTML格式。
|
||||
将输入和输出解析为HTML格式。将y中最后一项的输入部分段落化,并将输出部分的Markdown和数学公式转换为HTML格式。
|
||||
"""
|
||||
if y is None or y == []:
|
||||
return []
|
||||
@@ -290,7 +291,7 @@ def format_io(self, y):
|
||||
|
||||
def find_free_port():
|
||||
"""
|
||||
返回当前系统中可用的未使用端口。
|
||||
返回当前系统中可用的未使用端口。
|
||||
"""
|
||||
import socket
|
||||
from contextlib import closing
|
||||
@@ -410,9 +411,43 @@ def on_report_generated(files, chatbot):
|
||||
return report_files, chatbot
|
||||
|
||||
def is_openai_api_key(key):
|
||||
# 正确的 API_KEY 是 "sk-" + 48 位大小写字母数字的组合
|
||||
API_MATCH = re.match(r"sk-[a-zA-Z0-9]{48}$", key)
|
||||
return API_MATCH
|
||||
return bool(API_MATCH)
|
||||
|
||||
def is_api2d_key(key):
|
||||
if key.startswith('fk') and len(key) == 41:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def is_any_api_key(key):
|
||||
if ',' in key:
|
||||
keys = key.split(',')
|
||||
for k in keys:
|
||||
if is_any_api_key(k): return True
|
||||
return False
|
||||
else:
|
||||
return is_openai_api_key(key) or is_api2d_key(key)
|
||||
|
||||
|
||||
def select_api_key(keys, llm_model):
|
||||
import random
|
||||
avail_key_list = []
|
||||
key_list = keys.split(',')
|
||||
|
||||
if llm_model.startswith('gpt-'):
|
||||
for k in key_list:
|
||||
if is_openai_api_key(k): avail_key_list.append(k)
|
||||
|
||||
if llm_model.startswith('api2d-'):
|
||||
for k in key_list:
|
||||
if is_api2d_key(k): avail_key_list.append(k)
|
||||
|
||||
if len(avail_key_list) == 0:
|
||||
raise RuntimeError(f"您提供的api-key不满足要求,不包含任何可用于{llm_model}的api-key。")
|
||||
|
||||
api_key = random.choice(avail_key_list) # 随机负载均衡
|
||||
return api_key
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def read_single_conf_with_lru_cache(arg):
|
||||
@@ -423,7 +458,7 @@ def read_single_conf_with_lru_cache(arg):
|
||||
r = getattr(importlib.import_module('config'), arg)
|
||||
# 在读取API_KEY时,检查一下是不是忘了改config
|
||||
if arg == 'API_KEY':
|
||||
if is_openai_api_key(r):
|
||||
if is_any_api_key(r):
|
||||
print亮绿(f"[API_KEY] 您的 API_KEY 是: {r[:15]}*** API_KEY 导入成功")
|
||||
else:
|
||||
print亮红( "[API_KEY] 正确的 API_KEY 是 'sk-' + '48 位大小写字母数字' 的组合,请在config文件中修改API密钥, 添加海外代理之后再运行。" + \
|
||||
|
||||
在新工单中引用
屏蔽一个用户