logging -> loguru: final stage

这个提交包含在:
binary-husky
2024-09-15 15:51:51 +00:00
父节点 bbf9e9f868
当前提交 2f343179a2
共有 55 个文件被更改,包括 237 次插入529 次删除

查看文件

@@ -1,12 +1,13 @@
from transformers import AutoModel, AutoTokenizer
from loguru import logger
from toolbox import update_ui, get_conf
from multiprocessing import Process, Pipe
import time
import os
import json
import threading
import importlib
from toolbox import update_ui, get_conf
from multiprocessing import Process, Pipe
load_message = "ChatGLMFT尚未加载,加载需要一段时间。注意,取决于`config.py`的配置,ChatGLMFT消耗大量的内存CPU或显存GPU,也许会导致低配计算机卡死 ……"
@@ -78,7 +79,7 @@ class GetGLMFTHandle(Process):
config.pre_seq_len = model_args['pre_seq_len']
config.prefix_projection = model_args['prefix_projection']
print(f"Loading prefix_encoder weight from {CHATGLM_PTUNING_CHECKPOINT}")
logger.info(f"Loading prefix_encoder weight from {CHATGLM_PTUNING_CHECKPOINT}")
model = AutoModel.from_pretrained(model_args['model_name_or_path'], config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join(CHATGLM_PTUNING_CHECKPOINT, "pytorch_model.bin"))
new_prefix_state_dict = {}
@@ -88,7 +89,7 @@ class GetGLMFTHandle(Process):
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
if model_args['quantization_bit'] is not None and model_args['quantization_bit'] != 0:
print(f"Quantized to {model_args['quantization_bit']} bit")
logger.info(f"Quantized to {model_args['quantization_bit']} bit")
model = model.quantize(model_args['quantization_bit'])
model = model.cuda()
if model_args['pre_seq_len'] is not None: