优化chatgpt对话的截断策略

这个提交包含在:
binary-husky
2023-04-23 17:32:44 +08:00
父节点 0b89673ee9
当前提交 676fe40d39
共有 3 个文件被更改,包括 58 次插入8 次删除

查看文件

@@ -551,3 +551,49 @@ def run_gradio_in_subpath(demo, auth, port, custom_path):
return {"message": f"Gradio is running at: {custom_path}"}
app = gr.mount_gradio_app(app, demo, path=custom_path)
uvicorn.run(app, host="0.0.0.0", port=port) # , auth=auth
def clip_history(inputs, history, tokenizer, max_token_limit):
"""
reduce the length of input/history by clipping.
this function search for the longest entries to clip, little by little,
until the number of token of input/history is reduced under threshold.
通过剪辑来缩短输入/历史记录的长度。
此函数逐渐地搜索最长的条目进行剪辑,
直到输入/历史记录的标记数量降低到阈值以下。
"""
import numpy as np
from request_llm.bridge_all import model_info
def get_token_num(txt):
return len(tokenizer.encode(txt, disallowed_special=()))
input_token_num = get_token_num(inputs)
if input_token_num < max_token_limit * 3 / 4:
# 当输入部分的token占比小于限制的3/4时,在裁剪时把input的余量留出来
max_token_limit = max_token_limit - input_token_num
if max_token_limit < 128:
# 余量太小了,直接清除历史
history = []
return history
else:
# 当输入部分的token占比 > 限制的3/4时,直接清除历史
history = []
return history
everything = ['']
everything.extend(history)
n_token = get_token_num('\n'.join(everything))
everything_token = [get_token_num(e) for e in everything]
# 截断时的颗粒度
delta = max(everything_token) // 16
while n_token > max_token_limit:
where = np.argmax(everything_token)
encoded = tokenizer.encode(everything[where], disallowed_special=())
clipped_encoded = encoded[:len(encoded)-delta]
everything[where] = tokenizer.decode(clipped_encoded)[:-1] # -1 to remove the may-be illegal char
everything_token[where] = get_token_num(everything[where])
n_token = get_token_num('\n'.join(everything))
history = everything[1:]
return history