add context clip policy

这个提交包含在:
binary-husky
2025-06-03 00:51:18 +08:00
父节点 3ed1b0320e
当前提交 171e8a2744
共有 3 个文件被更改,包括 319 次插入58 次删除

查看文件

@@ -37,6 +37,9 @@ from shared_utils.handle_upload import html_local_file
from shared_utils.handle_upload import html_local_img
from shared_utils.handle_upload import file_manifest_filter_type
from shared_utils.handle_upload import extract_archive
from shared_utils.context_clip_policy import clip_history
from shared_utils.context_clip_policy import auto_context_clip_each_message
from shared_utils.context_clip_policy import auto_context_clip_search_optimal
from typing import List
pj = os.path.join
default_user_name = "default_user"
@@ -133,6 +136,9 @@ def ArgsGeneralWrapper(f):
if len(args) == 0: # 插件通道
yield from f(txt_passon, llm_kwargs, plugin_kwargs, chatbot_with_cookie, history, system_prompt, request)
else: # 对话通道,或者基础功能通道
# 基础对话通道,或者基础功能通道
if get_conf('AUTO_CONTEXT_CLIP_ENABLE'):
txt_passon, history = auto_context_clip(txt_passon, history)
yield from f(txt_passon, llm_kwargs, plugin_kwargs, chatbot_with_cookie, history, system_prompt, *args)
else:
# 处理少数情况下的特殊插件的锁定状态
@@ -729,66 +735,14 @@ def run_gradio_in_subpath(demo, auth, port, custom_path):
app = gr.mount_gradio_app(app, demo, path=custom_path)
uvicorn.run(app, host="0.0.0.0", port=port) # , auth=auth
def clip_history(inputs, history, tokenizer, max_token_limit):
"""
reduce the length of history by clipping.
this function search for the longest entries to clip, little by little,
until the number of token of history is reduced under threshold.
通过裁剪来缩短历史记录的长度。
此函数逐渐地搜索最长的条目进行剪辑,
直到历史记录的标记数量降低到阈值以下。
"""
import numpy as np
from request_llms.bridge_all import model_info
def get_token_num(txt):
return len(tokenizer.encode(txt, disallowed_special=()))
input_token_num = get_token_num(inputs)
if max_token_limit < 5000:
output_token_expect = 256 # 4k & 2k models
elif max_token_limit < 9000:
output_token_expect = 512 # 8k models
def auto_context_clip(current, history, policy='search_optimal'):
if policy == 'each_message':
return auto_context_clip_each_message(current, history)
elif policy == 'search_optimal':
return auto_context_clip_search_optimal(current, history)
else:
output_token_expect = 1024 # 16k & 32k models
raise RuntimeError(f"未知的自动上下文裁剪策略: {policy}")
if input_token_num < max_token_limit * 3 / 4:
# 当输入部分的token占比小于限制的3/4时,裁剪时
# 1. 把input的余量留出来
max_token_limit = max_token_limit - input_token_num
# 2. 把输出用的余量留出来
max_token_limit = max_token_limit - output_token_expect
# 3. 如果余量太小了,直接清除历史
if max_token_limit < output_token_expect:
history = []
return history
else:
# 当输入部分的token占比 > 限制的3/4时,直接清除历史
history = []
return history
everything = [""]
everything.extend(history)
n_token = get_token_num("\n".join(everything))
everything_token = [get_token_num(e) for e in everything]
# 截断时的颗粒度
delta = max(everything_token) // 16
while n_token > max_token_limit:
where = np.argmax(everything_token)
encoded = tokenizer.encode(everything[where], disallowed_special=())
clipped_encoded = encoded[: len(encoded) - delta]
everything[where] = tokenizer.decode(clipped_encoded)[
:-1
] # -1 to remove the may-be illegal char
everything_token[where] = get_token_num(everything[where])
n_token = get_token_num("\n".join(everything))
history = everything[1:]
return history
"""