Source code for core.Baseline.Predictor

import torch
import os, sys
from skimage import morphology, measure
from monai.inferers import sliding_window_inference

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))

from core.BasePredictor import BasePredictor

__all__ = ["Predictor"]


[docs] class Predictor(BasePredictor): def __init__( self, model, device, input_path, output_path, make_submission=False, exp_name=None, algo_params=None, ): super(Predictor, self).__init__( model, device, input_path, output_path, make_submission, exp_name, algo_params, ) def _inference(self, img_data): pred_mask = sliding_window_inference( img_data, 512, 4, self.model, padding_mode="constant", mode="gaussian", overlap=0.6, ) return pred_mask def _post_process(self, pred_mask): # Get probability map from the predicted logits pred_mask = torch.from_numpy(pred_mask) pred_mask = torch.softmax(pred_mask, dim=0) pred_mask = pred_mask[1].cpu().numpy() # Apply morphological post-processing pred_mask = pred_mask > 0.5 pred_mask = morphology.remove_small_holes(pred_mask, connectivity=1) pred_mask = morphology.remove_small_objects(pred_mask, 16) pred_mask = measure.label(pred_mask) return pred_mask