Source code for train_tools.data_utils.transforms

from .custom import *

from monai.transforms import *

__all__ = [
    "train_transforms",
    "public_transforms",
    "valid_transforms",
    "tuning_transforms",
    "unlabeled_transforms",
]

[docs] train_transforms = Compose( [ # >>> Load and refine data --- img: (H, W, 3); label: (H, W) CustomLoadImaged(keys=["img", "label"], image_only=True), CustomNormalizeImaged( keys=["img"], allow_missing_keys=True, channel_wise=False, percentiles=[0.0, 99.5], ), EnsureChannelFirstd(keys=["img", "label"], channel_dim=-1), RemoveRepeatedChanneld(keys=["label"], repeats=3), # label: (H, W) ScaleIntensityd(keys=["img"], allow_missing_keys=True), # Do not scale label # >>> Spatial transforms RandZoomd( keys=["img", "label"], prob=0.5, min_zoom=0.25, max_zoom=1.5, mode=["area", "nearest"], keep_size=False, ), SpatialPadd(keys=["img", "label"], spatial_size=512), RandSpatialCropd(keys=["img", "label"], roi_size=512, random_size=False), RandAxisFlipd(keys=["img", "label"], prob=0.5), RandRotate90d(keys=["img", "label"], prob=0.5, spatial_axes=[0, 1]), IntensityDiversification(keys=["img", "label"], allow_missing_keys=True), # # >>> Intensity transforms RandGaussianNoised(keys=["img"], prob=0.25, mean=0, std=0.1), RandAdjustContrastd(keys=["img"], prob=0.25, gamma=(1, 2)), RandGaussianSmoothd(keys=["img"], prob=0.25, sigma_x=(1, 2)), RandHistogramShiftd(keys=["img"], prob=0.25, num_control_points=3), RandGaussianSharpend(keys=["img"], prob=0.25), EnsureTyped(keys=["img", "label"]), ] )
[docs] public_transforms = Compose( [ CustomLoadImaged(keys=["img", "label"], image_only=True), BoundaryExclusion(keys=["label"]), CustomNormalizeImaged( keys=["img"], allow_missing_keys=True, channel_wise=False, percentiles=[0.0, 99.5], ), EnsureChannelFirstd(keys=["img", "label"], channel_dim=-1), RemoveRepeatedChanneld(keys=["label"], repeats=3), # label: (H, W) ScaleIntensityd(keys=["img"], allow_missing_keys=True), # Do not scale label # >>> Spatial transforms SpatialPadd(keys=["img", "label"], spatial_size=512), RandSpatialCropd(keys=["img", "label"], roi_size=512, random_size=False), RandAxisFlipd(keys=["img", "label"], prob=0.5), RandRotate90d(keys=["img", "label"], prob=0.5, spatial_axes=[0, 1]), Rotate90d(k=1, keys=["label"], spatial_axes=(0, 1)), Flipd(keys=["label"], spatial_axis=0), EnsureTyped(keys=["img", "label"]), ] )
[docs] valid_transforms = Compose( [ CustomLoadImaged(keys=["img", "label"], allow_missing_keys=True, image_only=True), CustomNormalizeImaged( keys=["img"], allow_missing_keys=True, channel_wise=False, percentiles=[0.0, 99.5], ), EnsureChannelFirstd(keys=["img", "label"], allow_missing_keys=True, channel_dim=-1), RemoveRepeatedChanneld(keys=["label"], repeats=3), ScaleIntensityd(keys=["img"], allow_missing_keys=True), EnsureTyped(keys=["img", "label"], allow_missing_keys=True), ] )
[docs] tuning_transforms = Compose( [ CustomLoadImaged(keys=["img"], image_only=True), CustomNormalizeImaged( keys=["img"], allow_missing_keys=True, channel_wise=False, percentiles=[0.0, 99.5], ), EnsureChannelFirstd(keys=["img"], channel_dim=-1), ScaleIntensityd(keys=["img"]), EnsureTyped(keys=["img"]), ] )
[docs] unlabeled_transforms = Compose( [ # >>> Load and refine data --- img: (H, W, 3); label: (H, W) CustomLoadImaged(keys=["img"], image_only=True), CustomNormalizeImaged( keys=["img"], allow_missing_keys=True, channel_wise=False, percentiles=[0.0, 99.5], ), EnsureChannelFirstd(keys=["img"], channel_dim=-1), RandZoomd( keys=["img"], prob=0.5, min_zoom=0.25, max_zoom=1.25, mode=["area"], keep_size=False, ), ScaleIntensityd(keys=["img"], allow_missing_keys=True), # Do not scale label # >>> Spatial transforms SpatialPadd(keys=["img"], spatial_size=512), RandSpatialCropd(keys=["img"], roi_size=512, random_size=False), EnsureTyped(keys=["img"]), ] )
def get_pred_transforms(): """Prediction preprocessing""" pred_transforms = Compose( [ # >>> Load and refine data CustomLoadImage(image_only=True), CustomNormalizeImage(channel_wise=False, percentiles=[0.0, 99.5]), EnsureChannelFirst(channel_dim=-1), # image: (3, H, W) ScaleIntensity(), EnsureType(data_type="tensor"), ] ) return pred_transforms