镜像自地址
https://github.com/SCIR-HI/Med-ChatGLM.git
已同步 2025-12-06 06:36:50 +00:00
127 行
4.3 KiB
Python
127 行
4.3 KiB
Python
""" PyTorch ChatGLM Dataset. """
|
|
|
|
import json
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
from transformers import AutoTokenizer
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
"THUDM/chatglm-6b", trust_remote_code=True)
|
|
|
|
|
|
def get_masks_and_position_ids(
|
|
seq, seq_len, context_length, device, gmask=False, position_encoding_2d=True
|
|
):
|
|
mask_position = (
|
|
seq_len - 2
|
|
) # is equal to `seq.index(mask_token)` or `seq.index(150001)`
|
|
attention_mask = torch.ones(
|
|
(1, context_length, context_length), device=device)
|
|
attention_mask.tril_()
|
|
attention_mask[..., : mask_position - 1] = 1
|
|
attention_mask = (attention_mask < 0.5).bool()
|
|
|
|
if position_encoding_2d:
|
|
# is equal to `seq_length = seq.index(150004)`
|
|
seq_length = seq_len - 1
|
|
position_ids = torch.arange(
|
|
context_length, dtype=torch.long, device=device)
|
|
if not gmask:
|
|
position_ids[seq_length:] = mask_position
|
|
block_position_ids = torch.cat(
|
|
(
|
|
torch.zeros(seq_length, dtype=torch.long, device=device),
|
|
torch.arange(
|
|
context_length - seq_length, dtype=torch.long, device=device
|
|
)
|
|
+ 1,
|
|
)
|
|
)
|
|
position_ids = torch.stack((position_ids, block_position_ids), dim=0)
|
|
else:
|
|
position_ids = torch.arange(
|
|
context_length, dtype=torch.long, device=device)
|
|
if not gmask:
|
|
position_ids[context_length - 1:] = mask_position
|
|
return attention_mask, position_ids
|
|
|
|
|
|
def chat_data_collator(features: list) -> dict:
|
|
# 只对target的部分计算loss
|
|
len_ids = [len(feature["input_ids"]) for feature in features]
|
|
longest = max(len_ids) + 1
|
|
input_ids = []
|
|
attention_mask_list = []
|
|
position_ids_list = []
|
|
labels_list = []
|
|
for ids_l, feature in sorted(zip(len_ids, features), key=lambda x: -x[0]):
|
|
ids = feature["input_ids"]
|
|
seq_len = feature["seq_len"]
|
|
labels = (
|
|
[-100] * (seq_len - 1)
|
|
+ ids[(seq_len - 1):]
|
|
+ [tokenizer.eos_token_id]
|
|
+ [-100] * (longest - ids_l - 1)
|
|
)
|
|
ids = ids + [tokenizer.eos_token_id] * (longest - ids_l)
|
|
_ids = torch.LongTensor(ids)
|
|
attention_mask, position_ids = get_masks_and_position_ids(
|
|
ids, seq_len, longest, _ids.device, gmask=False
|
|
)
|
|
labels_list.append(torch.LongTensor(labels))
|
|
input_ids.append(_ids)
|
|
attention_mask_list.append(attention_mask)
|
|
position_ids_list.append(position_ids)
|
|
input_ids = torch.stack(input_ids)
|
|
labels = torch.stack(labels_list)
|
|
attention_mask = torch.stack(attention_mask_list)
|
|
position_ids = torch.stack(position_ids_list)
|
|
return {
|
|
"input_ids": input_ids,
|
|
"labels": labels,
|
|
"attention_mask": attention_mask,
|
|
"position_ids": position_ids,
|
|
}
|
|
|
|
|
|
class Chat_Dataset(Dataset):
|
|
def __init__(self, data_dir, max_seq_length) -> None:
|
|
super().__init__()
|
|
self.content = self.load_json(data_dir)
|
|
self.encoded_content = self.encode(
|
|
tokenizer, self.content, max_seq_length)
|
|
self.features = self.encoded_content[0].keys()
|
|
|
|
def load_json(self, data_dir):
|
|
content = []
|
|
with open(data_dir, "r") as f:
|
|
lines = f.readlines()
|
|
for line in lines:
|
|
line = line.strip()
|
|
content.append(json.loads(line))
|
|
return content
|
|
|
|
def __getitem__(self, index):
|
|
return self.encoded_content[index]
|
|
|
|
def __len__(self):
|
|
return len(self.encoded_content)
|
|
|
|
def get_ori_item(self, index):
|
|
return self.content[index]
|
|
|
|
def encode(self, tokenizer, content, max_seq_length):
|
|
encoded_content = []
|
|
for example in content:
|
|
prompt = example["context"]
|
|
target = example["target"]
|
|
prompt_ids = tokenizer.encode(
|
|
prompt, max_length=max_seq_length, truncation=True)
|
|
target_ids = tokenizer.encode(
|
|
target, max_length=max_seq_length, truncation=True, add_special_tokens=False
|
|
)
|
|
input_ids = prompt_ids + target_ids + [tokenizer.eos_token_id]
|
|
encoded_content.append(
|
|
{"input_ids": input_ids, "seq_len": len(prompt_ids)})
|
|
return encoded_content
|