镜像自地址
https://github.com/xming521/CTAI.git
已同步 2025-12-06 06:36:49 +00:00
179 行
5.6 KiB
Python
179 行
5.6 KiB
Python
import os
|
|
|
|
import SimpleITK as sitk
|
|
import cv2 as cv
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils import data
|
|
from torch.utils.data import random_split
|
|
|
|
train_data_path = 'E:/projects/python projects/ct_data/'
|
|
test_data_path = '../data/test/'
|
|
|
|
|
|
def get_person_files(data_path):
|
|
# 数据结构
|
|
# [[person_id,image,mask],[person_id,image,mask],..,]
|
|
|
|
all = []
|
|
dir_list = [data_path + i for i in os.listdir(data_path)]
|
|
for dir in dir_list:
|
|
person_id = dir.split('/')[-1]
|
|
filename_list = []
|
|
image_list, mask_list, = [], []
|
|
# 所有数据跑
|
|
temp = os.listdir(dir + '/arterial phase')
|
|
filename_list.extend([dir + '/arterial phase/' + name for name in temp])
|
|
for i in filename_list:
|
|
if '.dcm' in i:
|
|
image_list.append(i)
|
|
if '_mask' in i:
|
|
mask_list.append(i)
|
|
|
|
all.append([person_id, image_list, mask_list])
|
|
|
|
return all
|
|
|
|
|
|
def get_train_files(data_path, all, get_dice=False):
|
|
image_list, mask_list, finish_list, id_list = [], [], [], []
|
|
# id_list 先是病人id再是图片
|
|
dir_list = [data_path + i for i in os.listdir(data_path)]
|
|
filename_list = []
|
|
for dir in dir_list:
|
|
# 所有数据跑
|
|
if all:
|
|
temp = os.listdir(dir + '/arterial phase')
|
|
filename_list.extend([dir + '/arterial phase/' + name for name in temp])
|
|
if not all:
|
|
filename_list.append(dir)
|
|
# temp = os.listdir(dir)
|
|
# filename_list.extend([dir + '/' + name for name in temp])
|
|
|
|
for i in filename_list:
|
|
if '.dcm' in i:
|
|
image_list.append((i, i.split('/')[-3], i.split('/')[-1].replace('.dcm', '')))
|
|
if '_mask' in i:
|
|
mask_list.append(i)
|
|
if 'finish' in i:
|
|
finish_list.append(i)
|
|
if get_dice:
|
|
return image_list, mask_list, id_list
|
|
else:
|
|
return image_list
|
|
|
|
|
|
def data_in_one(inputdata):
|
|
if not inputdata.any():
|
|
return inputdata
|
|
inputdata = (inputdata - inputdata.min()) / (inputdata.max() - inputdata.min())
|
|
return inputdata
|
|
|
|
|
|
def get_dataset(data_path, have):
|
|
global test_image, test_mask
|
|
image_list, mask_list, image_data, mask_data = [], [], [], []
|
|
|
|
image_list = get_train_files(data_path, all=True)
|
|
for i in image_list:
|
|
image = sitk.ReadImage(i[0])
|
|
image_array = sitk.GetArrayFromImage(image)
|
|
mask = i[0].replace('.dcm', '_mask.png')
|
|
mask_array = cv.imread(mask, cv.IMREAD_GRAYSCALE)
|
|
|
|
if have:
|
|
if not mask_array.any():
|
|
continue
|
|
|
|
mask_array = data_in_one(mask_array)
|
|
mask_tensor = torch.from_numpy(mask_array).float()
|
|
j = i[0].split('/')[-1].replace('_mask.png', '')
|
|
mask_data.append((j, mask_tensor))
|
|
|
|
ROI_mask = np.zeros(shape=image_array.shape)
|
|
ROI_mask_mini = np.zeros(shape=(1, 160, 100))
|
|
ROI_mask_mini[0] = image_array[0][270:430, 200:300]
|
|
ROI_mask_mini = data_in_one(ROI_mask_mini)
|
|
ROI_mask[0][270:430, 200:300] = ROI_mask_mini[0]
|
|
test_image = ROI_mask
|
|
image_tensor = torch.from_numpy(ROI_mask).float()
|
|
image_data.append((image_tensor, i[1], i[2]))
|
|
|
|
return image_data, mask_data
|
|
|
|
|
|
def get_onlytest(data_path, have):
|
|
global test_image, test_mask
|
|
image_list, mask_list, image_data, mask_data = [], [], [], []
|
|
|
|
image_list = get_train_files(data_path, all=True)
|
|
for i in image_list:
|
|
image = sitk.ReadImage(i[0])
|
|
image_array = sitk.GetArrayFromImage(image)
|
|
|
|
ROI_mask = np.zeros(shape=image_array.shape)
|
|
ROI_mask_mini = np.zeros(shape=(1, 160, 100))
|
|
ROI_mask_mini[0] = image_array[0][270:430, 200:300]
|
|
ROI_mask_mini = data_in_one(ROI_mask_mini)
|
|
ROI_mask[0][270:430, 200:300] = ROI_mask_mini[0]
|
|
test_image = ROI_mask
|
|
image_tensor = torch.from_numpy(ROI_mask).float()
|
|
# print(image_tensor.shape)
|
|
image_data.append((image_tensor, i[1], i[2]))
|
|
|
|
return image_data
|
|
|
|
|
|
class Dataset(data.Dataset):
|
|
def __init__(self, path, have=True, transform=None):
|
|
imgs = get_dataset(data_path=path, have=have)
|
|
self.imgs = imgs
|
|
# self.transform = transform
|
|
# self.target_transform = target_transform
|
|
|
|
def __getitem__(self, index):
|
|
image = self.imgs[0][index]
|
|
mask = self.imgs[1][index]
|
|
|
|
return image, mask
|
|
|
|
def __len__(self):
|
|
return len(self.imgs[0])
|
|
|
|
|
|
class testDataset(data.Dataset):
|
|
def __init__(self, path, have=True, transform=None):
|
|
imgs = get_onlytest(data_path=path, have=have)
|
|
self.imgs = imgs
|
|
# self.transform = transform
|
|
# self.target_transform = target_transform
|
|
|
|
def __getitem__(self, index):
|
|
image = self.imgs[index]
|
|
return image
|
|
|
|
def __len__(self):
|
|
return len(self.imgs)
|
|
|
|
|
|
def get_d1(path):
|
|
bag = Dataset(path, have=True)
|
|
train_size = int(0.9 * len(bag))
|
|
test_size = len(bag) - train_size
|
|
train_dataset, test_dataset = random_split(bag, [train_size, test_size])
|
|
return train_dataset, test_dataset
|
|
|
|
|
|
def get_d1_local(path):
|
|
bag = testDataset(path, have=False)
|
|
# train_size = int(0.9 * len(bag))
|
|
# test_size = len(bag) - train_size
|
|
# train_dataset, test_dataset = random_split(bag, [train_size, test_size])
|
|
return bag
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# get_train_files(train_data_path)
|
|
# get_dataset(train_data_path,have=True)
|
|
bag = get_d1_local()
|