Source code for core.Baseline.Trainer

import torch
import os, sys
import monai

from monai.data import decollate_batch

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))

from core.BaseTrainer import BaseTrainer
from core.Baseline.utils import create_interior_onehot, identify_instances_from_classmap
from train_tools.measures import evaluate_f1_score_cellseg
from tqdm import tqdm

__all__ = ["Trainer"]


[docs] class Trainer(BaseTrainer): def __init__( self, model, dataloaders, optimizer, scheduler=None, criterion=None, num_epochs=100, device="cuda:0", no_valid=False, valid_frequency=1, amp=False, algo_params=None, ): super(Trainer, self).__init__( model, dataloaders, optimizer, scheduler, criterion, num_epochs, device, no_valid, valid_frequency, amp, algo_params, ) # Dice loss as segmentation criterion
[docs] self.criterion = monai.losses.DiceCELoss(softmax=True)
def _epoch_phase(self, phase): """Learning process for 1 Epoch.""" phase_results = {} # Set model mode self.model.train() if phase == "train" else self.model.eval() # Epoch process for batch_data in tqdm(self.dataloaders[phase]): images = batch_data["img"].to(self.device) labels = batch_data["label"].to(self.device) self.optimizer.zero_grad() # Map label masks to 3-class onehot map labels_onehot = create_interior_onehot(labels) # Forward pass with torch.set_grad_enabled(phase == "train"): outputs = self._inference(images, phase) loss = self.criterion(outputs, labels_onehot) self.loss_metric.append(loss) if phase != "train": f1_score = self._get_f1_metric(outputs, labels) self.f1_metric.append(f1_score) # Backward pass if phase == "train": # For the mixed precision training if self.amp: self.scaler.scale(loss).backward() self.scaler.unscale_(self.optimizer) self.scaler.step(self.optimizer) self.scaler.update() else: loss.backward() self.optimizer.step() # Update metrics phase_results = self._update_results( phase_results, self.loss_metric, "loss", phase ) if phase != "train": phase_results = self._update_results( phase_results, self.f1_metric, "f1_score", phase ) return phase_results def _post_process(self, outputs, labels_onehot): """Conduct post-processing for outputs & labels.""" outputs = [self.post_pred(i) for i in decollate_batch(outputs)] labels_onehot = [self.post_gt(i) for i in decollate_batch(labels_onehot)] return outputs, labels_onehot def _get_f1_metric(self, masks_pred, masks_true): masks_pred = identify_instances_from_classmap(masks_pred[0]) masks_true = masks_true.squeeze(0).squeeze(0).cpu().numpy() f1_score = evaluate_f1_score_cellseg(masks_true, masks_pred)[-1] return f1_score