upload flask

这个提交包含在:
xming521
2020-02-17 16:12:00 +08:00
父节点 8c361a09b5
当前提交 64b9d1cc70
共有 11 个文件被更改,包括 502 次插入0 次删除

39
CTAI_flask/core/predict.py 普通文件
查看文件

@@ -0,0 +1,39 @@
import os
import sys
import cv2
import torch
import core.net.unet as net
import numpy as np
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
torch.set_num_threads(4)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()
import os
rate = 0.5
def predict(dataset,model):
# unet = torch.load('./core/0.5unet.pkl').to(device)
# torch.save(unet.state_dict(), "model_new.pth")
global res, img_y, mask_arrary
with torch.no_grad():
x = dataset[0][0].to(device)
file_name = dataset[1]
y = model(x)
img_y = torch.squeeze(y).cpu().numpy()
img_y[img_y >= rate] = 1
img_y[img_y < rate] = 0
img_y = img_y * 255
cv2.imwrite(f'./tmp/mask/{file_name}_mask.png', img_y,
(cv2.IMWRITE_PNG_COMPRESSION, 0))
if __name__ == '__main__':
# 写保存模型
# train()
predict()