镜像自地址
https://github.com/binary-husky/gpt_academic.git
已同步 2025-12-06 22:46:48 +00:00
比较提交
16 次代码提交
master-4.0
...
frontier
| 作者 | SHA1 | 提交日期 | |
|---|---|---|---|
|
|
171e8a2744 | ||
|
|
3ed1b0320e | ||
|
|
c6412a8d73 | ||
|
|
c598e20f0e | ||
|
|
7af6994f7b | ||
|
|
aab62aea39 | ||
|
|
31e3ffd997 | ||
|
|
1acd2bf292 | ||
|
|
5e0f327237 | ||
|
|
6a6eba5f16 | ||
|
|
722a055879 | ||
|
|
8254930495 | ||
|
|
ca1ab57f5d | ||
|
|
e20177cb7d | ||
|
|
6bd410582b | ||
|
|
4fe638ffa8 |
11
config.py
11
config.py
@@ -354,6 +354,17 @@ DAAS_SERVER_URLS = [ f"https://niuziniu-biligpt{i}.hf.space/stream" for i in ran
|
||||
# 在互联网搜索组件中,负责将搜索结果整理成干净的Markdown
|
||||
JINA_API_KEY = ""
|
||||
|
||||
|
||||
# 是否自动裁剪上下文长度(是否启动,默认不启动)
|
||||
AUTO_CONTEXT_CLIP_ENABLE = False
|
||||
# 目标裁剪上下文的token长度(如果超过这个长度,则会自动裁剪)
|
||||
AUTO_CONTEXT_CLIP_TRIGGER_TOKEN_LEN = 30*1000
|
||||
# 无条件丢弃x以上的轮数
|
||||
AUTO_CONTEXT_MAX_ROUND = 64
|
||||
# 在裁剪上下文时,倒数第x次对话能“最多”保留的上下文token的比例占 AUTO_CONTEXT_CLIP_TRIGGER_TOKEN_LEN 的多少
|
||||
AUTO_CONTEXT_MAX_CLIP_RATIO = [0.80, 0.60, 0.45, 0.25, 0.20, 0.18, 0.16, 0.14, 0.12, 0.10, 0.08, 0.07, 0.06, 0.05, 0.04, 0.03, 0.02, 0.01]
|
||||
|
||||
|
||||
"""
|
||||
--------------- 配置关联关系说明 ---------------
|
||||
|
||||
|
||||
@@ -172,7 +172,7 @@ def 载入对话历史存档(txt, llm_kwargs, plugin_kwargs, chatbot, history, s
|
||||
user_request 当前用户的请求信息(IP地址等)
|
||||
"""
|
||||
from crazy_functions.crazy_utils import get_files_from_everything
|
||||
success, file_manifest, _ = get_files_from_everything(txt, type='.html')
|
||||
success, file_manifest, _ = get_files_from_everything(txt, type='.html',chatbot=chatbot)
|
||||
|
||||
if not success:
|
||||
if txt == "": txt = '空空如也的输入栏'
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from shared_utils.fastapi_server import validate_path_safety
|
||||
from toolbox import update_ui, trimmed_format_exc, promote_file_to_downloadzone, get_log_folder
|
||||
from toolbox import CatchException, report_exception, write_history_to_file, zip_folder
|
||||
from loguru import logger
|
||||
@@ -155,6 +156,7 @@ def Latex英文润色(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_p
|
||||
import glob, os
|
||||
if os.path.exists(txt):
|
||||
project_folder = txt
|
||||
validate_path_safety(project_folder, chatbot.get_user())
|
||||
else:
|
||||
if txt == "": txt = '空空如也的输入栏'
|
||||
report_exception(chatbot, history, a = f"解析项目: {txt}", b = f"找不到本地项目或无权访问: {txt}")
|
||||
@@ -193,6 +195,7 @@ def Latex中文润色(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_p
|
||||
import glob, os
|
||||
if os.path.exists(txt):
|
||||
project_folder = txt
|
||||
validate_path_safety(project_folder, chatbot.get_user())
|
||||
else:
|
||||
if txt == "": txt = '空空如也的输入栏'
|
||||
report_exception(chatbot, history, a = f"解析项目: {txt}", b = f"找不到本地项目或无权访问: {txt}")
|
||||
@@ -229,6 +232,7 @@ def Latex英文纠错(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_p
|
||||
import glob, os
|
||||
if os.path.exists(txt):
|
||||
project_folder = txt
|
||||
validate_path_safety(project_folder, chatbot.get_user())
|
||||
else:
|
||||
if txt == "": txt = '空空如也的输入栏'
|
||||
report_exception(chatbot, history, a = f"解析项目: {txt}", b = f"找不到本地项目或无权访问: {txt}")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import glob, shutil, os, re
|
||||
from loguru import logger
|
||||
from shared_utils.fastapi_server import validate_path_safety
|
||||
from toolbox import update_ui, trimmed_format_exc, gen_time_str
|
||||
from toolbox import CatchException, report_exception, get_log_folder
|
||||
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
||||
@@ -118,7 +119,7 @@ def 多文件翻译(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
|
||||
def get_files_from_everything(txt, preference=''):
|
||||
def get_files_from_everything(txt, preference='', chatbox=None):
|
||||
if txt == "": return False, None, None
|
||||
success = True
|
||||
if txt.startswith('http'):
|
||||
@@ -146,9 +147,11 @@ def get_files_from_everything(txt, preference=''):
|
||||
# 直接给定文件
|
||||
file_manifest = [txt]
|
||||
project_folder = os.path.dirname(txt)
|
||||
validate_path_safety(project_folder, chatbot.get_user())
|
||||
elif os.path.exists(txt):
|
||||
# 本地路径,递归搜索
|
||||
project_folder = txt
|
||||
validate_path_safety(project_folder, chatbot.get_user())
|
||||
file_manifest = [f for f in glob.glob(f'{project_folder}/**/*.md', recursive=True)]
|
||||
else:
|
||||
project_folder = None
|
||||
@@ -177,7 +180,7 @@ def Markdown英译中(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_p
|
||||
return
|
||||
history = [] # 清空历史,以免输入溢出
|
||||
|
||||
success, file_manifest, project_folder = get_files_from_everything(txt, preference="Github")
|
||||
success, file_manifest, project_folder = get_files_from_everything(txt, preference="Github", chatbox=chatbot)
|
||||
|
||||
if not success:
|
||||
# 什么都没有
|
||||
|
||||
@@ -26,7 +26,7 @@ def 批量翻译PDF文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
|
||||
|
||||
# 清空历史,以免输入溢出
|
||||
history = []
|
||||
success, file_manifest, project_folder = get_files_from_everything(txt, type='.pdf')
|
||||
success, file_manifest, project_folder = get_files_from_everything(txt, type='.pdf', chatbot=chatbot)
|
||||
|
||||
# 检测输入参数,如没有给定输入参数,直接退出
|
||||
if (not success) and txt == "": txt = '空空如也的输入栏。提示:请先上传文件(把PDF文件拖入对话)。'
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import threading
|
||||
from loguru import logger
|
||||
from shared_utils.fastapi_server import validate_path_safety
|
||||
from shared_utils.char_visual_effect import scrolling_visual_effect
|
||||
from toolbox import update_ui, get_conf, trimmed_format_exc, get_max_token, Singleton
|
||||
|
||||
@@ -539,7 +540,7 @@ def read_and_clean_pdf_text(fp):
|
||||
return meta_txt, page_one_meta
|
||||
|
||||
|
||||
def get_files_from_everything(txt, type): # type='.md'
|
||||
def get_files_from_everything(txt, type, chatbot=None): # type='.md'
|
||||
"""
|
||||
这个函数是用来获取指定目录下所有指定类型(如.md)的文件,并且对于网络上的文件,也可以获取它。
|
||||
下面是对每个参数和返回值的说明:
|
||||
@@ -551,6 +552,7 @@ def get_files_from_everything(txt, type): # type='.md'
|
||||
- file_manifest: 文件路径列表,里面包含以指定类型为后缀名的所有文件的绝对路径。
|
||||
- project_folder: 字符串,表示文件所在的文件夹路径。如果是网络上的文件,就是临时文件夹的路径。
|
||||
该函数详细注释已添加,请确认是否满足您的需要。
|
||||
- chatbot 带Cookies的Chatbot类,为实现更多强大的功能做基础
|
||||
"""
|
||||
import glob, os
|
||||
|
||||
@@ -573,9 +575,13 @@ def get_files_from_everything(txt, type): # type='.md'
|
||||
# 直接给定文件
|
||||
file_manifest = [txt]
|
||||
project_folder = os.path.dirname(txt)
|
||||
if chatbot is not None:
|
||||
validate_path_safety(project_folder, chatbot.get_user())
|
||||
elif os.path.exists(txt):
|
||||
# 本地路径,递归搜索
|
||||
project_folder = txt
|
||||
if chatbot is not None:
|
||||
validate_path_safety(project_folder, chatbot.get_user())
|
||||
file_manifest = [f for f in glob.glob(f'{project_folder}/**/*'+type, recursive=True)]
|
||||
if len(file_manifest) == 0:
|
||||
success = False
|
||||
|
||||
@@ -242,9 +242,7 @@ def 解析PDF_DOC2X_单文件(
|
||||
extract_archive(file_path=this_file_path, dest_dir=ex_folder)
|
||||
|
||||
# edit markdown files
|
||||
success, file_manifest, project_folder = get_files_from_everything(
|
||||
ex_folder, type=".md"
|
||||
)
|
||||
success, file_manifest, project_folder = get_files_from_everything(ex_folder, type='.md', chatbot=chatbot)
|
||||
for generated_fp in file_manifest:
|
||||
# 修正一些公式问题
|
||||
with open(generated_fp, "r", encoding="utf8") as f:
|
||||
|
||||
@@ -27,10 +27,10 @@ def extract_text_from_files(txt, chatbot, history):
|
||||
return False, final_result, page_one, file_manifest, exception #如输入区内容不是文件则直接返回输入区内容
|
||||
|
||||
#查找输入区内容中的文件
|
||||
file_pdf,pdf_manifest,folder_pdf = get_files_from_everything(txt, '.pdf')
|
||||
file_md,md_manifest,folder_md = get_files_from_everything(txt, '.md')
|
||||
file_word,word_manifest,folder_word = get_files_from_everything(txt, '.docx')
|
||||
file_doc,doc_manifest,folder_doc = get_files_from_everything(txt, '.doc')
|
||||
file_pdf,pdf_manifest,folder_pdf = get_files_from_everything(txt, '.pdf', chatbot=chatbot)
|
||||
file_md,md_manifest,folder_md = get_files_from_everything(txt, '.md', chatbot=chatbot)
|
||||
file_word,word_manifest,folder_word = get_files_from_everything(txt, '.docx', chatbot=chatbot)
|
||||
file_doc,doc_manifest,folder_doc = get_files_from_everything(txt, '.doc', chatbot=chatbot)
|
||||
|
||||
if file_doc:
|
||||
exception = "word"
|
||||
|
||||
@@ -104,6 +104,8 @@ def 总结word文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_pr
|
||||
# 检测输入参数,如没有给定输入参数,直接退出
|
||||
if os.path.exists(txt):
|
||||
project_folder = txt
|
||||
from shared_utils.fastapi_server import validate_path_safety
|
||||
validate_path_safety(project_folder, chatbot.get_user())
|
||||
else:
|
||||
if txt == "": txt = '空空如也的输入栏'
|
||||
report_exception(chatbot, history, a=f"解析项目: {txt}", b=f"找不到本地项目或无权访问: {txt}")
|
||||
|
||||
@@ -61,7 +61,7 @@ def 批量翻译PDF文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
|
||||
history = []
|
||||
|
||||
from crazy_functions.crazy_utils import get_files_from_everything
|
||||
success, file_manifest, project_folder = get_files_from_everything(txt, type='.pdf')
|
||||
success, file_manifest, project_folder = get_files_from_everything(txt, type='.pdf', chatbot=chatbot)
|
||||
if len(file_manifest) > 0:
|
||||
# 尝试导入依赖,如果缺少依赖,则给出安装建议
|
||||
try:
|
||||
@@ -73,7 +73,7 @@ def 批量翻译PDF文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
|
||||
b=f"导入软件依赖失败。使用该模块需要额外依赖,安装方法```pip install --upgrade nougat-ocr tiktoken```。")
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
return
|
||||
success_mmd, file_manifest_mmd, _ = get_files_from_everything(txt, type='.mmd')
|
||||
success_mmd, file_manifest_mmd, _ = get_files_from_everything(txt, type='.mmd', chatbot=chatbot)
|
||||
success = success or success_mmd
|
||||
file_manifest += file_manifest_mmd
|
||||
chatbot.append(["文件列表:", ", ".join([e.split('/')[-1] for e in file_manifest])]);
|
||||
|
||||
@@ -87,6 +87,8 @@ def 理解PDF文档内容标准文件输入(txt, llm_kwargs, plugin_kwargs, chat
|
||||
# 检测输入参数,如没有给定输入参数,直接退出
|
||||
if os.path.exists(txt):
|
||||
project_folder = txt
|
||||
from shared_utils.fastapi_server import validate_path_safety
|
||||
validate_path_safety(project_folder, chatbot.get_user())
|
||||
else:
|
||||
if txt == "":
|
||||
txt = '空空如也的输入栏'
|
||||
|
||||
@@ -39,6 +39,8 @@ def 批量生成函数注释(txt, llm_kwargs, plugin_kwargs, chatbot, history, s
|
||||
import glob, os
|
||||
if os.path.exists(txt):
|
||||
project_folder = txt
|
||||
from shared_utils.fastapi_server import validate_path_safety
|
||||
validate_path_safety(project_folder, chatbot.get_user())
|
||||
else:
|
||||
if txt == "": txt = '空空如也的输入栏'
|
||||
report_exception(chatbot, history, a = f"解析项目: {txt}", b = f"找不到本地项目或无权访问: {txt}")
|
||||
|
||||
@@ -49,7 +49,7 @@ def 知识库文件注入(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
|
||||
file_manifest = []
|
||||
spl = ["txt", "doc", "docx", "email", "epub", "html", "json", "md", "msg", "pdf", "ppt", "pptx", "rtf"]
|
||||
for sp in spl:
|
||||
_, file_manifest_tmp, _ = get_files_from_everything(txt, type=f'.{sp}')
|
||||
_, file_manifest_tmp, _ = get_files_from_everything(txt, type=f'.{sp}', chatbot=chatbot)
|
||||
file_manifest += file_manifest_tmp
|
||||
|
||||
if len(file_manifest) == 0:
|
||||
|
||||
@@ -126,6 +126,8 @@ def 解析ipynb文件(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_p
|
||||
import os
|
||||
if os.path.exists(txt):
|
||||
project_folder = txt
|
||||
from shared_utils.fastapi_server import validate_path_safety
|
||||
validate_path_safety(project_folder, chatbot.get_user())
|
||||
else:
|
||||
if txt == "":
|
||||
txt = '空空如也的输入栏'
|
||||
|
||||
@@ -48,6 +48,8 @@ def 读文章写摘要(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_
|
||||
import glob, os
|
||||
if os.path.exists(txt):
|
||||
project_folder = txt
|
||||
from shared_utils.fastapi_server import validate_path_safety
|
||||
validate_path_safety(project_folder, chatbot.get_user())
|
||||
else:
|
||||
if txt == "": txt = '空空如也的输入栏'
|
||||
report_exception(chatbot, history, a = f"解析项目: {txt}", b = f"找不到本地项目或无权访问: {txt}")
|
||||
|
||||
296
shared_utils/context_clip_policy.py
普通文件
296
shared_utils/context_clip_policy.py
普通文件
@@ -0,0 +1,296 @@
|
||||
import copy
|
||||
from shared_utils.config_loader import get_conf
|
||||
|
||||
def get_token_num(txt, tokenizer):
|
||||
return len(tokenizer.encode(txt, disallowed_special=()))
|
||||
|
||||
def get_model_info():
|
||||
from request_llms.bridge_all import model_info
|
||||
return model_info
|
||||
|
||||
def clip_history(inputs, history, tokenizer, max_token_limit):
|
||||
"""
|
||||
reduce the length of history by clipping.
|
||||
this function search for the longest entries to clip, little by little,
|
||||
until the number of token of history is reduced under threshold.
|
||||
|
||||
通过裁剪来缩短历史记录的长度。
|
||||
此函数逐渐地搜索最长的条目进行剪辑,
|
||||
直到历史记录的标记数量降低到阈值以下。
|
||||
|
||||
被动触发裁剪
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
input_token_num = get_token_num(inputs)
|
||||
|
||||
if max_token_limit < 5000:
|
||||
output_token_expect = 256 # 4k & 2k models
|
||||
elif max_token_limit < 9000:
|
||||
output_token_expect = 512 # 8k models
|
||||
else:
|
||||
output_token_expect = 1024 # 16k & 32k models
|
||||
|
||||
if input_token_num < max_token_limit * 3 / 4:
|
||||
# 当输入部分的token占比小于限制的3/4时,裁剪时
|
||||
# 1. 把input的余量留出来
|
||||
max_token_limit = max_token_limit - input_token_num
|
||||
# 2. 把输出用的余量留出来
|
||||
max_token_limit = max_token_limit - output_token_expect
|
||||
# 3. 如果余量太小了,直接清除历史
|
||||
if max_token_limit < output_token_expect:
|
||||
history = []
|
||||
return history
|
||||
else:
|
||||
# 当输入部分的token占比 > 限制的3/4时,直接清除历史
|
||||
history = []
|
||||
return history
|
||||
|
||||
everything = [""]
|
||||
everything.extend(history)
|
||||
n_token = get_token_num("\n".join(everything))
|
||||
everything_token = [get_token_num(e) for e in everything]
|
||||
|
||||
# 截断时的颗粒度
|
||||
delta = max(everything_token) // 16
|
||||
|
||||
while n_token > max_token_limit:
|
||||
where = np.argmax(everything_token)
|
||||
encoded = tokenizer.encode(everything[where], disallowed_special=())
|
||||
clipped_encoded = encoded[: len(encoded) - delta]
|
||||
everything[where] = tokenizer.decode(clipped_encoded)[
|
||||
:-1
|
||||
] # -1 to remove the may-be illegal char
|
||||
everything_token[where] = get_token_num(everything[where])
|
||||
n_token = get_token_num("\n".join(everything))
|
||||
|
||||
history = everything[1:]
|
||||
return history
|
||||
|
||||
|
||||
|
||||
def auto_context_clip_each_message(current, history):
|
||||
"""
|
||||
clip_history 是被动触发的
|
||||
|
||||
主动触发裁剪
|
||||
"""
|
||||
context = history + [current]
|
||||
trigger_clip_token_len = get_conf('AUTO_CONTEXT_CLIP_TRIGGER_TOKEN_LEN')
|
||||
model_info = get_model_info()
|
||||
tokenizer = model_info['gpt-4']['tokenizer']
|
||||
# 只保留最近的128条记录,无论token长度,防止计算token时计算过长的时间
|
||||
max_round = get_conf('AUTO_CONTEXT_MAX_ROUND')
|
||||
char_len = sum([len(h) for h in context])
|
||||
if char_len < trigger_clip_token_len*2:
|
||||
# 不需要裁剪
|
||||
history = context[:-1]
|
||||
current = context[-1]
|
||||
return current, history
|
||||
if len(context) > max_round:
|
||||
context = context[-max_round:]
|
||||
# 计算各个历史记录的token长度
|
||||
context_token_num = [get_token_num(h, tokenizer) for h in context]
|
||||
context_token_num_old = copy.copy(context_token_num)
|
||||
total_token_num = total_token_num_old = sum(context_token_num)
|
||||
if total_token_num < trigger_clip_token_len:
|
||||
# 不需要裁剪
|
||||
history = context[:-1]
|
||||
current = context[-1]
|
||||
return current, history
|
||||
clip_token_len = trigger_clip_token_len * 0.85
|
||||
# 越长越先被裁,越靠后越先被裁
|
||||
max_clip_ratio: list[float] = get_conf('AUTO_CONTEXT_MAX_CLIP_RATIO')
|
||||
max_clip_ratio = list(reversed(max_clip_ratio))
|
||||
if len(context) > len(max_clip_ratio):
|
||||
# give up the oldest context
|
||||
context = context[-len(max_clip_ratio):]
|
||||
context_token_num = context_token_num[-len(max_clip_ratio):]
|
||||
if len(context) < len(max_clip_ratio):
|
||||
# match the length of two array
|
||||
max_clip_ratio = max_clip_ratio[-len(context):]
|
||||
|
||||
# compute rank
|
||||
clip_prior_weight = [(token_num/clip_token_len + (len(context) - index)*0.1) for index, token_num in enumerate(context_token_num)]
|
||||
# print('clip_prior_weight', clip_prior_weight)
|
||||
# get sorted index of context_token_num, from largest to smallest
|
||||
sorted_index = sorted(range(len(context_token_num)), key=lambda k: clip_prior_weight[k], reverse=True)
|
||||
|
||||
# pre compute space yield
|
||||
for index in sorted_index:
|
||||
print('index', index, f'current total {total_token_num}, target {clip_token_len}')
|
||||
if total_token_num < clip_token_len:
|
||||
# no need to clip
|
||||
break
|
||||
# clip room left
|
||||
clip_room_left = total_token_num - clip_token_len
|
||||
# get the clip ratio
|
||||
allowed_token_num_this_entry = max_clip_ratio[index] * clip_token_len
|
||||
if context_token_num[index] < allowed_token_num_this_entry:
|
||||
print('index', index, '[allowed] before', context_token_num[index], 'allowed', allowed_token_num_this_entry)
|
||||
continue
|
||||
|
||||
token_to_clip = context_token_num[index] - allowed_token_num_this_entry
|
||||
if token_to_clip*0.85 > clip_room_left:
|
||||
print('index', index, '[careful clip] token_to_clip', token_to_clip, 'clip_room_left', clip_room_left)
|
||||
token_to_clip = clip_room_left
|
||||
|
||||
token_percent_to_clip = token_to_clip / context_token_num[index]
|
||||
char_percent_to_clip = token_percent_to_clip
|
||||
text_this_entry = context[index]
|
||||
char_num_to_clip = int(len(text_this_entry) * char_percent_to_clip)
|
||||
if char_num_to_clip < 500:
|
||||
# 如果裁剪的字符数小于500,则不裁剪
|
||||
print('index', index, 'before', context_token_num[index], 'allowed', allowed_token_num_this_entry)
|
||||
continue
|
||||
char_num_to_clip += 200 # 稍微多加一点
|
||||
char_to_preseve = len(text_this_entry) - char_num_to_clip
|
||||
_half = int(char_to_preseve / 2)
|
||||
# 前半 + ... (content clipped because token overflows) ... + 后半
|
||||
text_this_entry_clip = text_this_entry[:_half] + \
|
||||
" ... (content clipped because token overflows) ... " \
|
||||
+ text_this_entry[-_half:]
|
||||
context[index] = text_this_entry_clip
|
||||
post_clip_token_cnt = get_token_num(text_this_entry_clip, tokenizer)
|
||||
print('index', index, 'before', context_token_num[index], 'allowed', allowed_token_num_this_entry, 'after', post_clip_token_cnt)
|
||||
context_token_num[index] = post_clip_token_cnt
|
||||
total_token_num = sum(context_token_num)
|
||||
context_token_num_final = [get_token_num(h, tokenizer) for h in context]
|
||||
print('context_token_num_old', context_token_num_old)
|
||||
print('context_token_num_final', context_token_num_final)
|
||||
print('token change from', total_token_num_old, 'to', sum(context_token_num_final), 'target', clip_token_len)
|
||||
history = context[:-1]
|
||||
current = context[-1]
|
||||
return current, history
|
||||
|
||||
|
||||
|
||||
def auto_context_clip_search_optimal(current, history, promote_latest_long_message=False):
|
||||
"""
|
||||
current: 当前消息
|
||||
history: 历史消息列表
|
||||
promote_latest_long_message: 是否特别提高最后一条长message的权重,避免过度裁剪
|
||||
|
||||
主动触发裁剪
|
||||
"""
|
||||
context = history + [current]
|
||||
trigger_clip_token_len = get_conf('AUTO_CONTEXT_CLIP_TRIGGER_TOKEN_LEN')
|
||||
model_info = get_model_info()
|
||||
tokenizer = model_info['gpt-4']['tokenizer']
|
||||
# 只保留最近的128条记录,无论token长度,防止计算token时计算过长的时间
|
||||
max_round = get_conf('AUTO_CONTEXT_MAX_ROUND')
|
||||
char_len = sum([len(h) for h in context])
|
||||
if char_len < trigger_clip_token_len:
|
||||
# 不需要裁剪
|
||||
history = context[:-1]
|
||||
current = context[-1]
|
||||
return current, history
|
||||
if len(context) > max_round:
|
||||
context = context[-max_round:]
|
||||
# 计算各个历史记录的token长度
|
||||
context_token_num = [get_token_num(h, tokenizer) for h in context]
|
||||
context_token_num_old = copy.copy(context_token_num)
|
||||
total_token_num = total_token_num_old = sum(context_token_num)
|
||||
if total_token_num < trigger_clip_token_len:
|
||||
# 不需要裁剪
|
||||
history = context[:-1]
|
||||
current = context[-1]
|
||||
return current, history
|
||||
clip_token_len = trigger_clip_token_len * 0.90
|
||||
max_clip_ratio: list[float] = get_conf('AUTO_CONTEXT_MAX_CLIP_RATIO')
|
||||
max_clip_ratio = list(reversed(max_clip_ratio))
|
||||
if len(context) > len(max_clip_ratio):
|
||||
# give up the oldest context
|
||||
context = context[-len(max_clip_ratio):]
|
||||
context_token_num = context_token_num[-len(max_clip_ratio):]
|
||||
if len(context) < len(max_clip_ratio):
|
||||
# match the length of two array
|
||||
max_clip_ratio = max_clip_ratio[-len(context):]
|
||||
|
||||
_scale = _scale_init = 1.25
|
||||
token_percent_arr = [(token_num/clip_token_len) for index, token_num in enumerate(context_token_num)]
|
||||
|
||||
# promote last long message, avoid clipping it too much
|
||||
if promote_latest_long_message:
|
||||
promote_weight_constant = 1.6
|
||||
promote_index = -1
|
||||
threshold = 0.50
|
||||
for index, token_percent in enumerate(token_percent_arr):
|
||||
if token_percent > threshold:
|
||||
promote_index = index
|
||||
if promote_index >= 0:
|
||||
max_clip_ratio[promote_index] = promote_weight_constant
|
||||
|
||||
max_clip_ratio_arr = max_clip_ratio
|
||||
step = 0.05
|
||||
for i in range(int(_scale_init / step) - 1):
|
||||
_take = 0
|
||||
for max_clip, token_r in zip(max_clip_ratio_arr, token_percent_arr):
|
||||
_take += min(max_clip * _scale, token_r)
|
||||
if _take < 1.0:
|
||||
break
|
||||
_scale -= 0.05
|
||||
|
||||
# print('optimal scale', _scale)
|
||||
# print([_scale * max_clip for max_clip in max_clip_ratio_arr])
|
||||
# print([token_r for token_r in token_percent_arr])
|
||||
# print([min(token_r, _scale * max_clip) for token_r, max_clip in zip(token_percent_arr, max_clip_ratio_arr)])
|
||||
eps = 0.05
|
||||
max_clip_ratio = [_scale * max_clip + eps for max_clip in max_clip_ratio_arr]
|
||||
|
||||
# compute rank
|
||||
# clip_prior_weight_old = [(token_num/clip_token_len + (len(context) - index)*0.1) for index, token_num in enumerate(context_token_num)]
|
||||
clip_prior_weight = [ token_r / max_clip for max_clip, token_r in zip(max_clip_ratio_arr, token_percent_arr)]
|
||||
|
||||
# sorted_index_old = sorted(range(len(context_token_num)), key=lambda k: clip_prior_weight_old[k], reverse=True)
|
||||
# print('sorted_index_old', sorted_index_old)
|
||||
sorted_index = sorted(range(len(context_token_num)), key=lambda k: clip_prior_weight[k], reverse=True)
|
||||
# print('sorted_index', sorted_index)
|
||||
|
||||
# pre compute space yield
|
||||
for index in sorted_index:
|
||||
# print('index', index, f'current total {total_token_num}, target {clip_token_len}')
|
||||
if total_token_num < clip_token_len:
|
||||
# no need to clip
|
||||
break
|
||||
# clip room left
|
||||
clip_room_left = total_token_num - clip_token_len
|
||||
# get the clip ratio
|
||||
allowed_token_num_this_entry = max_clip_ratio[index] * clip_token_len
|
||||
if context_token_num[index] < allowed_token_num_this_entry:
|
||||
# print('index', index, '[allowed] before', context_token_num[index], 'allowed', allowed_token_num_this_entry)
|
||||
continue
|
||||
|
||||
token_to_clip = context_token_num[index] - allowed_token_num_this_entry
|
||||
if token_to_clip*0.85 > clip_room_left:
|
||||
# print('index', index, '[careful clip] token_to_clip', token_to_clip, 'clip_room_left', clip_room_left)
|
||||
token_to_clip = clip_room_left
|
||||
|
||||
token_percent_to_clip = token_to_clip / context_token_num[index]
|
||||
char_percent_to_clip = token_percent_to_clip
|
||||
text_this_entry = context[index]
|
||||
char_num_to_clip = int(len(text_this_entry) * char_percent_to_clip)
|
||||
if char_num_to_clip < 500:
|
||||
# 如果裁剪的字符数小于500,则不裁剪
|
||||
# print('index', index, 'before', context_token_num[index], 'allowed', allowed_token_num_this_entry)
|
||||
continue
|
||||
eps = 200
|
||||
char_num_to_clip = char_num_to_clip + eps # 稍微多加一点
|
||||
char_to_preseve = len(text_this_entry) - char_num_to_clip
|
||||
_half = int(char_to_preseve / 2)
|
||||
# 前半 + ... (content clipped because token overflows) ... + 后半
|
||||
text_this_entry_clip = text_this_entry[:_half] + \
|
||||
" ... (content clipped because token overflows) ... " \
|
||||
+ text_this_entry[-_half:]
|
||||
context[index] = text_this_entry_clip
|
||||
post_clip_token_cnt = get_token_num(text_this_entry_clip, tokenizer)
|
||||
# print('index', index, 'before', context_token_num[index], 'allowed', allowed_token_num_this_entry, 'after', post_clip_token_cnt)
|
||||
context_token_num[index] = post_clip_token_cnt
|
||||
total_token_num = sum(context_token_num)
|
||||
context_token_num_final = [get_token_num(h, tokenizer) for h in context]
|
||||
# print('context_token_num_old', context_token_num_old)
|
||||
# print('context_token_num_final', context_token_num_final)
|
||||
# print('token change from', total_token_num_old, 'to', sum(context_token_num_final), 'target', clip_token_len)
|
||||
history = context[:-1]
|
||||
current = context[-1]
|
||||
return current, history
|
||||
@@ -51,7 +51,7 @@ def validate_path_safety(path_or_url, user):
|
||||
from toolbox import get_conf, default_user_name
|
||||
from toolbox import FriendlyException
|
||||
PATH_PRIVATE_UPLOAD, PATH_LOGGING = get_conf('PATH_PRIVATE_UPLOAD', 'PATH_LOGGING')
|
||||
sensitive_path = None
|
||||
sensitive_path = None # 必须不能包含 '/',即不能是多级路径
|
||||
path_or_url = os.path.relpath(path_or_url)
|
||||
if path_or_url.startswith(PATH_LOGGING): # 日志文件(按用户划分)
|
||||
sensitive_path = PATH_LOGGING
|
||||
|
||||
@@ -4,7 +4,6 @@ from functools import wraps, lru_cache
|
||||
from shared_utils.advanced_markdown_format import format_io
|
||||
from shared_utils.config_loader import get_conf as get_conf
|
||||
|
||||
|
||||
pj = os.path.join
|
||||
default_user_name = 'default_user'
|
||||
|
||||
@@ -12,11 +11,13 @@ default_user_name = 'default_user'
|
||||
openai_regex = re.compile(
|
||||
r"sk-[a-zA-Z0-9_-]{48}$|" +
|
||||
r"sk-[a-zA-Z0-9_-]{92}$|" +
|
||||
r"sk-proj-[a-zA-Z0-9_-]{48}$|"+
|
||||
r"sk-proj-[a-zA-Z0-9_-]{124}$|"+
|
||||
r"sk-proj-[a-zA-Z0-9_-]{156}$|"+ #新版apikey位数不匹配故修改此正则表达式
|
||||
r"sk-proj-[a-zA-Z0-9_-]{48}$|" +
|
||||
r"sk-proj-[a-zA-Z0-9_-]{124}$|" +
|
||||
r"sk-proj-[a-zA-Z0-9_-]{156}$|" + #新版apikey位数不匹配故修改此正则表达式
|
||||
r"sess-[a-zA-Z0-9]{40}$"
|
||||
)
|
||||
|
||||
|
||||
def is_openai_api_key(key):
|
||||
CUSTOM_API_KEY_PATTERN = get_conf('CUSTOM_API_KEY_PATTERN')
|
||||
if len(CUSTOM_API_KEY_PATTERN) != 0:
|
||||
@@ -27,7 +28,7 @@ def is_openai_api_key(key):
|
||||
|
||||
|
||||
def is_azure_api_key(key):
|
||||
API_MATCH_AZURE = re.match(r"[a-zA-Z0-9]{32}$", key)
|
||||
API_MATCH_AZURE = re.match(r"^[a-zA-Z0-9]{32}$|^[a-zA-Z0-9]{84}", key)
|
||||
return bool(API_MATCH_AZURE)
|
||||
|
||||
|
||||
@@ -35,10 +36,12 @@ def is_api2d_key(key):
|
||||
API_MATCH_API2D = re.match(r"fk[a-zA-Z0-9]{6}-[a-zA-Z0-9]{32}$", key)
|
||||
return bool(API_MATCH_API2D)
|
||||
|
||||
|
||||
def is_openroute_api_key(key):
|
||||
API_MATCH_OPENROUTE = re.match(r"sk-or-v1-[a-zA-Z0-9]{64}$", key)
|
||||
return bool(API_MATCH_OPENROUTE)
|
||||
|
||||
|
||||
def is_cohere_api_key(key):
|
||||
API_MATCH_AZURE = re.match(r"[a-zA-Z0-9]{40}$", key)
|
||||
return bool(API_MATCH_AZURE)
|
||||
@@ -109,7 +112,7 @@ def select_api_key(keys, llm_model):
|
||||
if llm_model.startswith('cohere-'):
|
||||
for k in key_list:
|
||||
if is_cohere_api_key(k): avail_key_list.append(k)
|
||||
|
||||
|
||||
if llm_model.startswith('openrouter-'):
|
||||
for k in key_list:
|
||||
if is_openroute_api_key(k): avail_key_list.append(k)
|
||||
@@ -117,7 +120,7 @@ def select_api_key(keys, llm_model):
|
||||
if len(avail_key_list) == 0:
|
||||
raise RuntimeError(f"您提供的api-key不满足要求,不包含任何可用于{llm_model}的api-key。您可能选择了错误的模型或请求源(左上角更换模型菜单中可切换openai,azure,claude,cohere等请求源)。")
|
||||
|
||||
api_key = random.choice(avail_key_list) # 随机负载均衡
|
||||
api_key = random.choice(avail_key_list) # 随机负载均衡
|
||||
return api_key
|
||||
|
||||
|
||||
@@ -133,5 +136,5 @@ def select_api_key_for_embed_models(keys, llm_model):
|
||||
if len(avail_key_list) == 0:
|
||||
raise RuntimeError(f"您提供的api-key不满足要求,不包含任何可用于{llm_model}的api-key。您可能选择了错误的模型或请求源。")
|
||||
|
||||
api_key = random.choice(avail_key_list) # 随机负载均衡
|
||||
api_key = random.choice(avail_key_list) # 随机负载均衡
|
||||
return api_key
|
||||
|
||||
87
toolbox.py
87
toolbox.py
@@ -37,6 +37,9 @@ from shared_utils.handle_upload import html_local_file
|
||||
from shared_utils.handle_upload import html_local_img
|
||||
from shared_utils.handle_upload import file_manifest_filter_type
|
||||
from shared_utils.handle_upload import extract_archive
|
||||
from shared_utils.context_clip_policy import clip_history
|
||||
from shared_utils.context_clip_policy import auto_context_clip_each_message
|
||||
from shared_utils.context_clip_policy import auto_context_clip_search_optimal
|
||||
from typing import List
|
||||
pj = os.path.join
|
||||
default_user_name = "default_user"
|
||||
@@ -133,6 +136,9 @@ def ArgsGeneralWrapper(f):
|
||||
if len(args) == 0: # 插件通道
|
||||
yield from f(txt_passon, llm_kwargs, plugin_kwargs, chatbot_with_cookie, history, system_prompt, request)
|
||||
else: # 对话通道,或者基础功能通道
|
||||
# 基础对话通道,或者基础功能通道
|
||||
if get_conf('AUTO_CONTEXT_CLIP_ENABLE'):
|
||||
txt_passon, history = auto_context_clip(txt_passon, history)
|
||||
yield from f(txt_passon, llm_kwargs, plugin_kwargs, chatbot_with_cookie, history, system_prompt, *args)
|
||||
else:
|
||||
# 处理少数情况下的特殊插件的锁定状态
|
||||
@@ -499,6 +505,22 @@ def to_markdown_tabs(head: list, tabs: list, alignment=":---:", column=False, om
|
||||
|
||||
return tabs_list
|
||||
|
||||
def validate_file_size(files, max_size_mb=500):
|
||||
"""
|
||||
验证文件大小是否在允许范围内。
|
||||
:param files: 文件的完整路径的列表
|
||||
:param max_size_mb: 最大文件大小,单位为MB(默认500MB)
|
||||
:return: True 如果文件大小有效,否则抛出异常
|
||||
"""
|
||||
# 获取文件大小(字节)
|
||||
total_size = 0
|
||||
max_size_bytes = max_size_mb * 1024 * 1024
|
||||
for file in files:
|
||||
total_size += os.path.getsize(file.name)
|
||||
if total_size > max_size_bytes:
|
||||
raise ValueError(f"File size exceeds the allowed limit of {max_size_mb} MB. "
|
||||
f"Current size: {total_size / (1024 * 1024):.2f} MB")
|
||||
return True
|
||||
|
||||
def on_file_uploaded(
|
||||
request: gradio.Request, files:List[str], chatbot:ChatBotWithCookies,
|
||||
@@ -510,6 +532,7 @@ def on_file_uploaded(
|
||||
if len(files) == 0:
|
||||
return chatbot, txt
|
||||
|
||||
validate_file_size(files, max_size_mb=500)
|
||||
# 创建工作路径
|
||||
user_name = default_user_name if not request.username else request.username
|
||||
time_tag = gen_time_str()
|
||||
@@ -712,66 +735,14 @@ def run_gradio_in_subpath(demo, auth, port, custom_path):
|
||||
app = gr.mount_gradio_app(app, demo, path=custom_path)
|
||||
uvicorn.run(app, host="0.0.0.0", port=port) # , auth=auth
|
||||
|
||||
|
||||
def clip_history(inputs, history, tokenizer, max_token_limit):
|
||||
"""
|
||||
reduce the length of history by clipping.
|
||||
this function search for the longest entries to clip, little by little,
|
||||
until the number of token of history is reduced under threshold.
|
||||
通过裁剪来缩短历史记录的长度。
|
||||
此函数逐渐地搜索最长的条目进行剪辑,
|
||||
直到历史记录的标记数量降低到阈值以下。
|
||||
"""
|
||||
import numpy as np
|
||||
from request_llms.bridge_all import model_info
|
||||
|
||||
def get_token_num(txt):
|
||||
return len(tokenizer.encode(txt, disallowed_special=()))
|
||||
|
||||
input_token_num = get_token_num(inputs)
|
||||
|
||||
if max_token_limit < 5000:
|
||||
output_token_expect = 256 # 4k & 2k models
|
||||
elif max_token_limit < 9000:
|
||||
output_token_expect = 512 # 8k models
|
||||
def auto_context_clip(current, history, policy='search_optimal'):
|
||||
if policy == 'each_message':
|
||||
return auto_context_clip_each_message(current, history)
|
||||
elif policy == 'search_optimal':
|
||||
return auto_context_clip_search_optimal(current, history)
|
||||
else:
|
||||
output_token_expect = 1024 # 16k & 32k models
|
||||
raise RuntimeError(f"未知的自动上下文裁剪策略: {policy}。")
|
||||
|
||||
if input_token_num < max_token_limit * 3 / 4:
|
||||
# 当输入部分的token占比小于限制的3/4时,裁剪时
|
||||
# 1. 把input的余量留出来
|
||||
max_token_limit = max_token_limit - input_token_num
|
||||
# 2. 把输出用的余量留出来
|
||||
max_token_limit = max_token_limit - output_token_expect
|
||||
# 3. 如果余量太小了,直接清除历史
|
||||
if max_token_limit < output_token_expect:
|
||||
history = []
|
||||
return history
|
||||
else:
|
||||
# 当输入部分的token占比 > 限制的3/4时,直接清除历史
|
||||
history = []
|
||||
return history
|
||||
|
||||
everything = [""]
|
||||
everything.extend(history)
|
||||
n_token = get_token_num("\n".join(everything))
|
||||
everything_token = [get_token_num(e) for e in everything]
|
||||
|
||||
# 截断时的颗粒度
|
||||
delta = max(everything_token) // 16
|
||||
|
||||
while n_token > max_token_limit:
|
||||
where = np.argmax(everything_token)
|
||||
encoded = tokenizer.encode(everything[where], disallowed_special=())
|
||||
clipped_encoded = encoded[: len(encoded) - delta]
|
||||
everything[where] = tokenizer.decode(clipped_encoded)[
|
||||
:-1
|
||||
] # -1 to remove the may-be illegal char
|
||||
everything_token[where] = get_token_num(everything[where])
|
||||
n_token = get_token_num("\n".join(everything))
|
||||
|
||||
history = everything[1:]
|
||||
return history
|
||||
|
||||
|
||||
"""
|
||||
|
||||
在新工单中引用
屏蔽一个用户