修正internlm输入设备bug

这个提交包含在:
qingxu fu
2023-11-11 23:22:50 +08:00
父节点 f75e39dc27
当前提交 2d91e438d6
共有 2 个文件被更改,包括 17 次插入19 次删除

查看文件

@@ -94,8 +94,9 @@ class GetInternlmHandle(LocalLLMHandle):
inputs = tokenizer([prompt], padding=True, return_tensors="pt")
input_length = len(inputs["input_ids"][0])
device = get_conf('LOCAL_MODEL_DEVICE')
for k, v in inputs.items():
inputs[k] = v.cuda()
inputs[k] = v.to(device)
input_ids = inputs["input_ids"]
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
if generation_config is None: