Source code for core.MEDIAR.Predictor

import torch
import numpy as np
import os, sys
from monai.inferers import sliding_window_inference

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))

from core.BasePredictor import BasePredictor
from core.MEDIAR.utils import compute_masks

__all__ = ["Predictor"]


[docs] class Predictor(BasePredictor): def __init__( self, model, device, input_path, output_path, make_submission=False, exp_name=None, algo_params=None, ): super(Predictor, self).__init__( model, device, input_path, output_path, make_submission, exp_name, algo_params, )
[docs] self.hflip_tta = HorizontalFlip()
[docs] self.vflip_tta = VerticalFlip()
@torch.no_grad() def _inference(self, img_data): """Conduct model prediction""" img_data = img_data.to(self.device) img_base = img_data outputs_base = self._window_inference(img_base) outputs_base = outputs_base.cpu().squeeze() img_base.cpu() if not self.use_tta: pred_mask = outputs_base return pred_mask else: # HorizontalFlip TTA img_hflip = self.hflip_tta.apply_aug_image(img_data, apply=True) outputs_hflip = self._window_inference(img_hflip) outputs_hflip = self.hflip_tta.apply_deaug_mask(outputs_hflip, apply=True) outputs_hflip = outputs_hflip.cpu().squeeze() img_hflip = img_hflip.cpu() # VertricalFlip TTA img_vflip = self.vflip_tta.apply_aug_image(img_data, apply=True) outputs_vflip = self._window_inference(img_vflip) outputs_vflip = self.vflip_tta.apply_deaug_mask(outputs_vflip, apply=True) outputs_vflip = outputs_vflip.cpu().squeeze() img_vflip = img_vflip.cpu() # Merge Results pred_mask = torch.zeros_like(outputs_base) pred_mask[0] = (outputs_base[0] + outputs_hflip[0] - outputs_vflip[0]) / 3 pred_mask[1] = (outputs_base[1] - outputs_hflip[1] + outputs_vflip[1]) / 3 pred_mask[2] = (outputs_base[2] + outputs_hflip[2] + outputs_vflip[2]) / 3 return pred_mask def _window_inference(self, img_data, aux=False): """Inference on RoI-sized window""" outputs = sliding_window_inference( img_data, roi_size=512, sw_batch_size=4, predictor=self.model if not aux else self.model_aux, padding_mode="constant", mode="gaussian", overlap=0.6, ) return outputs def _post_process(self, pred_mask): """Generate cell instance masks.""" dP, cellprob = pred_mask[:2], self._sigmoid(pred_mask[-1]) H, W = pred_mask.shape[-2], pred_mask.shape[-1] if np.prod(H * W) < (5000 * 5000): pred_mask = compute_masks( dP, cellprob, use_gpu=True, flow_threshold=0.4, device=self.device, cellprob_threshold=0.5, )[0] else: print("\n[Whole Slide] Grid Prediction starting...") roi_size = 2000 # Get patch grid by roi_size if H % roi_size != 0: n_H = H // roi_size + 1 new_H = roi_size * n_H else: n_H = H // roi_size new_H = H if W % roi_size != 0: n_W = W // roi_size + 1 new_W = roi_size * n_W else: n_W = W // roi_size new_W = W # Allocate values on the grid pred_pad = np.zeros((new_H, new_W), dtype=np.uint32) dP_pad = np.zeros((2, new_H, new_W), dtype=np.float32) cellprob_pad = np.zeros((new_H, new_W), dtype=np.float32) dP_pad[:, :H, :W], cellprob_pad[:H, :W] = dP, cellprob for i in range(n_H): for j in range(n_W): print("Pred on Grid (%d, %d) processing..." % (i, j)) dP_roi = dP_pad[ :, roi_size * i : roi_size * (i + 1), roi_size * j : roi_size * (j + 1), ] cellprob_roi = cellprob_pad[ roi_size * i : roi_size * (i + 1), roi_size * j : roi_size * (j + 1), ] pred_mask = compute_masks( dP_roi, cellprob_roi, use_gpu=True, flow_threshold=0.4, device=self.device, cellprob_threshold=0.5, )[0] pred_pad[ roi_size * i : roi_size * (i + 1), roi_size * j : roi_size * (j + 1), ] = pred_mask pred_mask = pred_pad[:H, :W] return pred_mask def _sigmoid(self, z): return 1 / (1 + np.exp(-z))
""" Adapted from the following references: [1] https://github.com/qubvel/ttach/blob/master/ttach/transforms.py """ def hflip(x): """flip batch of images horizontally""" return x.flip(3) def vflip(x): """flip batch of images vertically""" return x.flip(2) class DualTransform: identity_param = None def __init__( self, name: str, params, ): self.params = params self.pname = name def apply_aug_image(self, image, *args, **params): raise NotImplementedError def apply_deaug_mask(self, mask, *args, **params): raise NotImplementedError class HorizontalFlip(DualTransform): """Flip images horizontally (left -> right)""" identity_param = False def __init__(self): super().__init__("apply", [False, True]) def apply_aug_image(self, image, apply=False, **kwargs): if apply: image = hflip(image) return image def apply_deaug_mask(self, mask, apply=False, **kwargs): if apply: mask = hflip(mask) return mask class VerticalFlip(DualTransform): """Flip images vertically (up -> down)""" identity_param = False def __init__(self): super().__init__("apply", [False, True]) def apply_aug_image(self, image, apply=False, **kwargs): if apply: image = vflip(image) return image def apply_deaug_mask(self, mask, apply=False, **kwargs): if apply: mask = vflip(mask) return mask