Source code for train_tools.measures

"""
Adapted from the following references:
[1] https://github.com/JunMa11/NeurIPS-CellSeg/blob/main/baseline/compute_metric.py
[2] https://github.com/stardist/stardist/blob/master/stardist/matching.py

"""

import numpy as np
from skimage import segmentation
from scipy.optimize import linear_sum_assignment
from numba import jit

__all__ = ["evaluate_f1_score_cellseg", "evaluate_f1_score"]


[docs] def evaluate_f1_score_cellseg(masks_true, masks_pred, threshold=0.5): """ Get confusion elements for cell segmentation results. Boundary pixels are not considered during evaluation. """ if np.prod(masks_true.shape) < (5000 * 5000): masks_true = _remove_boundary_cells(masks_true.astype(np.int32)) masks_pred = _remove_boundary_cells(masks_pred.astype(np.int32)) tp, fp, fn = get_confusion(masks_true, masks_pred, threshold) # Compute by Patch-based way for large images else: H, W = masks_true.shape roi_size = 2000 # Get patch grid by roi_size if H % roi_size != 0: n_H = H // roi_size + 1 new_H = roi_size * n_H else: n_H = H // roi_size new_H = H if W % roi_size != 0: n_W = W // roi_size + 1 new_W = roi_size * n_W else: n_W = W // roi_size new_W = W # Allocate values on the grid gt_pad = np.zeros((new_H, new_W), dtype=masks_true.dtype) pred_pad = np.zeros((new_H, new_W), dtype=masks_true.dtype) gt_pad[:H, :W] = masks_true pred_pad[:H, :W] = masks_pred tp, fp, fn = 0, 0, 0 # Calculate confusion elements for each patch for i in range(n_H): for j in range(n_W): gt_roi = _remove_boundary_cells( gt_pad[ roi_size * i : roi_size * (i + 1), roi_size * j : roi_size * (j + 1), ] ) pred_roi = _remove_boundary_cells( pred_pad[ roi_size * i : roi_size * (i + 1), roi_size * j : roi_size * (j + 1), ] ) tp_i, fp_i, fn_i = get_confusion(gt_roi, pred_roi, threshold) tp += tp_i fp += fp_i fn += fn_i # Calculate f1 score precision, recall, f1_score = evaluate_f1_score(tp, fp, fn) return precision, recall, f1_score
[docs] def evaluate_f1_score(tp, fp, fn): """Evaluate F1-score for the given confusion elements""" # Do not Compute on trivial results if tp == 0: precision, recall, f1_score = 0, 0, 0 else: precision = tp / (tp + fp) recall = tp / (tp + fn) f1_score = 2 * (precision * recall) / (precision + recall) return precision, recall, f1_score
def _remove_boundary_cells(mask): """Remove cells on the boundary from the mask""" # Identify boundary cells W, H = mask.shape bd = np.ones((W, H)) bd[2 : W - 2, 2 : H - 2] = 0 bd_cells = np.unique(mask * bd) # Remove cells on the boundary for i in bd_cells[1:]: mask[mask == i] = 0 # Allocate labels as sequential manner new_label, _, _ = segmentation.relabel_sequential(mask) return new_label def get_confusion(masks_true, masks_pred, threshold=0.5): """Calculate confusion matrix elements: (TP, FP, FN)""" num_gt_instances = np.max(masks_true) num_pred_instances = np.max(masks_pred) if num_pred_instances == 0: print("No segmentation results!") tp, fp, fn = 0, 0, 0 else: # Calculate IoU and exclude background label (0) iou = _get_iou(masks_true, masks_pred) iou = iou[1:, 1:] # Calculate true positives tp = _get_true_positive(iou, threshold) fp = num_pred_instances - tp fn = num_gt_instances - tp return tp, fp, fn def _get_true_positive(iou, threshold=0.5): """Get true positive (TP) pixels at the given threshold""" # Number of instances to be matched num_matched = min(iou.shape[0], iou.shape[1]) # Find optimal matching by using IoU as tie-breaker costs = -(iou >= threshold).astype(np.float) - iou / (2 * num_matched) matched_gt_label, matched_pred_label = linear_sum_assignment(costs) # Consider as the same instance only if the IoU is above the threshold match_ok = iou[matched_gt_label, matched_pred_label] >= threshold tp = match_ok.sum() return tp def _get_iou(masks_true, masks_pred): """Get the iou between masks_true and masks_pred""" # Get overlap matrix (GT Instances Num, Pred Instance Num) overlap = _label_overlap(masks_true, masks_pred) # Predicted instance pixels n_pixels_pred = np.sum(overlap, axis=0, keepdims=True) # GT instance pixels n_pixels_true = np.sum(overlap, axis=1, keepdims=True) # Calculate intersection of union (IoU) union = n_pixels_pred + n_pixels_true - overlap iou = overlap / union # Ensure numerical values iou[np.isnan(iou)] = 0.0 return iou @jit(nopython=True) def _label_overlap(x, y): """Get pixel overlaps between two masks Parameters ------------ x, y (np array; dtype int): 0=NO masks; 1,2... are mask labels Returns ------------ overlap (np array; dtype int): Overlaps of size [x.max()+1, y.max()+1] """ # Make as 1D array x, y = x.ravel(), y.ravel() # Preallocate a Contact Map matrix overlap = np.zeros((1 + x.max(), 1 + y.max()), dtype=np.uint) # Calculate the number of shared pixels for each label for i in range(len(x)): overlap[x[i], y[i]] += 1 return overlap