upload train code

这个提交包含在:
xming521
2020-02-17 16:17:23 +08:00
父节点 e6f084f765
当前提交 d01af7bdd8
共有 13 个文件被更改,包括 831 次插入0 次删除

查看文件

@@ -0,0 +1,38 @@
import numpy as np
def dice(im1, im2):
"""
Computes the Dice coefficient, a measure of set similarity.
Parameters
----------
im1 : array-like, bool
Any array of arbitrary size. If not boolean, will be converted.
im2 : array-like, bool
Any other array of identical size. If not boolean, will be converted.
Returns
-------
dice : float
Dice coefficient as a float on range [0,1].
Maximum similarity = 1
No similarity = 0
Notes
-----
The order of inputs for `dice` is irrelevant. The result will be
identical if `im1` and `im2` are switched.
"""
im1 = np.asarray(im1).astype(np.bool)
im2 = np.asarray(im2).astype(np.bool)
if im1.shape != im2.shape:
raise ValueError("Shape mismatch: im1 and im2 must have the same shape.")
# 俩都为全黑
if not (im1.any() or im2.any()):
return 1.0
# Compute Dice coefficient
intersection = np.logical_and(im1, im2)
res = 2. * intersection.sum() / (im1.sum() + im2.sum())
return np.round(res, 5)

36
CTAI_model/utils/draw.py 普通文件
查看文件

@@ -0,0 +1,36 @@
import matplotlib.pyplot as plt
import numpy as np
data = []
data_true = []
with open('../result/0.50nohup50.txt', 'r') as f:
data = [i.replace('\n', '') for i in f.readlines()]
for i in range(len(data)):
if i % 3 == 0:
x = data[i].split(' ')
data_true.append([x[2].replace('test', '').replace('train_loss:', ''), x[-1]])
print(data_true)
ax = plt.gca()
plt.rcParams['savefig.dpi'] = 300 # 图片像素
plt.rcParams['figure.dpi'] = 200 # 分辨率
# plt.plot(range(1,51), np.squeeze([i[0] for i in data_true]), label='Train loss')
# plt.ylabel('loss')
# plt.xlabel('epochs')
# plt.title("Model: train loss")
# plt.legend()
# plt.show()
ax.invert_yaxis()
plt.plot(range(1, 51), np.squeeze([i[0] for i in data_true]), label='Train loss')
plt.ylabel('loss')
plt.xlabel('epochs')
plt.title("Model: train loss")
plt.legend()
# plt.show()
plt.savefig('plot123_2.png', dpi=200) # 指定分辨率保存

查看文件

@@ -0,0 +1,44 @@
import os
import SimpleITK as sitk
import cv2
import numpy as np
from data_set import make
def mkdir(path):
folder = os.path.exists(path)
if not folder: # 判断是否存在文件夹如果不存在则创建为文件夹
os.makedirs(path) # makedirs 创建文件时如果路径不存在会创建这个路径
filename_list = make.get_person_files('../data/all/d2/')
for i in filename_list:
pid = i[0]
print(pid)
for j in i[1]:
image = sitk.ReadImage(j)
image_array = sitk.GetArrayFromImage(image).swapaxes(0, 2)
image_array = np.rot90(image_array, -1)
image_array = np.fliplr(image_array).squeeze()
# ret, image_array = cv2.threshold(image_array, 150, 255, cv2.THRESH_BINARY)
mkdir(f'../data/png/{pid}/')
name = j.replace('.dcm', '').split('/')[-1]
# cv2.imwrite(f'../data/jpg/{pid}/{name}.jpg', image_array, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
cv2.imwrite(f'../data/png/{pid}/{name}.png', image_array, (cv2.IMWRITE_PNG_COMPRESSION, 0))
# print(filename_list)
# for i in filename_list:
# if '.dcm' in i:
# image = sitk.ReadImage(data_path + '/' + i)
# image_array = sitk.GetArrayFromImage(image).swapaxes(0, 2)
# image_array = np.rot90(image_rray, -1)
# image_array = np.fliplr(image_array)
# name = i.replace('.dcm', '')
# cv2.imwrite(f'{data_path}/{name}_train.png', image_array, (cv2.IMWRITE_PNG_COMPRESSION, 0))
# t=cv2.imread('data/out/mask-tttt.png',cv2.IMREAD_GRAYSCALE)
# print(t)

查看文件

@@ -0,0 +1,44 @@
import os
def get_train_files(data_path, file_type='dcm', all=True):
file_type = '.' + file_type
image_list, mask_list, ROI_list = [], [], []
dir_list = [data_path + i for i in os.listdir(data_path)]
filename_list = []
for dir in dir_list:
# 所有数据跑
if all:
temp = os.listdir(dir + '/arterial phase')
filename_list.extend([dir + '/arterial phase/' + name for name in temp])
if not all:
filename_list.append(dir)
# temp = os.listdir(dir)
# filename_list.extend([dir + '/' + name for name in temp])
for i in filename_list:
if file_type in i:
image_list.append(i)
if '_mask' in i:
mask_list.append(i)
# 校验文件正确
return image_list, mask_list
if __name__ == '__main__':
a, _ = get_train_files('../data/all/d2/')
_, b = get_train_files('../data/out/')
for i in range(len(a)):
aa = a[i].split('/')[-1].replace('.dcm', '')
bb = b[i].split('/')[-1].replace('_mask.png', '')
if aa != bb:
print(aa, bb, b[i])
print(a[i] + 'file list error!')
for i in range(len(b)):
aa = a[i].split('/')[-1].replace('.dcm', '')
bb = b[i].split('/')[-1].replace('_mask.png', '')
if aa != bb:
print(a[i] + 'file list error!')