Source code for arcjetCV.segmentation.contour.cnn

import torch
from pathlib import Path
import cv2
from torchvision import transforms as T


[docs] class CNN: def __init__(self): # predict fails from GUI (Could not load library libcudnn_ops_infer.so), works from script. Fix: run it cpu only self.device = torch.device( "cpu" ) # torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") self.checkpoint_path = ( Path(__file__) .parent.absolute() .joinpath(Path("Unet-xception_25_original.pt")) ) self.model = torch.load(self.checkpoint_path, map_location=self.device) self.mean = [0.485, 0.456, 0.406] self.std = [0.229, 0.224, 0.225] self.t = T.Compose( [ T.ToPILImage(), T.Resize((512, 512)), T.ToTensor(), T.Normalize(self.mean, self.std), ] )
[docs] def predict(self, img): self.model.eval() image = self.t(img) self.model.to(self.device) image = image.to(self.device) with torch.no_grad(): image = image.unsqueeze(0) output = self.model(image) masked = torch.argmax(output, dim=1) masked = masked.cpu().squeeze(0).numpy() masked = cv2.resize( masked, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST ) return masked