"""
Adapted from the following references:
[1] https://github.com/JunMa11/NeurIPS-CellSeg/blob/main/baseline/model_training_3class.py
"""
import torch
import numpy as np
from skimage import segmentation, morphology, measure
import monai
__all__ = ["create_interior_onehot", "identify_instances_from_classmap"]
@torch.no_grad()
[docs]
def identify_instances_from_classmap(
class_map, cell_class=1, threshold=0.5, from_logits=True
):
"""Identification of cell instances from the class map"""
if from_logits:
class_map = torch.softmax(class_map, dim=0) # (C, H, W)
# Convert probability map to binary mask
pred_mask = class_map[cell_class].cpu().numpy()
# Apply morphological postprocessing
pred_mask = pred_mask > threshold
pred_mask = morphology.remove_small_holes(pred_mask, connectivity=1)
pred_mask = morphology.remove_small_objects(pred_mask, 16)
pred_mask = measure.label(pred_mask)
return pred_mask
@torch.no_grad()
[docs]
def create_interior_onehot(inst_maps):
"""
interior : (H,W), np.uint8
three-class map, values: 0,1,2
0: background
1: interior
2: boundary
"""
device = inst_maps.device
# Get (np.int16) array corresponding to label masks: (B, 1, H, W)
inst_maps = inst_maps.squeeze(1).cpu().numpy().astype(np.int16)
interior_maps = []
for inst_map in inst_maps:
# Create interior-edge map
boundary = segmentation.find_boundaries(inst_map, mode="inner")
# Refine interior-edge map
boundary = morphology.binary_dilation(boundary, morphology.disk(1))
# Assign label classes
interior_temp = np.logical_and(~boundary, inst_map > 0)
# interior_temp[boundary] = 0
interior_temp = morphology.remove_small_objects(interior_temp, min_size=16)
interior = np.zeros_like(inst_map, dtype=np.uint8)
interior[interior_temp] = 1
interior[boundary] = 2
interior_maps.append(interior)
# Aggregate interior_maps for batch
interior_maps = np.stack(interior_maps, axis=0).astype(np.uint8)
# Reshape as original label shape: (B, H, W)
interior_maps = torch.from_numpy(interior_maps).unsqueeze(1).to(device)
# Obtain one-hot map for batch
interior_onehot = monai.networks.one_hot(interior_maps, num_classes=3)
return interior_onehot