镜像自地址
https://github.com/xming521/CTAI.git
已同步 2025-12-06 14:46:48 +00:00
upload flask
这个提交包含在:
39
CTAI_flask/core/predict.py
普通文件
39
CTAI_flask/core/predict.py
普通文件
@@ -0,0 +1,39 @@
|
||||
import os
|
||||
import sys
|
||||
import cv2
|
||||
import torch
|
||||
import core.net.unet as net
|
||||
import numpy as np
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
torch.set_num_threads(4)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
import os
|
||||
|
||||
rate = 0.5
|
||||
|
||||
|
||||
def predict(dataset,model):
|
||||
|
||||
# unet = torch.load('./core/0.5unet.pkl').to(device)
|
||||
# torch.save(unet.state_dict(), "model_new.pth")
|
||||
|
||||
global res, img_y, mask_arrary
|
||||
with torch.no_grad():
|
||||
x = dataset[0][0].to(device)
|
||||
file_name = dataset[1]
|
||||
y = model(x)
|
||||
img_y = torch.squeeze(y).cpu().numpy()
|
||||
img_y[img_y >= rate] = 1
|
||||
img_y[img_y < rate] = 0
|
||||
img_y = img_y * 255
|
||||
cv2.imwrite(f'./tmp/mask/{file_name}_mask.png', img_y,
|
||||
(cv2.IMWRITE_PNG_COMPRESSION, 0))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 写保存模型
|
||||
# train()
|
||||
predict()
|
||||
在新工单中引用
屏蔽一个用户