镜像自地址
https://github.com/xming521/CTAI.git
已同步 2025-12-06 06:36:49 +00:00
122 行
3.8 KiB
Python
122 行
3.8 KiB
Python
import sys
|
|
|
|
sys.path.append("..")
|
|
import torch
|
|
from torch.nn import init
|
|
from torch.utils.data import DataLoader
|
|
from data_set import make
|
|
from net import unet
|
|
from utils import dice_loss
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
|
|
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
|
torch.set_num_threads(1)
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
torch.cuda.empty_cache()
|
|
res = {'epoch': [], 'loss': [], 'dice': []}
|
|
|
|
|
|
def weights_init(m):
|
|
classname = m.__class__.__name__
|
|
# print(classname)
|
|
if classname.find('Conv3d') != -1:
|
|
init.xavier_normal(m.weight.data, 0.0)
|
|
init.constant_(m.bias.data, 0.0)
|
|
elif classname.find('Linear') != -1:
|
|
init.xavier_normal(m.weight.data, 0.0)
|
|
init.constant_(m.bias.data, 0.0)
|
|
|
|
|
|
# 参数
|
|
rate = 0.50
|
|
learn_rate = 0.001
|
|
epochs = 1
|
|
# train_dataset_path = '../data/all/d1/'
|
|
train_dataset_path = 'E:/projects/python projects/ct_data/'
|
|
|
|
train_dataset, test_dataset = make.get_d1(train_dataset_path)
|
|
unet = unet.Unet(1, 1).to(device).apply(weights_init)
|
|
criterion = torch.nn.BCELoss().to(device)
|
|
optimizer = torch.optim.Adam(unet.parameters(), learn_rate)
|
|
|
|
|
|
def train():
|
|
global res
|
|
dataloaders = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0)
|
|
for epoch in range(epochs):
|
|
dt_size = len(dataloaders.dataset)
|
|
epoch_loss, epoch_dice = 0, 0
|
|
step = 0
|
|
for x, y in dataloaders:
|
|
id = x[1:]
|
|
step += 1
|
|
x = x[0].to(device)
|
|
y = y[1].to(device)
|
|
print(x.size())
|
|
print(y.size())
|
|
optimizer.zero_grad()
|
|
outputs = unet(x)
|
|
loss = criterion(outputs, y)
|
|
loss.backward()
|
|
optimizer.step()
|
|
# dice
|
|
# a = outputs.cpu().detach().squeeze(1).numpy()
|
|
# a[a >= rate] = 1
|
|
# a[a < rate] = 0
|
|
# b = y.cpu().detach().numpy()
|
|
# dice = dice_loss.dice(a, b)
|
|
# epoch_loss += float(loss.item())
|
|
# epoch_dice += dice
|
|
|
|
if step % 100 == 0:
|
|
res['epoch'].append((epoch + 1) * step)
|
|
res['loss'].append(loss.item())
|
|
print("epoch%d step%d/%d train_loss:%0.3f" % (
|
|
epoch, step, (dt_size - 1) // dataloaders.batch_size + 1, loss.item()),
|
|
end='')
|
|
test()
|
|
# print("epoch %d loss:%0.3f,dice %f" % (epoch, epoch_loss / step, epoch_dice / step))
|
|
plt.plot(res['epoch'], np.squeeze(res['cost']), label='Train cost')
|
|
plt.ylabel('cost')
|
|
plt.xlabel('epochs')
|
|
plt.title("Model: train cost")
|
|
plt.legend()
|
|
|
|
plt.plot(res['epoch'], np.squeeze(res), label='Validation cost', color='#FF9966')
|
|
plt.ylabel('loss')
|
|
plt.xlabel('epochs')
|
|
plt.title("Model:validation loss")
|
|
plt.legend()
|
|
|
|
plt.savefig("examples.jpg")
|
|
|
|
# torch.save(unet, 'unet.pkl')
|
|
# model = torch.load('unet.pkl')
|
|
test()
|
|
|
|
|
|
def test():
|
|
global res, img_y, mask_arrary
|
|
epoch_dice = 0
|
|
with torch.no_grad():
|
|
dataloaders = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=0)
|
|
for x, mask in dataloaders:
|
|
id = x[1:] # ('1026',), ('10018',)]先病人号后片号
|
|
x = x[0].to(device)
|
|
y = unet(x)
|
|
mask_arrary = mask[1].cpu().squeeze(0).detach().numpy()
|
|
img_y = torch.squeeze(y).cpu().numpy()
|
|
img_y[img_y >= rate] = 1
|
|
img_y[img_y < rate] = 0
|
|
img_y = img_y * 255
|
|
epoch_dice += dice_loss.dice(img_y, mask_arrary)
|
|
# cv.imwrite(f'data/out/{mask[0][0]}-result.png', img_y, (cv.IMWRITE_PNG_COMPRESSION, 0))
|
|
print('test dice %f' % (epoch_dice / len(dataloaders)))
|
|
res['dice'].append(epoch_dice / len(dataloaders))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
train()
|
|
test()
|