Source code for core.MEDIAR.EnsemblePredictor

import torch
import os, sys, copy
import numpy as np

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))

from core.MEDIAR.Predictor import Predictor

__all__ = ["EnsemblePredictor"]


[docs] class EnsemblePredictor(Predictor): def __init__( self, model, model_aux, device, input_path, output_path, make_submission=False, exp_name=None, algo_params=None, ): super(EnsemblePredictor, self).__init__( model, device, input_path, output_path, make_submission, exp_name, algo_params, )
[docs] self.model_aux = model_aux
@torch.no_grad() def _inference(self, img_data): self.model_aux.to(self.device) self.model_aux.eval() img_data = img_data.to(self.device) img_base = img_data outputs_base = self._window_inference(img_base) outputs_base = outputs_base.cpu().squeeze() outputs_aux = self._window_inference(img_base, aux=True) outputs_aux = outputs_aux.cpu().squeeze() img_base.cpu() if not self.use_tta: pred_mask = (outputs_base + outputs_aux) / 2 return pred_mask else: # HorizontalFlip TTA img_hflip = self.hflip_tta.apply_aug_image(img_data, apply=True) outputs_hflip = self._window_inference(img_hflip) outputs_hflip_aux = self._window_inference(img_hflip, aux=True) outputs_hflip = self.hflip_tta.apply_deaug_mask(outputs_hflip, apply=True) outputs_hflip_aux = self.hflip_tta.apply_deaug_mask( outputs_hflip_aux, apply=True ) outputs_hflip = outputs_hflip.cpu().squeeze() outputs_hflip_aux = outputs_hflip_aux.cpu().squeeze() img_hflip = img_hflip.cpu() # VertricalFlip TTA img_vflip = self.vflip_tta.apply_aug_image(img_data, apply=True) outputs_vflip = self._window_inference(img_vflip) outputs_vflip_aux = self._window_inference(img_vflip, aux=True) outputs_vflip = self.vflip_tta.apply_deaug_mask(outputs_vflip, apply=True) outputs_vflip_aux = self.vflip_tta.apply_deaug_mask( outputs_vflip_aux, apply=True ) outputs_vflip = outputs_vflip.cpu().squeeze() outputs_vflip_aux = outputs_vflip_aux.cpu().squeeze() img_vflip = img_vflip.cpu() # Merge Results pred_mask = torch.zeros_like(outputs_base) pred_mask[0] = (outputs_base[0] + outputs_hflip[0] - outputs_vflip[0]) / 3 pred_mask[1] = (outputs_base[1] - outputs_hflip[1] + outputs_vflip[1]) / 3 pred_mask[2] = (outputs_base[2] + outputs_hflip[2] + outputs_vflip[2]) / 3 pred_mask_aux = torch.zeros_like(outputs_aux) pred_mask_aux[0] = ( outputs_aux[0] + outputs_hflip_aux[0] - outputs_vflip_aux[0] ) / 3 pred_mask_aux[1] = ( outputs_aux[1] - outputs_hflip_aux[1] + outputs_vflip_aux[1] ) / 3 pred_mask_aux[2] = ( outputs_aux[2] + outputs_hflip_aux[2] + outputs_vflip_aux[2] ) / 3 pred_mask = (pred_mask + pred_mask_aux) / 2 return pred_mask