镜像自地址
https://github.com/xming521/CTAI.git
已同步 2025-12-06 14:46:48 +00:00
添加使用TensorRT转换并推理模型
这个提交包含在:
47
CTAI_tensorRT/to_oonx.py
普通文件
47
CTAI_tensorRT/to_oonx.py
普通文件
@@ -0,0 +1,47 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from CTAI_model.net import unet
|
||||
import SimpleITK as sitk
|
||||
|
||||
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
def data_in_one(inputdata):
|
||||
if not inputdata.any():
|
||||
return inputdata
|
||||
inputdata = (inputdata - inputdata.min()) / (inputdata.max() - inputdata.min())
|
||||
return inputdata
|
||||
|
||||
|
||||
def get_data(data_path):
|
||||
image = sitk.ReadImage(data_path)
|
||||
image_array = sitk.GetArrayFromImage(image)
|
||||
|
||||
ROI_mask = np.zeros(shape=image_array.shape)
|
||||
ROI_mask_mini = np.zeros(shape=(1, 160, 100))
|
||||
ROI_mask_mini[0] = image_array[0][270:430, 200:300]
|
||||
ROI_mask_mini = data_in_one(ROI_mask_mini)
|
||||
ROI_mask[0][270:430, 200:300] = ROI_mask_mini[0]
|
||||
test_image = ROI_mask
|
||||
image_tensor = torch.from_numpy(ROI_mask).float().unsqueeze(1)
|
||||
return image_tensor
|
||||
|
||||
def init_model():
|
||||
model = unet.Unet(1, 1).to(device)
|
||||
if torch.cuda.is_available():
|
||||
model.load_state_dict(torch.load("./model.pth", map_location=device))
|
||||
else:
|
||||
model.load_state_dict(torch.load("./model.pth", map_location='cpu'))
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def main():
|
||||
data = get_data('20014.dcm').to(device)
|
||||
model = init_model()
|
||||
torch.onnx.export(model, data, "unet.onnx")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
在新工单中引用
屏蔽一个用户