import numpy as np
from skimage import exposure
from monai.config import KeysCollection
from monai.transforms.transform import Transform
from monai.transforms.compose import MapTransform
from typing import Dict, Hashable, Mapping
__all__ = [
"CustomNormalizeImage",
"CustomNormalizeImageD",
"CustomNormalizeImageDict",
"CustomNormalizeImaged",
]
[docs]
class CustomNormalizeImage(Transform):
"""Normalize the image."""
def __init__(self, percentiles=[0, 99.5], channel_wise=False):
self.lower, self.upper = percentiles
[docs]
self.channel_wise = channel_wise
def _normalize(self, img) -> np.ndarray:
non_zero_vals = img[np.nonzero(img)]
percentiles = np.percentile(non_zero_vals, [self.lower, self.upper])
img_norm = exposure.rescale_intensity(
img, in_range=(percentiles[0], percentiles[1]), out_range="uint8"
)
return img_norm.astype(np.uint8)
def __call__(self, img: np.ndarray) -> np.ndarray:
if self.channel_wise:
pre_img_data = np.zeros(img.shape, dtype=np.uint8)
for i in range(img.shape[-1]):
img_channel_i = img[:, :, i]
if len(img_channel_i[np.nonzero(img_channel_i)]) > 0:
pre_img_data[:, :, i] = self._normalize(img_channel_i)
img = pre_img_data
else:
img = self._normalize(img)
return img
[docs]
class CustomNormalizeImaged(MapTransform):
"""Dictionary-based wrapper of NormalizeImage"""
def __init__(
self,
keys: KeysCollection,
percentiles=[1, 99],
channel_wise: bool = False,
allow_missing_keys: bool = False,
):
super(CustomNormalizeImageD, self).__init__(keys, allow_missing_keys)
[docs]
self.normalizer = CustomNormalizeImage(percentiles, channel_wise)
def __call__(
self, data: Mapping[Hashable, np.ndarray]
) -> Dict[Hashable, np.ndarray]:
d = dict(data)
for key in self.keys:
d[key] = self.normalizer(d[key])
return d
CustomNormalizeImageD = CustomNormalizeImageDict = CustomNormalizeImaged