镜像自地址
https://github.com/xming521/CTAI.git
已同步 2025-12-06 06:36:49 +00:00
添加使用TensorRT转换并推理模型
这个提交包含在:
117
CTAI_tensorRT/main.py
普通文件
117
CTAI_tensorRT/main.py
普通文件
@@ -0,0 +1,117 @@
|
||||
import cv2
|
||||
import tensorrt as trt
|
||||
import pycuda.driver as cuda
|
||||
import pycuda.autoinit
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from CTAI_tensorRT.to_oonx import data_in_one
|
||||
|
||||
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
|
||||
|
||||
onnx_path = 'unet.onnx'
|
||||
engine_path = 'model.engine'
|
||||
shape = (1, 1, 512, 512)
|
||||
|
||||
import tensorrt as trt
|
||||
import onnx
|
||||
import SimpleITK as sitk
|
||||
|
||||
|
||||
def build():
|
||||
model = onnx.load("unet.onnx")
|
||||
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||
# Create the TensorRT engine
|
||||
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, trt.OnnxParser(network,
|
||||
TRT_LOGGER) as parser: # Parse the ONNX model
|
||||
|
||||
parser.parse(model.SerializeToString())
|
||||
config = builder.create_builder_config()
|
||||
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 25) # 1 MiB
|
||||
config.set_flag(trt.BuilderFlag.PREFER_PRECISION_CONSTRAINTS)
|
||||
config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS)
|
||||
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
|
||||
profile = builder.create_optimization_profile()
|
||||
|
||||
profile.set_shape("input", (1, 1, 512, 512), (32, 1, 512, 512), (64, 1, 512, 512))
|
||||
config.add_optimization_profile(profile)
|
||||
serialized_engine = builder.build_serialized_network(network, config)
|
||||
|
||||
# Save the TensorRT engine to disk
|
||||
with open("engine.trt", "wb") as f:
|
||||
f.write(serialized_engine)
|
||||
|
||||
|
||||
# Simple helper data class that's a little nicer to use than a 2-tuple.
|
||||
class HostDeviceMem(object):
|
||||
def __init__(self, host_mem, device_mem):
|
||||
self.host = host_mem
|
||||
self.device = device_mem
|
||||
|
||||
def __str__(self):
|
||||
return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def get_data(data_path):
|
||||
image = sitk.ReadImage(data_path)
|
||||
image_array = sitk.GetArrayFromImage(image)
|
||||
|
||||
ROI_mask = np.zeros(shape=image_array.shape)
|
||||
ROI_mask_mini = np.zeros(shape=(1, 160, 100))
|
||||
ROI_mask_mini[0] = image_array[0][270:430, 200:300]
|
||||
ROI_mask_mini = data_in_one(ROI_mask_mini)
|
||||
ROI_mask[0][270:430, 200:300] = ROI_mask_mini[0]
|
||||
return ROI_mask
|
||||
|
||||
|
||||
# Allocates all buffers required for an engine, i.e. host/device inputs/outputs.
|
||||
def allocate_buffers(engine):
|
||||
inputs = []
|
||||
outputs = []
|
||||
bindings = []
|
||||
stream = cuda.Stream()
|
||||
input_data = get_data('20014.dcm')
|
||||
for binding in engine:
|
||||
size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
|
||||
dtype = trt.nptype(engine.get_binding_dtype(binding))
|
||||
# Allocate host and device buffers
|
||||
# 申请锁页内存
|
||||
host_mem = cuda.pagelocked_empty(size, dtype)
|
||||
device_mem = cuda.mem_alloc(host_mem.nbytes)
|
||||
# Append the device buffer to device bindings.
|
||||
bindings.append(int(device_mem))
|
||||
# Append to the appropriate list.
|
||||
if engine.binding_is_input(binding):
|
||||
np.copyto(host_mem, input_data.ravel())
|
||||
inputs.append(HostDeviceMem(host_mem, device_mem))
|
||||
else:
|
||||
outputs.append(HostDeviceMem(host_mem, device_mem))
|
||||
return inputs, outputs, bindings, stream
|
||||
|
||||
|
||||
with open("engine.trt", "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
|
||||
engine = runtime.deserialize_cuda_engine(f.read())
|
||||
with engine.create_execution_context() as context:
|
||||
# 分配输入/输出内存
|
||||
inputs, outputs, bindings, stream = allocate_buffers(engine)
|
||||
# Transfer input data to the GPU.
|
||||
[cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
|
||||
# Run inference.
|
||||
context.execute_async(batch_size=1, bindings=bindings, stream_handle=stream.handle)
|
||||
# Transfer predictions back from the GPU.
|
||||
[cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
|
||||
# Synchronize the stream
|
||||
stream.synchronize()
|
||||
# Return only the host outputs.
|
||||
res = [out.host for out in outputs][0]
|
||||
res = res.reshape(engine.get_binding_shape(1)[1:])
|
||||
res[res >= 0.5] = 1
|
||||
res[res < 0.5] = 0
|
||||
res = res * 255
|
||||
# cv2.imwrite(f'res.png', res, (cv2.IMWRITE_PNG_COMPRESSION, 0))
|
||||
47
CTAI_tensorRT/to_oonx.py
普通文件
47
CTAI_tensorRT/to_oonx.py
普通文件
@@ -0,0 +1,47 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from CTAI_model.net import unet
|
||||
import SimpleITK as sitk
|
||||
|
||||
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
def data_in_one(inputdata):
|
||||
if not inputdata.any():
|
||||
return inputdata
|
||||
inputdata = (inputdata - inputdata.min()) / (inputdata.max() - inputdata.min())
|
||||
return inputdata
|
||||
|
||||
|
||||
def get_data(data_path):
|
||||
image = sitk.ReadImage(data_path)
|
||||
image_array = sitk.GetArrayFromImage(image)
|
||||
|
||||
ROI_mask = np.zeros(shape=image_array.shape)
|
||||
ROI_mask_mini = np.zeros(shape=(1, 160, 100))
|
||||
ROI_mask_mini[0] = image_array[0][270:430, 200:300]
|
||||
ROI_mask_mini = data_in_one(ROI_mask_mini)
|
||||
ROI_mask[0][270:430, 200:300] = ROI_mask_mini[0]
|
||||
test_image = ROI_mask
|
||||
image_tensor = torch.from_numpy(ROI_mask).float().unsqueeze(1)
|
||||
return image_tensor
|
||||
|
||||
def init_model():
|
||||
model = unet.Unet(1, 1).to(device)
|
||||
if torch.cuda.is_available():
|
||||
model.load_state_dict(torch.load("./model.pth", map_location=device))
|
||||
else:
|
||||
model.load_state_dict(torch.load("./model.pth", map_location='cpu'))
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def main():
|
||||
data = get_data('20014.dcm').to(device)
|
||||
model = init_model()
|
||||
torch.onnx.export(model, data, "unet.onnx")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
在新工单中引用
屏蔽一个用户