move token limit conf to bridge_all.py

这个提交包含在:
binary-husky
2023-12-04 10:39:10 +08:00
父节点 9bfc3400f9
当前提交 3c03f240ba
共有 3 个文件被更改,包括 4 次插入8 次删除

查看文件

@@ -8,7 +8,6 @@ from .local_llm_class import LocalLLMHandle, get_local_llm_predict_fns
from threading import Thread
import torch
MAX_INPUT_TOKEN_LENGTH = get_conf("MAX_INPUT_TOKEN_LENGTH")
def download_huggingface_model(model_name, max_retry, local_dir):
from huggingface_hub import snapshot_download
for i in range(1, max_retry):
@@ -94,8 +93,8 @@ class GetCoderLMHandle(LocalLLMHandle):
history.append({ 'role': 'user', 'content': query})
messages = history
inputs = self._tokenizer.apply_chat_template(messages, return_tensors="pt")
if inputs.shape[1] > MAX_INPUT_TOKEN_LENGTH:
inputs = inputs[:, -MAX_INPUT_TOKEN_LENGTH:]
if inputs.shape[1] > max_length:
inputs = inputs[:, -max_length:]
inputs = inputs.to(self._model.device)
generation_kwargs = dict(
inputs=inputs,