镜像自地址
https://github.com/SCIR-HI/Med-ChatGLM.git
已同步 2025-12-05 22:26:50 +00:00
v0.1 commit
这个提交包含在:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -127,3 +127,7 @@ dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# mac
|
||||
.DS_Store
|
||||
.idea
|
||||
|
||||
102
README.md
102
README.md
@@ -1,2 +1,100 @@
|
||||
# Med-ChatGLM
|
||||
Repo for Chinese Medical ChatGLM
|
||||
[**中文**](./README.md) | [**English**](./README_EN.md)
|
||||
|
||||
# ChatGLM-Med: 基于中文医学知识的ChatGLM模型微调
|
||||
|
||||
[](https://github.com/SCIR-HI/Med-ChatGLM/blob/main/LICENSE)
|
||||
[](https://www.python.org/downloads/release/python-390/)
|
||||
|
||||
|
||||
本项目开源了经过中文医学指令精调/指令微调(Instruct-tuning) 的ChatGLM-6B模型。我们通过医学知识图谱和GPT3.5 API构建了中文医学指令数据集,并在此基础上对ChatGLM-6B进行了指令微调,提高了ChatGLM在医疗领域的问答效果。
|
||||
|
||||
## A Quick Start
|
||||
首先安装依赖包,python环境建议3.9+
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
## 模型下载
|
||||
训练好的模型参数可以通过如下方式下载:
|
||||
|
||||
| 模型名称 | 大小 | 模型下载地址 |
|
||||
| :----------------- | :------: | :-------------------------: |
|
||||
| ChatGLM-6B-Med | 约13.4GB | [[百度网盘]](https://pan.baidu.com/s/1Sfi1bRwV741GIChIEOUW0A?pwd=i73e)<br>[[GoogleDrive]](https://) 上传中 |
|
||||
|
||||
|
||||
## 交互式测试
|
||||
在安装好环境后,即可进行交互式测试:
|
||||
|
||||
```
|
||||
python infer.py
|
||||
```
|
||||
## 数据集构建
|
||||
我们采用了公开和自建的中文医学知识库,主要参考了[cMeKG](https://github.com/king-yyf/CMeKG_tools)。
|
||||
医学知识库围绕疾病、药物、检查指标等构建,字段包括并发症,高危因素,组织学检查,临床症状,药物治疗,辅助治疗等。知识库示例如下:
|
||||
|
||||
```
|
||||
{"中心词": "偏头痛", "相关疾病": ["妊娠合并偏头痛", "恶寒发热"], "相关症状": ["皮肤变硬", "头部及眼后部疼痛并能听到连续不断的隆隆声", "晨起头痛加重"], "所属科室": ["中西医结合科", "内科"], "发病部位": ["头部"]}
|
||||
```
|
||||
我们利用GPT3.5接口围绕医学知识库构建问答数据,设置了多种Prompt形式来充分利用知识。
|
||||
|
||||
指令微调的训练集数据示例如下:
|
||||
|
||||
```
|
||||
"问题:患者30岁,有胰腺假性囊肿病史,出现黄疸,怀疑胰腺假性囊肿并发了门静脉高压症,如何治疗?"
|
||||
"回答: "胰腺假性囊肿并发门静脉高压症需要密切观察病情,积极治疗原发病,进行药物干预,如利尿剂、强心剂等,注意肝功能保护和营养支持。严重病例需要考虑行肝移植。"
|
||||
```
|
||||
指令微调数据集质量仍有限,后续将进行不断迭代,同时医学知识库和数据集构建代码还在整理中,整理完成将会发布。
|
||||
|
||||
## Instruct-tuning 指令微调
|
||||
### 下载ChatGLM-6B原始参数
|
||||
请参照[https://github.com/THUDM/ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B),并将下载的模型放入`model`目录下。
|
||||
|
||||
### 指令微调
|
||||
|
||||
```
|
||||
sh script/sft_medchat.sh
|
||||
```
|
||||
请根据实际情况调整batch size等超参,默认为我们使用的参数。
|
||||
|
||||
### 新数据指令微调
|
||||
|
||||
如果想用自己的数据集继续指令微调/指令精调(Instruct-tuning)ChatGLM,可以构建与`data`目录中训练集类似的文件格式进行。或参照[ChatGLM-Instruct-Tuning](https://github.com/thinksoso/ChatGLM-Instruct-Tuning/issues)仓库。
|
||||
### 计算资源需求
|
||||
我们在一张A100-SXM-80GB显卡上进行了训练,根据经验,训练显存建议选择32G及以上。
|
||||
|
||||
## 模型效果对比
|
||||
|
||||
| 测试输入 | ChatGLM-6B输出 | ChatGLM-6B-Med输出 |
|
||||
|:------------- |:---------------| :-------------|
|
||||
| 谷氨酰转肽酶水平会因吸毒或饮酒而升高吗? | 吸毒和饮酒都会对身体的健康产生负面影响,包括对肝脏和心脏的影响。<br>饮酒可能导致肝脏损伤和心脏疾病,因为酒精会破坏肝脏和心脏的功能。长期饮酒还可能导致脂肪肝、肝硬化、心脏病等健康问题。<br>吸毒会对肝脏和心脏产生严重的损伤,导致肝脏纤维化、肝硬化、心脏血管疾病等。吸毒还可能导致精神障碍、社会退缩、暴力和犯罪等问题。<br>因此,无论是饮酒还是吸毒,都会对健康产生负面影响,应该避免过度使用。 | 谷氨酰转肽酶升高的原因有很多,包括感染、消化性溃疡、感染、药物副作用等,而吸毒或饮酒等不良行为可能会导致谷氨酰转肽酶的升高。 |
|
||||
| 小李最近出现了心动过速的症状,伴有轻度胸痛。体检发现P-R间期延长,伴有T波低平和ST段异常 | 非常抱歉,我不理解您的问题。请提供更多背景信息或上下文,以便我更好地回答您的问题。谢谢! | 小李出现了心动过速的症状,并且伴有胸痛,需要考虑是否有心肌病、冠状动脉粥样硬化等心血管疾病,建议进行心电图检查、血液检查、心脏超声检查等 |
|
||||
| ...... | ...... |......|
|
||||
|
||||
## 项目参与者
|
||||
本项目由哈尔滨工业大学社会计算与信息检索研究中心健康智能组[王昊淳](https://github.com/s65b40) 、[刘驰](https://github.com/thinksoso)完成,指导教师为赵森栋副教授,秦兵教授以及刘挺教授。
|
||||
|
||||
## 致谢
|
||||
|
||||
本项目参考了以下开源项目,在此对相关项目和研究开发人员表示感谢。
|
||||
|
||||
- ChatGLM: [https://github.com/THUDM/ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B)
|
||||
- ChatGLM-Instruct-Tuning: [https://github.com/thinksoso/ChatGLM-Instruct-Tuning/issues](https://github.com/thinksoso/ChatGLM-Instruct-Tuning/issues)
|
||||
- CMeKG: [https://github.com/king-yyf/CMeKG_tools](https://github.com/king-yyf/CMeKG_tools)
|
||||
|
||||
##免责声明
|
||||
本项目相关资源仅供学术研究之用,严禁用于商业用途。使用涉及第三方代码的部分时,请严格遵循相应的开源协议。模型生成的内容受模型计算、随机性和量化精度损失等因素影响,本项目无法对其准确性作出保证。本项目数据集绝大部分由模型生成,即使符合某些医学事实,也不能被用作实际医学诊断的依据。对于模型输出的任何内容,本项目不承担任何法律责任,亦不对因使用相关资源和输出结果而可能产生的任何损失承担责任。
|
||||
|
||||
|
||||
## Citation
|
||||
如果你使用了本项目的数据或者代码,请声明引用
|
||||
|
||||
```
|
||||
@misc{ChatGLM-Med,
|
||||
author={Haochun Wang, Chi Liu},
|
||||
title = {ChatGLM-Med: 基于中文医学知识的ChatGLM模型微调},
|
||||
year = {2023},
|
||||
publisher = {GitHub},
|
||||
journal = {GitHub repository},
|
||||
howpublished = {\url{https://github.com/SCIR-HI/Med-ChatGLM}},
|
||||
}
|
||||
```
|
||||
126
chat_dataset.py
普通文件
126
chat_dataset.py
普通文件
@@ -0,0 +1,126 @@
|
||||
""" PyTorch ChatGLM Dataset. """
|
||||
|
||||
import json
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"THUDM/chatglm-6b", trust_remote_code=True)
|
||||
|
||||
|
||||
def get_masks_and_position_ids(
|
||||
seq, seq_len, context_length, device, gmask=False, position_encoding_2d=True
|
||||
):
|
||||
mask_position = (
|
||||
seq_len - 2
|
||||
) # is equal to `seq.index(mask_token)` or `seq.index(150001)`
|
||||
attention_mask = torch.ones(
|
||||
(1, context_length, context_length), device=device)
|
||||
attention_mask.tril_()
|
||||
attention_mask[..., : mask_position - 1] = 1
|
||||
attention_mask = (attention_mask < 0.5).bool()
|
||||
|
||||
if position_encoding_2d:
|
||||
# is equal to `seq_length = seq.index(150004)`
|
||||
seq_length = seq_len - 1
|
||||
position_ids = torch.arange(
|
||||
context_length, dtype=torch.long, device=device)
|
||||
if not gmask:
|
||||
position_ids[seq_length:] = mask_position
|
||||
block_position_ids = torch.cat(
|
||||
(
|
||||
torch.zeros(seq_length, dtype=torch.long, device=device),
|
||||
torch.arange(
|
||||
context_length - seq_length, dtype=torch.long, device=device
|
||||
)
|
||||
+ 1,
|
||||
)
|
||||
)
|
||||
position_ids = torch.stack((position_ids, block_position_ids), dim=0)
|
||||
else:
|
||||
position_ids = torch.arange(
|
||||
context_length, dtype=torch.long, device=device)
|
||||
if not gmask:
|
||||
position_ids[context_length - 1:] = mask_position
|
||||
return attention_mask, position_ids
|
||||
|
||||
|
||||
def chat_data_collator(features: list) -> dict:
|
||||
# 只对target的部分计算loss
|
||||
len_ids = [len(feature["input_ids"]) for feature in features]
|
||||
longest = max(len_ids) + 1
|
||||
input_ids = []
|
||||
attention_mask_list = []
|
||||
position_ids_list = []
|
||||
labels_list = []
|
||||
for ids_l, feature in sorted(zip(len_ids, features), key=lambda x: -x[0]):
|
||||
ids = feature["input_ids"]
|
||||
seq_len = feature["seq_len"]
|
||||
labels = (
|
||||
[-100] * (seq_len - 1)
|
||||
+ ids[(seq_len - 1):]
|
||||
+ [tokenizer.eos_token_id]
|
||||
+ [-100] * (longest - ids_l - 1)
|
||||
)
|
||||
ids = ids + [tokenizer.eos_token_id] * (longest - ids_l)
|
||||
_ids = torch.LongTensor(ids)
|
||||
attention_mask, position_ids = get_masks_and_position_ids(
|
||||
ids, seq_len, longest, _ids.device, gmask=False
|
||||
)
|
||||
labels_list.append(torch.LongTensor(labels))
|
||||
input_ids.append(_ids)
|
||||
attention_mask_list.append(attention_mask)
|
||||
position_ids_list.append(position_ids)
|
||||
input_ids = torch.stack(input_ids)
|
||||
labels = torch.stack(labels_list)
|
||||
attention_mask = torch.stack(attention_mask_list)
|
||||
position_ids = torch.stack(position_ids_list)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
|
||||
|
||||
class Chat_Dataset(Dataset):
|
||||
def __init__(self, data_dir, max_seq_length) -> None:
|
||||
super().__init__()
|
||||
self.content = self.load_json(data_dir)
|
||||
self.encoded_content = self.encode(
|
||||
tokenizer, self.content, max_seq_length)
|
||||
self.features = self.encoded_content[0].keys()
|
||||
|
||||
def load_json(self, data_dir):
|
||||
content = []
|
||||
with open(data_dir, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
content.append(json.loads(line))
|
||||
return content
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.encoded_content[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.encoded_content)
|
||||
|
||||
def get_ori_item(self, index):
|
||||
return self.content[index]
|
||||
|
||||
def encode(self, tokenizer, content, max_seq_length):
|
||||
encoded_content = []
|
||||
for example in content:
|
||||
prompt = example["context"]
|
||||
target = example["target"]
|
||||
prompt_ids = tokenizer.encode(
|
||||
prompt, max_length=max_seq_length, truncation=True)
|
||||
target_ids = tokenizer.encode(
|
||||
target, max_length=max_seq_length, truncation=True, add_special_tokens=False
|
||||
)
|
||||
input_ids = prompt_ids + target_ids + [tokenizer.eos_token_id]
|
||||
encoded_content.append(
|
||||
{"input_ids": input_ids, "seq_len": len(prompt_ids)})
|
||||
return encoded_content
|
||||
92
configuration_chatglm.py
普通文件
92
configuration_chatglm.py
普通文件
@@ -0,0 +1,92 @@
|
||||
""" ChatGLM model configuration """
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class ChatGLMConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`~ChatGLMModel`].
|
||||
It is used to instantiate an ChatGLM model according to the specified arguments, defining the model
|
||||
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
|
||||
the ChatGLM-6B [THUDM/ChatGLM-6B](https://huggingface.co/THUDM/chatglm-6b) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used
|
||||
to control the model outputs. Read the documentation from [`PretrainedConfig`]
|
||||
for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 150528):
|
||||
Vocabulary size of the ChatGLM-6B model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`~ChatGLMModel`] or
|
||||
[`~TFChatGLMModel`].
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the encoder layers and the pooler layer.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 28):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
inner_hidden_size (`int`, *optional*, defaults to 16384):
|
||||
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
max_sequence_length (`int`, *optional*, defaults to 512):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
|
||||
layernorm_epsilon (`float`, *optional*, defaults to 1e-5):
|
||||
The epsilon used by the layer normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether the model should return the last key/values attentions (not used by all models).
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from configuration_chatglm import ChatGLMConfig
|
||||
>>> from modeling_chatglm import ChatGLMModel
|
||||
|
||||
>>> # Initializing a ChatGLM-6B THUDM/ChatGLM-6B style configuration
|
||||
>>> configuration = ChatGLMConfig()
|
||||
|
||||
>>> # Initializing a model from the THUDM/ChatGLM-6B style configuration
|
||||
>>> model = ChatGLMModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```
|
||||
"""
|
||||
model_type = "chatglm"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=150528,
|
||||
hidden_size=4096,
|
||||
num_layers=28,
|
||||
num_attention_heads=32,
|
||||
layernorm_epsilon=1e-5,
|
||||
use_cache=False,
|
||||
bos_token_id=150004,
|
||||
eos_token_id=150005,
|
||||
pad_token_id=0,
|
||||
max_sequence_length=2048,
|
||||
inner_hidden_size=16384,
|
||||
position_encoding_2d=True,
|
||||
**kwargs
|
||||
):
|
||||
self.num_layers = num_layers
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.max_sequence_length = max_sequence_length
|
||||
self.layernorm_epsilon = layernorm_epsilon
|
||||
self.inner_hidden_size = inner_hidden_size
|
||||
self.use_cache = use_cache
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.position_encoding_2d = position_encoding_2d
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
**kwargs
|
||||
)
|
||||
7622
data/train.txt
普通文件
7622
data/train.txt
普通文件
文件差异内容过多而无法显示
加载差异
13
infer.py
普通文件
13
infer.py
普通文件
@@ -0,0 +1,13 @@
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
from modeling_chatglm import ChatGLMForConditionalGeneration
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"./model", trust_remote_code=True)
|
||||
model = ChatGLMForConditionalGeneration.from_pretrained(
|
||||
"./model").half().cuda()
|
||||
while True:
|
||||
a = input("请输入您的问题:(输入q以退出)")
|
||||
if a.strip() == 'q':
|
||||
exit()
|
||||
response, history = model.chat(tokenizer, "问题:" + a.strip() + '\n答案:', max_length=256, history=[])
|
||||
print("回答:", response)
|
||||
0
model/put model here
普通文件
0
model/put model here
普通文件
1227
modeling_chatglm.py
普通文件
1227
modeling_chatglm.py
普通文件
文件差异内容过多而无法显示
加载差异
16
requirements.txt
普通文件
16
requirements.txt
普通文件
@@ -0,0 +1,16 @@
|
||||
# int8
|
||||
bitsandbytes==0.37.1
|
||||
accelerate==0.17.1
|
||||
|
||||
# chatglm
|
||||
protobuf>=3.19.5,<3.20.1
|
||||
transformers==4.27.1
|
||||
icetk
|
||||
cpm_kernels==1.0.11
|
||||
torch>=1.13.1
|
||||
evaluate
|
||||
scikit-learn
|
||||
|
||||
#
|
||||
datasets==2.10.1
|
||||
git+https://github.com/huggingface/peft.git # 最新版本 >=0.3.0.dev0
|
||||
564
run_clm.py
普通文件
564
run_clm.py
普通文件
@@ -0,0 +1,564 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
|
||||
|
||||
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
|
||||
https://huggingface.co/models?filter=text-generation
|
||||
"""
|
||||
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import chain
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from typing import Dict
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
default_data_collator,
|
||||
is_torch_tpu_available,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.testing_utils import CaptureLogger
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.utils import check_min_version, send_example_telemetry
|
||||
from transformers.utils.versions import require_version
|
||||
from transformers import TrainerCallback, TrainerControl, TrainingArguments, TrainerState
|
||||
from modeling_chatglm import ChatGLMForConditionalGeneration
|
||||
from chat_dataset import chat_data_collator, Chat_Dataset
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
# check_min_version("4.28.0.dev0")
|
||||
|
||||
# require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
class CastOutputToFloat(torch.nn.Sequential):
|
||||
def forward(self, x):
|
||||
return super().forward(x).to(torch.float32)
|
||||
|
||||
class LoggingLossCallback(TrainerCallback):
|
||||
def __init__(self, log_interval: int, log_file: str):
|
||||
super().__init__()
|
||||
self.log_interval = log_interval
|
||||
self.log_file = log_file
|
||||
|
||||
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs: Dict[str, float], **kwargs):
|
||||
if state.global_step % self.log_interval == 0:
|
||||
loss = logs.get("loss", None)
|
||||
if loss is not None:
|
||||
with open(self.log_file, "a") as f:
|
||||
f.write(f"Step: {state.global_step}, Loss: {loss}\n")
|
||||
return control
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
||||
"""
|
||||
|
||||
model_name_or_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
|
||||
)
|
||||
},
|
||||
)
|
||||
model_type: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
||||
)
|
||||
config_overrides: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Override some existing default config settings when a model is trained from scratch. Example: "
|
||||
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
||||
)
|
||||
},
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
use_fast_tokenizer: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": (
|
||||
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
|
||||
"with private models)."
|
||||
)
|
||||
},
|
||||
)
|
||||
torch_dtype: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
|
||||
"dtype will be automatically derived from the model's weights."
|
||||
),
|
||||
"choices": ["auto", "bfloat16", "float16", "float32"],
|
||||
},
|
||||
)
|
||||
low_cpu_mem_usage: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": (
|
||||
"It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded."
|
||||
"set True will benefit LLM loading time and RAM consumption."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
|
||||
raise ValueError(
|
||||
"--config_overrides can't be used in combination with --config_name or --model_name_or_path"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
dataset_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
dataset_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
max_seq_length: Optional[int] = field(
|
||||
default=256, metadata={"help": "The longest of prompt or target"}
|
||||
)
|
||||
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
||||
validation_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
||||
)
|
||||
log_file: Optional[str] = field(default=None, metadata={"help": "log file path"})
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
)
|
||||
},
|
||||
)
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
)
|
||||
},
|
||||
)
|
||||
streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
|
||||
block_size: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Optional input sequence length after tokenization. "
|
||||
"The training dataset will be truncated in block of this size for training. "
|
||||
"Default to the model max input length for single sentence inputs (take into account special tokens)."
|
||||
)
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
validation_split_percentage: Optional[int] = field(
|
||||
default=5,
|
||||
metadata={
|
||||
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
||||
},
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
keep_linebreaks: bool = field(
|
||||
default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.streaming:
|
||||
require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`")
|
||||
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||
else:
|
||||
if self.train_file is not None:
|
||||
extension = self.train_file.split(".")[-1]
|
||||
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
|
||||
if self.validation_file is not None:
|
||||
extension = self.validation_file.split(".")[-1]
|
||||
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
|
||||
# information sent is the one passed as arguments along with your Python/PyTorch versions.
|
||||
send_example_telemetry("run_clm", model_args, data_args)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
if training_args.should_log:
|
||||
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# Detecting last checkpoint.
|
||||
last_checkpoint = None
|
||||
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
||||
"Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||
logger.info(
|
||||
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
||||
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
||||
)
|
||||
|
||||
# Set seed before initializing model.
|
||||
set_seed(training_args.seed)
|
||||
|
||||
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
||||
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
||||
# (the dataset will be downloaded automatically from the datasets Hub).
|
||||
#
|
||||
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
|
||||
# 'text' is found. You can easily tweak this behavior (see below).
|
||||
#
|
||||
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
||||
# download the dataset.
|
||||
if data_args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
raw_datasets = load_dataset(
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
streaming=data_args.streaming,
|
||||
)
|
||||
if "validation" not in raw_datasets.keys():
|
||||
raw_datasets["validation"] = load_dataset(
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
split=f"train[:{data_args.validation_split_percentage}%]",
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
streaming=data_args.streaming,
|
||||
)
|
||||
raw_datasets["train"] = load_dataset(
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
split=f"train[{data_args.validation_split_percentage}%:]",
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
streaming=data_args.streaming,
|
||||
)
|
||||
lm_datasets = raw_datasets
|
||||
else:
|
||||
data_files = {}
|
||||
dataset_args = {}
|
||||
lm_datasets = {}
|
||||
if data_args.train_file is not None:
|
||||
data_files["train"] = data_args.train_file
|
||||
train_dataset = Chat_Dataset(data_files["train"],data_args.max_seq_length)
|
||||
lm_datasets["train"] = train_dataset
|
||||
if data_args.validation_file is not None:
|
||||
data_files["validation"] = data_args.validation_file
|
||||
valid_dataset = Chat_Dataset(data_files["validation"],data_args.max_seq_length)
|
||||
lm_datasets["validation"] = valid_dataset
|
||||
|
||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
#
|
||||
# Distributed training:
|
||||
# The .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
|
||||
config_kwargs = {
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"revision": model_args.model_revision,
|
||||
"use_auth_token": True if model_args.use_auth_token else None,
|
||||
}
|
||||
if model_args.config_name:
|
||||
config = AutoConfig.from_pretrained(model_args.config_name, trust_remote_code=True, **config_kwargs)
|
||||
elif model_args.model_name_or_path:
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path,trust_remote_code=True, **config_kwargs)
|
||||
else:
|
||||
config = CONFIG_MAPPING[model_args.model_type]()
|
||||
logger.warning("You are instantiating a new config instance from scratch.")
|
||||
if model_args.config_overrides is not None:
|
||||
logger.info(f"Overriding config: {model_args.config_overrides}")
|
||||
config.update_from_string(model_args.config_overrides)
|
||||
logger.info(f"New config: {config}")
|
||||
|
||||
tokenizer_kwargs = {
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"use_fast": model_args.use_fast_tokenizer,
|
||||
"revision": model_args.model_revision,
|
||||
"use_auth_token": True if model_args.use_auth_token else None,
|
||||
}
|
||||
if model_args.tokenizer_name:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name,trust_remote_code=True,**tokenizer_kwargs)
|
||||
elif model_args.model_name_or_path:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True,**tokenizer_kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
||||
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
||||
)
|
||||
|
||||
if model_args.model_name_or_path:
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype
|
||||
if model_args.torch_dtype in ["auto", None]
|
||||
else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
model = ChatGLMForConditionalGeneration.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
torch_dtype=torch_dtype,
|
||||
low_cpu_mem_usage=model_args.low_cpu_mem_usage,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
# model = ChatGLMForConditionalGeneration.from_pretrained(
|
||||
# "THUDM/chatglm-6b", load_in_8bit=True, trust_remote_code=True, device_map="auto"
|
||||
# )
|
||||
else:
|
||||
model = ChatGLMForConditionalGeneration.from_config(config)
|
||||
n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
|
||||
logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
|
||||
|
||||
model.gradient_checkpointing_enable()
|
||||
model.enable_input_require_grads()
|
||||
model.is_parallelizable = False
|
||||
model.model_parallel = False
|
||||
# model.lm_head = CastOutputToFloat(model.lm_head)
|
||||
model.config.use_cache = (
|
||||
False # silence the warnings. Please re-enable for inference!
|
||||
)
|
||||
|
||||
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
|
||||
# on a small vocab and want a smaller embedding size, remove this test.
|
||||
embedding_size = model.get_input_embeddings().weight.shape[0]
|
||||
if len(tokenizer) > embedding_size:
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
|
||||
tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
|
||||
|
||||
if data_args.max_seq_length is None:
|
||||
max_seq_length = tokenizer.model_max_length/2
|
||||
else:
|
||||
if data_args.max_seq_length > tokenizer.model_max_length:
|
||||
logger.warning(
|
||||
f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
|
||||
f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
|
||||
)
|
||||
block_size = min(data_args.max_seq_length, tokenizer.model_max_length/2)
|
||||
|
||||
|
||||
if training_args.do_train:
|
||||
if "train" not in lm_datasets:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
train_dataset = lm_datasets["train"]
|
||||
if data_args.max_train_samples is not None:
|
||||
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
||||
train_dataset = train_dataset.select(range(max_train_samples))
|
||||
|
||||
if training_args.do_eval:
|
||||
if "validation" not in lm_datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = lm_datasets["validation"]
|
||||
if data_args.max_eval_samples is not None:
|
||||
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
||||
eval_dataset = eval_dataset.select(range(max_eval_samples))
|
||||
|
||||
def preprocess_logits_for_metrics(logits, labels):
|
||||
if isinstance(logits, tuple):
|
||||
# Depending on the model and config, logits may contain extra tensors,
|
||||
# like past_key_values, but logits always come first
|
||||
logits = logits[0]
|
||||
return logits.argmax(dim=-1)
|
||||
|
||||
metric = evaluate.load("accuracy")
|
||||
|
||||
def compute_metrics(eval_preds):
|
||||
preds, labels = eval_preds
|
||||
# preds have the same shape as the labels, after the argmax(-1) has been calculated
|
||||
# by preprocess_logits_for_metrics but we need to shift the labels
|
||||
labels = labels[:, 1:].reshape(-1)
|
||||
preds = preds[:, :-1].reshape(-1)
|
||||
return metric.compute(predictions=preds, references=labels)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset if training_args.do_train else None,
|
||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||
tokenizer=tokenizer,
|
||||
# Data collator will default to DataCollatorWithPadding, so we change it.
|
||||
data_collator=chat_data_collator,
|
||||
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
|
||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics
|
||||
if training_args.do_eval and not is_torch_tpu_available()
|
||||
else None,
|
||||
callbacks=[LoggingLossCallback(log_interval=10, log_file=data_args.log_file)],
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
checkpoint = None
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
checkpoint = training_args.resume_from_checkpoint
|
||||
elif last_checkpoint is not None:
|
||||
checkpoint = last_checkpoint
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
metrics = train_result.metrics
|
||||
|
||||
max_train_samples = (
|
||||
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
||||
)
|
||||
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
metrics = trainer.evaluate()
|
||||
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
try:
|
||||
perplexity = math.exp(metrics["eval_loss"])
|
||||
except OverflowError:
|
||||
perplexity = float("inf")
|
||||
metrics["perplexity"] = perplexity
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"}
|
||||
if data_args.dataset_name is not None:
|
||||
kwargs["dataset_tags"] = data_args.dataset_name
|
||||
if data_args.dataset_config_name is not None:
|
||||
kwargs["dataset_args"] = data_args.dataset_config_name
|
||||
kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
|
||||
else:
|
||||
kwargs["dataset"] = data_args.dataset_name
|
||||
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(**kwargs)
|
||||
else:
|
||||
trainer.create_model_card(**kwargs)
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
28
scripts/sft_medchat.sh
普通文件
28
scripts/sft_medchat.sh
普通文件
@@ -0,0 +1,28 @@
|
||||
wandb online
|
||||
exp_tag="chatglm_tuning"
|
||||
|
||||
python run_clm.py \
|
||||
--model_name_or_path MODEL_PATH\
|
||||
--per_device_train_batch_size 8 \
|
||||
--per_device_eval_batch_size 8 \
|
||||
--train_file ./data/train.txt \
|
||||
--max_seq_length 256 \
|
||||
--output_dir ./output/ \
|
||||
--do_train \
|
||||
--logging_steps 30 \
|
||||
--log_file ./log/$exp_tag \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--learning_rate 5e-5 \
|
||||
--group_by_length False \
|
||||
--num_train_epochs 3 \
|
||||
--lr_scheduler_type linear \
|
||||
--warmup_ratio 0.1 \
|
||||
--logging_dir ./log \
|
||||
--logging_steps 10 \
|
||||
--save_strategy epoch \
|
||||
--seed 2023 \
|
||||
--remove_unused_columns False \
|
||||
--torch_dtype auto \
|
||||
--adam_epsilon 1e-3 \
|
||||
--report_to wandb \
|
||||
--run_name $exp_tag
|
||||
345
tokenization_chatglm.py
普通文件
345
tokenization_chatglm.py
普通文件
@@ -0,0 +1,345 @@
|
||||
"""Tokenization classes for ChatGLM."""
|
||||
import sys
|
||||
import unicodedata
|
||||
from typing import List, Optional, Union
|
||||
from functools import lru_cache
|
||||
import os
|
||||
import collections
|
||||
import re
|
||||
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from icetk.text_tokenizer import TextTokenizer
|
||||
from icetk.utils import auto_create
|
||||
import icetk.sentencepiece_model_pb2 as sp_model
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"THUDM/chatglm-6b": 2048,
|
||||
}
|
||||
|
||||
|
||||
class SPTokenizer:
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
max_blank_length=80,
|
||||
byte_fallback=True,
|
||||
):
|
||||
assert vocab_file is not None
|
||||
self.vocab_file = vocab_file
|
||||
self.special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "<unused_0>", "<sop>", "<eop>", "<ENC>", "<dBLOCK>"]
|
||||
self.max_blank_length = max_blank_length
|
||||
self.byte_fallback = byte_fallback
|
||||
self.text_tokenizer = self._build_text_tokenizer(encode_special_tokens=False)
|
||||
self.special_text_tokenizer = self._build_text_tokenizer(encode_special_tokens=True)
|
||||
|
||||
@staticmethod
|
||||
def _configure_tokenizer(
|
||||
text_tokenizer: TextTokenizer,
|
||||
special_tokens: List[str],
|
||||
max_blank_length: int,
|
||||
byte_fallback: bool,
|
||||
encode_special_tokens=False,
|
||||
):
|
||||
# special token
|
||||
special_token_type = 4 if encode_special_tokens else 3 # 3 - CONTROL, 4 - USER_DEFINE
|
||||
for token in special_tokens:
|
||||
text_tokenizer.proto.pieces.append(
|
||||
sp_model.ModelProto.SentencePiece(piece=token, score=0.0, type=special_token_type)
|
||||
)
|
||||
# whitespaces
|
||||
for token in [SPTokenizer.get_tab_token()] + [
|
||||
SPTokenizer.get_blank_token(i) for i in range(2, max_blank_length + 1)
|
||||
]:
|
||||
text_tokenizer.proto.pieces.append(sp_model.ModelProto.SentencePiece(piece=token, score=0.0, type=4))
|
||||
# byte fallback
|
||||
if byte_fallback:
|
||||
text_tokenizer.proto.trainer_spec.byte_fallback = True
|
||||
for i in range(256):
|
||||
text_tokenizer.proto.pieces.append(
|
||||
sp_model.ModelProto.SentencePiece(piece="<0x{:02X}>".format(i), score=0.0, type=6)
|
||||
)
|
||||
text_tokenizer.refresh()
|
||||
|
||||
def _build_text_tokenizer(self, encode_special_tokens=False):
|
||||
tokenizer = TextTokenizer(self.vocab_file)
|
||||
self._configure_tokenizer(
|
||||
tokenizer, self.special_tokens, self.max_blank_length, self.byte_fallback, encode_special_tokens
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
def _get_text_tokenizer(self, encode_special_tokens=False):
|
||||
if encode_special_tokens:
|
||||
return self.special_text_tokenizer
|
||||
else:
|
||||
return self.text_tokenizer
|
||||
|
||||
@staticmethod
|
||||
def get_blank_token(length: int):
|
||||
assert length >= 2
|
||||
return f"<|blank_{length}|>"
|
||||
|
||||
@staticmethod
|
||||
def get_tab_token():
|
||||
return f"<|tab|>"
|
||||
|
||||
@property
|
||||
def num_image_tokens(self):
|
||||
return 20000
|
||||
|
||||
@property
|
||||
def num_text_tokens(self):
|
||||
return self.text_tokenizer.num_tokens
|
||||
|
||||
@property
|
||||
def num_tokens(self):
|
||||
return self.num_image_tokens + self.num_text_tokens
|
||||
|
||||
@staticmethod
|
||||
def _encode_whitespaces(text: str, max_len: int = 80):
|
||||
text = text.replace("\t", SPTokenizer.get_tab_token())
|
||||
for i in range(max_len, 1, -1):
|
||||
text = text.replace(" " * i, SPTokenizer.get_blank_token(i))
|
||||
return text
|
||||
|
||||
def _preprocess(self, text: str, linebreak=True, whitespaces=True):
|
||||
if linebreak:
|
||||
text = text.replace("\n", "<n>")
|
||||
if whitespaces:
|
||||
text = self._encode_whitespaces(text, max_len=self.max_blank_length)
|
||||
return text
|
||||
|
||||
def encode(
|
||||
self, text: str, linebreak=True, whitespaces=True, special_tokens=False, add_dummy_prefix=True
|
||||
) -> List[int]:
|
||||
"""
|
||||
@param text: Text to encode.
|
||||
@param linebreak: Whether to encode newline (\n) in text.
|
||||
@param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.
|
||||
@param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.
|
||||
@param add_dummy_prefix: Whether to add dummy blank space in the beginning.
|
||||
"""
|
||||
text = self._preprocess(text, linebreak, whitespaces)
|
||||
if not add_dummy_prefix:
|
||||
text = "<n>" + text
|
||||
tmp = self._get_text_tokenizer(encode_special_tokens=special_tokens).encode(text)
|
||||
tokens = [x + self.num_image_tokens for x in tmp]
|
||||
return tokens if add_dummy_prefix else tokens[2:]
|
||||
|
||||
def decode(self, text_ids: List[int], special_tokens=False) -> str:
|
||||
ids = [int(_id) - self.num_image_tokens for _id in text_ids]
|
||||
text = self._get_text_tokenizer(encode_special_tokens=special_tokens).decode(ids)
|
||||
text = text.replace("<n>", "\n")
|
||||
text = text.replace(SPTokenizer.get_tab_token(), "\t")
|
||||
for i in range(2, self.max_blank_length + 1):
|
||||
text = text.replace(self.get_blank_token(i), " " * i)
|
||||
return text
|
||||
|
||||
def tokenize(
|
||||
self, text: str, linebreak=True, whitespaces=True, special_tokens=False, add_dummy_prefix=True
|
||||
) -> List[str]:
|
||||
"""
|
||||
@param text: Text to encode.
|
||||
@param linebreak: Whether to encode newline (\n) in text.
|
||||
@param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.
|
||||
@param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.
|
||||
@param add_dummy_prefix: Whether to add dummy blank space in the beginning.
|
||||
"""
|
||||
text = self._preprocess(text, linebreak, whitespaces)
|
||||
if not add_dummy_prefix:
|
||||
text = "<n>" + text
|
||||
tokens = self._get_text_tokenizer(encode_special_tokens=special_tokens).tokenize(text)
|
||||
return tokens if add_dummy_prefix else tokens[2:]
|
||||
|
||||
def __getitem__(self, x: Union[int, str]):
|
||||
if isinstance(x, int):
|
||||
if x < self.num_image_tokens:
|
||||
return "<image_{}>".format(x)
|
||||
else:
|
||||
return self.text_tokenizer.convert_id_to_token(x - self.num_image_tokens)
|
||||
elif isinstance(x, str):
|
||||
if x.startswith("<image_") and x.endswith(">") and x[7:-1].isdigit():
|
||||
return int(x[7:-1])
|
||||
else:
|
||||
return self.text_tokenizer.convert_token_to_id(x) + self.num_image_tokens
|
||||
else:
|
||||
raise ValueError("The key should be str or int.")
|
||||
|
||||
|
||||
class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
Construct a ChatGLM tokenizer. Based on byte-level Byte-Pair-Encoding.
|
||||
|
||||
Args:
|
||||
vocab_file (`str`):
|
||||
Path to the vocabulary file.
|
||||
"""
|
||||
|
||||
vocab_files_names = {"vocab_file": "ice_text.model"}
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["input_ids"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
do_lower_case=False,
|
||||
remove_space=False,
|
||||
bos_token='sop',
|
||||
eos_token='eos',
|
||||
eop_token='eop',
|
||||
mask_token='[MASK]',
|
||||
gmask_token='[gMASK]',
|
||||
padding_side="left",
|
||||
**kwargs
|
||||
) -> None:
|
||||
super().__init__(
|
||||
do_lower_case=do_lower_case,
|
||||
remove_space=remove_space,
|
||||
padding_side=padding_side,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.do_lower_case = do_lower_case
|
||||
self.remove_space = remove_space
|
||||
self.vocab_file = vocab_file
|
||||
|
||||
self.bos_token = bos_token
|
||||
self.eos_token = eos_token
|
||||
self.eop_token = eop_token
|
||||
self.mask_token = mask_token
|
||||
self.gMASK_token = gmask_token
|
||||
|
||||
self.sp_tokenizer = SPTokenizer(vocab_file)
|
||||
|
||||
""" Initialisation """
|
||||
|
||||
@property
|
||||
def eop_token_id(self) -> Optional[int]:
|
||||
"""
|
||||
`Optional[int]`: Id of the end of sentence token in the vocabulary. Returns `None` if the token has not been
|
||||
set.
|
||||
"""
|
||||
if self.eop_token is None:
|
||||
return None
|
||||
return self.convert_tokens_to_ids(self.eop_token)
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
""" Returns vocab size """
|
||||
return self.sp_tokenizer.num_tokens
|
||||
|
||||
def get_vocab(self):
|
||||
""" Returns vocab as a dict """
|
||||
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
return vocab
|
||||
|
||||
def preprocess_text(self, inputs):
|
||||
if self.remove_space:
|
||||
outputs = " ".join(inputs.strip().split())
|
||||
else:
|
||||
outputs = inputs
|
||||
|
||||
if self.do_lower_case:
|
||||
outputs = outputs.lower()
|
||||
|
||||
return outputs
|
||||
|
||||
def _tokenize(self, text, **kwargs):
|
||||
""" Returns a tokenized string. """
|
||||
text = self.preprocess_text(text)
|
||||
|
||||
seq = self.sp_tokenizer.tokenize(text)
|
||||
|
||||
return seq
|
||||
|
||||
def decode(
|
||||
self,
|
||||
token_ids: Union[List[int], List[List[int]]],
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: bool = True,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
**kwargs
|
||||
) -> str:
|
||||
if isinstance(token_ids[0], list):
|
||||
tokens = []
|
||||
for single_token_ids in token_ids:
|
||||
if self.pad_token_id in single_token_ids: # remove pad
|
||||
single_token_ids = list(filter((self.pad_token_id).__ne__, single_token_ids))
|
||||
tokens.append(self.sp_tokenizer.decode(single_token_ids))
|
||||
return (tokens)
|
||||
else:
|
||||
if self.pad_token_id in token_ids: # remove pad
|
||||
token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
|
||||
return self.sp_tokenizer.decode(token_ids)
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
""" Converts a token (str) in an id using the vocab. """
|
||||
return self.sp_tokenizer[token]
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
return self.sp_tokenizer[index]
|
||||
|
||||
def save_vocabulary(self, save_directory, filename_prefix=None):
|
||||
"""
|
||||
Save the vocabulary and special tokens file to a directory.
|
||||
|
||||
Args:
|
||||
save_directory (`str`):
|
||||
The directory in which to save the vocabulary.
|
||||
filename_prefix (`str`, *optional*):
|
||||
An optional prefix to add to the named of the saved files.
|
||||
|
||||
Returns:
|
||||
`Tuple(str)`: Paths to the files saved.
|
||||
"""
|
||||
if os.path.isdir(save_directory):
|
||||
vocab_file = os.path.join(
|
||||
save_directory, self.vocab_files_names["vocab_file"]
|
||||
)
|
||||
else:
|
||||
vocab_file = save_directory
|
||||
|
||||
with open(self.vocab_file, 'rb') as fin:
|
||||
proto_str = fin.read()
|
||||
|
||||
with open(vocab_file, "wb") as writer:
|
||||
writer.write(proto_str)
|
||||
|
||||
return (vocab_file,)
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||
adding special tokens. A BERT sequence has the following format:
|
||||
|
||||
- single sequence: `[CLS] X [SEP]`
|
||||
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs to which the special tokens will be added.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
||||
"""
|
||||
if token_ids_1 is not None:
|
||||
token_ids_0 += token_ids_1
|
||||
mask_ids = self.sp_tokenizer[self.mask_token]
|
||||
gmask_ids = self.sp_tokenizer[self.gMASK_token]
|
||||
if mask_ids not in token_ids_0 and gmask_ids not in token_ids_0:
|
||||
token_ids_0 += [gmask_ids]
|
||||
|
||||
if token_ids_0[-1] != mask_ids and token_ids_0[-1] != gmask_ids:
|
||||
token_ids_0 += [self.sp_tokenizer[self.eos_token]]
|
||||
|
||||
token_ids_0 += [self.sp_tokenizer[self.bos_token]]
|
||||
|
||||
return token_ids_0
|
||||
在新工单中引用
屏蔽一个用户