镜像自地址
https://github.com/xming521/CTAI.git
已同步 2025-12-06 06:36:49 +00:00
upload train code
这个提交包含在:
0
CTAI_model/net/__init__.py
普通文件
0
CTAI_model/net/__init__.py
普通文件
53
CTAI_model/net/test.py
普通文件
53
CTAI_model/net/test.py
普通文件
@@ -0,0 +1,53 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import cv2
|
||||
|
||||
sys.path.append("..")
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from data_set import make
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
torch.set_num_threads(4)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
torch.cuda.empty_cache()
|
||||
res = {'epoch': [], 'loss': [], 'dice': []}
|
||||
|
||||
test_data_path = '../data/all/d2/'
|
||||
rate = 0.5
|
||||
|
||||
test_dataset = make.get_d1_local(test_data_path)
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def mkdir(path):
|
||||
folder = os.path.exists(path)
|
||||
if not folder: # 判断是否存在文件夹如果不存在则创建为文件夹
|
||||
os.makedirs(path) # makedirs 创建文件时如果路径不存在会创建这个路径
|
||||
|
||||
|
||||
def onlytest():
|
||||
unet = torch.load('../result/0.5unet.pkl').to(device)
|
||||
global res, img_y, mask_arrary
|
||||
epoch_dice = 0
|
||||
with torch.no_grad():
|
||||
dataloaders = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)
|
||||
for x in dataloaders:
|
||||
id = x[1:] # ('1026',), ('10018',)]先病人号后片号
|
||||
print(id, 'id')
|
||||
x = x[0].to(device)
|
||||
y = unet(x)
|
||||
img_y = torch.squeeze(y).cpu().numpy()
|
||||
img_y[img_y >= rate] = 1
|
||||
img_y[img_y < rate] = 0
|
||||
img_y = img_y * 255
|
||||
mkdir(f'../data/out/{id[0][0]}/arterial phase/')
|
||||
cv2.imwrite(f'../data/out/{id[0][0]}/arterial phase/{id[1][0]}_mask.png', img_y,
|
||||
(cv2.IMWRITE_PNG_COMPRESSION, 0))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# train()
|
||||
onlytest()
|
||||
121
CTAI_model/net/train.py
普通文件
121
CTAI_model/net/train.py
普通文件
@@ -0,0 +1,121 @@
|
||||
import sys
|
||||
|
||||
sys.path.append("..")
|
||||
import torch
|
||||
from torch.nn import init
|
||||
from torch.utils.data import DataLoader
|
||||
from data_set import make
|
||||
from net import unet
|
||||
from utils import dice_loss
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||||
torch.set_num_threads(1)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
torch.cuda.empty_cache()
|
||||
res = {'epoch': [], 'loss': [], 'dice': []}
|
||||
|
||||
|
||||
def weights_init(m):
|
||||
classname = m.__class__.__name__
|
||||
# print(classname)
|
||||
if classname.find('Conv3d') != -1:
|
||||
init.xavier_normal(m.weight.data, 0.0)
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
elif classname.find('Linear') != -1:
|
||||
init.xavier_normal(m.weight.data, 0.0)
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
|
||||
|
||||
# 参数
|
||||
rate = 0.50
|
||||
learn_rate = 0.001
|
||||
epochs = 1
|
||||
# train_dataset_path = '../data/all/d1/'
|
||||
train_dataset_path = 'E:/projects/python projects/ct_data/'
|
||||
|
||||
train_dataset, test_dataset = make.get_d1(train_dataset_path)
|
||||
unet = unet.Unet(1, 1).to(device).apply(weights_init)
|
||||
criterion = torch.nn.BCELoss().to(device)
|
||||
optimizer = torch.optim.Adam(unet.parameters(), learn_rate)
|
||||
|
||||
|
||||
def train():
|
||||
global res
|
||||
dataloaders = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)
|
||||
for epoch in range(epochs):
|
||||
dt_size = len(dataloaders.dataset)
|
||||
epoch_loss, epoch_dice = 0, 0
|
||||
step = 0
|
||||
for x, y in dataloaders:
|
||||
id = x[1:]
|
||||
step += 1
|
||||
x = x[0].to(device)
|
||||
y = y[1].to(device)
|
||||
print(x.size())
|
||||
print(y.size())
|
||||
optimizer.zero_grad()
|
||||
outputs = unet(x)
|
||||
loss = criterion(outputs, y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# dice
|
||||
# a = outputs.cpu().detach().squeeze(1).numpy()
|
||||
# a[a >= rate] = 1
|
||||
# a[a < rate] = 0
|
||||
# b = y.cpu().detach().numpy()
|
||||
# dice = dice_loss.dice(a, b)
|
||||
# epoch_loss += float(loss.item())
|
||||
# epoch_dice += dice
|
||||
|
||||
if step % 100 == 0:
|
||||
res['epoch'].append((epoch + 1) * step)
|
||||
res['loss'].append(loss.item())
|
||||
print("epoch%d step%d/%d train_loss:%0.3f" % (
|
||||
epoch, step, (dt_size - 1) // dataloaders.batch_size + 1, loss.item()),
|
||||
end='')
|
||||
test()
|
||||
# print("epoch %d loss:%0.3f,dice %f" % (epoch, epoch_loss / step, epoch_dice / step))
|
||||
plt.plot(res['epoch'], np.squeeze(res['cost']), label='Train cost')
|
||||
plt.ylabel('cost')
|
||||
plt.xlabel('epochs')
|
||||
plt.title("Model: train cost")
|
||||
plt.legend()
|
||||
|
||||
plt.plot(res['epoch'], np.squeeze(res), label='Validation cost', color='#FF9966')
|
||||
plt.ylabel('loss')
|
||||
plt.xlabel('epochs')
|
||||
plt.title("Model:validation loss")
|
||||
plt.legend()
|
||||
|
||||
plt.savefig("examples.jpg")
|
||||
|
||||
# torch.save(unet, 'unet.pkl')
|
||||
# model = torch.load('unet.pkl')
|
||||
test()
|
||||
|
||||
|
||||
def test():
|
||||
global res, img_y, mask_arrary
|
||||
epoch_dice = 0
|
||||
with torch.no_grad():
|
||||
dataloaders = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=0)
|
||||
for x, mask in dataloaders:
|
||||
id = x[1:] # ('1026',), ('10018',)]先病人号后片号
|
||||
x = x[0].to(device)
|
||||
y = unet(x)
|
||||
mask_arrary = mask[1].cpu().squeeze(0).detach().numpy()
|
||||
img_y = torch.squeeze(y).cpu().numpy()
|
||||
img_y[img_y >= rate] = 1
|
||||
img_y[img_y < rate] = 0
|
||||
img_y = img_y * 255
|
||||
epoch_dice += dice_loss.dice(img_y, mask_arrary)
|
||||
# cv.imwrite(f'data/out/{mask[0][0]}-result.png', img_y, (cv.IMWRITE_PNG_COMPRESSION, 0))
|
||||
print('test dice %f' % (epoch_dice / len(dataloaders)))
|
||||
res['dice'].append(epoch_dice / len(dataloaders))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
||||
test()
|
||||
68
CTAI_model/net/unet.py
普通文件
68
CTAI_model/net/unet.py
普通文件
@@ -0,0 +1,68 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class DoubleConv(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super(DoubleConv, self).__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, 3, padding=1),
|
||||
nn.BatchNorm2d(out_ch), # 归一
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_ch, out_ch, 3, padding=1),
|
||||
nn.BatchNorm2d(out_ch),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.conv(input)
|
||||
|
||||
|
||||
class Unet(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super(Unet, self).__init__()
|
||||
|
||||
self.conv1 = DoubleConv(in_ch, 64)
|
||||
self.pool1 = nn.MaxPool2d(2)
|
||||
self.conv2 = DoubleConv(64, 128)
|
||||
self.pool2 = nn.MaxPool2d(2)
|
||||
self.conv3 = DoubleConv(128, 256)
|
||||
self.pool3 = nn.MaxPool2d(2)
|
||||
self.conv4 = DoubleConv(256, 512)
|
||||
self.pool4 = nn.MaxPool2d(2)
|
||||
self.conv5 = DoubleConv(512, 1024)
|
||||
self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
|
||||
self.conv6 = DoubleConv(1024, 512)
|
||||
self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
|
||||
self.conv7 = DoubleConv(512, 256)
|
||||
self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
|
||||
self.conv8 = DoubleConv(256, 128)
|
||||
self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
|
||||
self.conv9 = DoubleConv(128, 64)
|
||||
self.conv10 = nn.Conv2d(64, out_ch, 1)
|
||||
|
||||
def forward(self, x):
|
||||
c1 = self.conv1(x)
|
||||
p1 = self.pool1(c1)
|
||||
c2 = self.conv2(p1)
|
||||
p2 = self.pool2(c2)
|
||||
c3 = self.conv3(p2)
|
||||
p3 = self.pool3(c3)
|
||||
c4 = self.conv4(p3)
|
||||
p4 = self.pool4(c4)
|
||||
c5 = self.conv5(p4)
|
||||
up_6 = self.up6(c5)
|
||||
merge6 = torch.cat([up_6, c4], dim=1)
|
||||
c6 = self.conv6(merge6)
|
||||
up_7 = self.up7(c6)
|
||||
merge7 = torch.cat([up_7, c3], dim=1)
|
||||
c7 = self.conv7(merge7)
|
||||
up_8 = self.up8(c7)
|
||||
merge8 = torch.cat([up_8, c2], dim=1)
|
||||
c8 = self.conv8(merge8)
|
||||
up_9 = self.up9(c8)
|
||||
merge9 = torch.cat([up_9, c1], dim=1)
|
||||
c9 = self.conv9(merge9)
|
||||
c10 = self.conv10(c9)
|
||||
out = nn.Sigmoid()(c10)
|
||||
return out
|
||||
在新工单中引用
屏蔽一个用户