镜像自地址
https://github.com/binary-husky/gpt_academic.git
已同步 2025-12-06 22:46:48 +00:00
优化chatgpt对话的截断策略
这个提交包含在:
46
toolbox.py
46
toolbox.py
@@ -551,3 +551,49 @@ def run_gradio_in_subpath(demo, auth, port, custom_path):
|
||||
return {"message": f"Gradio is running at: {custom_path}"}
|
||||
app = gr.mount_gradio_app(app, demo, path=custom_path)
|
||||
uvicorn.run(app, host="0.0.0.0", port=port) # , auth=auth
|
||||
|
||||
|
||||
def clip_history(inputs, history, tokenizer, max_token_limit):
|
||||
"""
|
||||
reduce the length of input/history by clipping.
|
||||
this function search for the longest entries to clip, little by little,
|
||||
until the number of token of input/history is reduced under threshold.
|
||||
通过剪辑来缩短输入/历史记录的长度。
|
||||
此函数逐渐地搜索最长的条目进行剪辑,
|
||||
直到输入/历史记录的标记数量降低到阈值以下。
|
||||
"""
|
||||
import numpy as np
|
||||
from request_llm.bridge_all import model_info
|
||||
def get_token_num(txt):
|
||||
return len(tokenizer.encode(txt, disallowed_special=()))
|
||||
input_token_num = get_token_num(inputs)
|
||||
if input_token_num < max_token_limit * 3 / 4:
|
||||
# 当输入部分的token占比小于限制的3/4时,在裁剪时把input的余量留出来
|
||||
max_token_limit = max_token_limit - input_token_num
|
||||
if max_token_limit < 128:
|
||||
# 余量太小了,直接清除历史
|
||||
history = []
|
||||
return history
|
||||
else:
|
||||
# 当输入部分的token占比 > 限制的3/4时,直接清除历史
|
||||
history = []
|
||||
return history
|
||||
|
||||
everything = ['']
|
||||
everything.extend(history)
|
||||
n_token = get_token_num('\n'.join(everything))
|
||||
everything_token = [get_token_num(e) for e in everything]
|
||||
|
||||
# 截断时的颗粒度
|
||||
delta = max(everything_token) // 16
|
||||
|
||||
while n_token > max_token_limit:
|
||||
where = np.argmax(everything_token)
|
||||
encoded = tokenizer.encode(everything[where], disallowed_special=())
|
||||
clipped_encoded = encoded[:len(encoded)-delta]
|
||||
everything[where] = tokenizer.decode(clipped_encoded)[:-1] # -1 to remove the may-be illegal char
|
||||
everything_token[where] = get_token_num(everything[where])
|
||||
n_token = get_token_num('\n'.join(everything))
|
||||
|
||||
history = everything[1:]
|
||||
return history
|
||||
|
||||
在新工单中引用
屏蔽一个用户