import torch
import numpy as np
from tqdm import tqdm
from monai.inferers import sliding_window_inference
from monai.metrics import CumulativeAverage
from monai.transforms import (
Activations,
AsDiscrete,
Compose,
EnsureType,
)
import os, sys
import copy
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../")))
from core.utils import print_learning_device, print_with_logging
from train_tools.measures import evaluate_f1_score_cellseg
[docs]
class BaseTrainer:
"""Abstract base class for trainer implementations"""
def __init__(
self,
model,
dataloaders,
optimizer,
scheduler=None,
criterion=None,
num_epochs=100,
device="cuda:0",
no_valid=False,
valid_frequency=1,
amp=False,
algo_params=None,
):
[docs]
self.model = model.to(device)
[docs]
self.dataloaders = dataloaders
[docs]
self.optimizer = optimizer
[docs]
self.scheduler = scheduler
[docs]
self.criterion = criterion
[docs]
self.num_epochs = num_epochs
[docs]
self.no_valid = no_valid
[docs]
self.valid_frequency = valid_frequency
[docs]
self.best_weights = None
[docs]
self.best_f1_score = 0.1
# FP-16 Scaler
[docs]
self.scaler = torch.cuda.amp.GradScaler() if amp else None
# Assign algoritm-specific arguments
if algo_params:
self.__dict__.update((k, v) for k, v in algo_params.items())
# Cumulitive statistics
[docs]
self.loss_metric = CumulativeAverage()
[docs]
self.f1_metric = CumulativeAverage()
# Post-processing functions
[docs]
self.post_pred = Compose(
[EnsureType(), Activations(softmax=True), AsDiscrete(threshold=0.5)]
)
[docs]
self.post_gt = Compose([EnsureType(), AsDiscrete(to_onehot=None)])
[docs]
def train(self):
"""Train the model"""
# Print learning device name
print_learning_device(self.device)
# Learning process
for epoch in range(1, self.num_epochs + 1):
print(f"[Round {epoch}/{self.num_epochs}]")
# Train Epoch Phase
print(">>> Train Epoch")
train_results = self._epoch_phase("train")
print_with_logging(train_results, epoch)
if self.scheduler is not None:
self.scheduler.step()
if epoch % self.valid_frequency == 0:
if not self.no_valid:
# Valid Epoch Phase
print(">>> Valid Epoch")
valid_results = self._epoch_phase("valid")
print_with_logging(valid_results, epoch)
if "Valid_F1_Score" in valid_results.keys():
current_f1_score = valid_results["Valid_F1_Score"]
self._update_best_model(current_f1_score)
else:
print(">>> TuningSet Epoch")
tuning_cell_counts = self._tuningset_evaluation()
tuning_count_dict = {"TuningSet_Cell_Count": tuning_cell_counts}
print_with_logging(tuning_count_dict, epoch)
current_cell_count = tuning_cell_counts
self._update_best_model(current_cell_count)
print("-" * 50)
self.best_f1_score = 0
if self.best_weights is not None:
self.model.load_state_dict(self.best_weights)
def _epoch_phase(self, phase):
"""Learning process for 1 Epoch (for different phases).
Args:
phase (str): "train", "valid", "test"
Returns:
dict: statistics for the phase results
"""
phase_results = {}
# Set model mode
self.model.train() if phase == "train" else self.model.eval()
# Epoch process
for batch_data in tqdm(self.dataloaders[phase]):
images = batch_data["img"].to(self.device)
labels = batch_data["label"].to(self.device)
self.optimizer.zero_grad()
# Forward pass
with torch.set_grad_enabled(phase == "train"):
outputs = self.model(images)
loss = self.criterion(outputs, labels)
self.loss_metric.append(loss)
# Backward pass
if phase == "train":
loss.backward()
self.optimizer.step()
# Update metrics
phase_results = self._update_results(
phase_results, self.loss_metric, "loss", phase
)
return phase_results
@torch.no_grad()
def _tuningset_evaluation(self):
cell_counts_total = []
self.model.eval()
for batch_data in tqdm(self.dataloaders["tuning"]):
images = batch_data["img"].to(self.device)
if images.shape[-1] > 5000:
continue
outputs = sliding_window_inference(
images,
roi_size=512,
sw_batch_size=4,
predictor=self.model,
padding_mode="constant",
mode="gaussian",
)
outputs = outputs.squeeze(0)
outputs, _ = self._post_process(outputs, None)
count = len(np.unique(outputs) - 1)
cell_counts_total.append(count)
cell_counts_total_sum = np.sum(cell_counts_total)
print("Cell Counts Total: (%d)" % (cell_counts_total_sum))
return cell_counts_total_sum
def _update_results(self, phase_results, metric, metric_key, phase="train"):
"""Aggregate and flush metrics
Args:
phase_results (dict): base dictionary to log metrics
metric (_type_): cumulated metrics
metric_key (_type_): name of metric
phase (str, optional): current phase name. Defaults to "train".
Returns:
dict: dictionary of metrics for the current phase
"""
# Refine metrics name
metric_key = "_".join([phase, metric_key]).title()
# Aggregate metrics
metric_item = round(metric.aggregate().item(), 4)
# Log metrics to dictionary
phase_results[metric_key] = metric_item
# Flush metrics
metric.reset()
return phase_results
def _update_best_model(self, current_f1_score):
if current_f1_score > self.best_f1_score:
self.best_weights = copy.deepcopy(self.model.state_dict())
self.best_f1_score = current_f1_score
print(
"\n>>>> Update Best Model with score: {}\n".format(self.best_f1_score)
)
else:
pass
def _inference(self, images, phase="train"):
"""inference methods for different phase"""
if phase != "train":
outputs = sliding_window_inference(
images,
roi_size=512,
sw_batch_size=4,
predictor=self.model,
padding_mode="reflect",
mode="gaussian",
overlap=0.5,
)
else:
outputs = self.model(images)
return outputs
def _post_process(self, outputs, labels):
return outputs, labels
def _get_f1_metric(self, masks_pred, masks_true):
f1_score = evaluate_f1_score_cellseg(masks_true, masks_pred)[-1]
return f1_score