镜像自地址
https://github.com/binary-husky/gpt_academic.git
已同步 2025-12-06 14:36:48 +00:00
add context clip policy
这个提交包含在:
70
toolbox.py
70
toolbox.py
@@ -37,6 +37,9 @@ from shared_utils.handle_upload import html_local_file
|
||||
from shared_utils.handle_upload import html_local_img
|
||||
from shared_utils.handle_upload import file_manifest_filter_type
|
||||
from shared_utils.handle_upload import extract_archive
|
||||
from shared_utils.context_clip_policy import clip_history
|
||||
from shared_utils.context_clip_policy import auto_context_clip_each_message
|
||||
from shared_utils.context_clip_policy import auto_context_clip_search_optimal
|
||||
from typing import List
|
||||
pj = os.path.join
|
||||
default_user_name = "default_user"
|
||||
@@ -133,6 +136,9 @@ def ArgsGeneralWrapper(f):
|
||||
if len(args) == 0: # 插件通道
|
||||
yield from f(txt_passon, llm_kwargs, plugin_kwargs, chatbot_with_cookie, history, system_prompt, request)
|
||||
else: # 对话通道,或者基础功能通道
|
||||
# 基础对话通道,或者基础功能通道
|
||||
if get_conf('AUTO_CONTEXT_CLIP_ENABLE'):
|
||||
txt_passon, history = auto_context_clip(txt_passon, history)
|
||||
yield from f(txt_passon, llm_kwargs, plugin_kwargs, chatbot_with_cookie, history, system_prompt, *args)
|
||||
else:
|
||||
# 处理少数情况下的特殊插件的锁定状态
|
||||
@@ -729,66 +735,14 @@ def run_gradio_in_subpath(demo, auth, port, custom_path):
|
||||
app = gr.mount_gradio_app(app, demo, path=custom_path)
|
||||
uvicorn.run(app, host="0.0.0.0", port=port) # , auth=auth
|
||||
|
||||
|
||||
def clip_history(inputs, history, tokenizer, max_token_limit):
|
||||
"""
|
||||
reduce the length of history by clipping.
|
||||
this function search for the longest entries to clip, little by little,
|
||||
until the number of token of history is reduced under threshold.
|
||||
通过裁剪来缩短历史记录的长度。
|
||||
此函数逐渐地搜索最长的条目进行剪辑,
|
||||
直到历史记录的标记数量降低到阈值以下。
|
||||
"""
|
||||
import numpy as np
|
||||
from request_llms.bridge_all import model_info
|
||||
|
||||
def get_token_num(txt):
|
||||
return len(tokenizer.encode(txt, disallowed_special=()))
|
||||
|
||||
input_token_num = get_token_num(inputs)
|
||||
|
||||
if max_token_limit < 5000:
|
||||
output_token_expect = 256 # 4k & 2k models
|
||||
elif max_token_limit < 9000:
|
||||
output_token_expect = 512 # 8k models
|
||||
def auto_context_clip(current, history, policy='search_optimal'):
|
||||
if policy == 'each_message':
|
||||
return auto_context_clip_each_message(current, history)
|
||||
elif policy == 'search_optimal':
|
||||
return auto_context_clip_search_optimal(current, history)
|
||||
else:
|
||||
output_token_expect = 1024 # 16k & 32k models
|
||||
raise RuntimeError(f"未知的自动上下文裁剪策略: {policy}。")
|
||||
|
||||
if input_token_num < max_token_limit * 3 / 4:
|
||||
# 当输入部分的token占比小于限制的3/4时,裁剪时
|
||||
# 1. 把input的余量留出来
|
||||
max_token_limit = max_token_limit - input_token_num
|
||||
# 2. 把输出用的余量留出来
|
||||
max_token_limit = max_token_limit - output_token_expect
|
||||
# 3. 如果余量太小了,直接清除历史
|
||||
if max_token_limit < output_token_expect:
|
||||
history = []
|
||||
return history
|
||||
else:
|
||||
# 当输入部分的token占比 > 限制的3/4时,直接清除历史
|
||||
history = []
|
||||
return history
|
||||
|
||||
everything = [""]
|
||||
everything.extend(history)
|
||||
n_token = get_token_num("\n".join(everything))
|
||||
everything_token = [get_token_num(e) for e in everything]
|
||||
|
||||
# 截断时的颗粒度
|
||||
delta = max(everything_token) // 16
|
||||
|
||||
while n_token > max_token_limit:
|
||||
where = np.argmax(everything_token)
|
||||
encoded = tokenizer.encode(everything[where], disallowed_special=())
|
||||
clipped_encoded = encoded[: len(encoded) - delta]
|
||||
everything[where] = tokenizer.decode(clipped_encoded)[
|
||||
:-1
|
||||
] # -1 to remove the may-be illegal char
|
||||
everything_token[where] = get_token_num(everything[where])
|
||||
n_token = get_token_num("\n".join(everything))
|
||||
|
||||
history = everything[1:]
|
||||
return history
|
||||
|
||||
|
||||
"""
|
||||
|
||||
在新工单中引用
屏蔽一个用户