Source code for train_tools.models.MEDIARFormer
import torch
import torch.nn as nn
from segmentation_models_pytorch import MAnet
from segmentation_models_pytorch.base.modules import Activation
__all__ = ["MEDIARFormer"]
[docs]
class MEDIARFormer(MAnet):
"""MEDIAR-Former Model"""
def __init__(
self,
encoder_name="mit_b5", # Default encoder
encoder_weights="imagenet", # Pre-trained weights
decoder_channels=(1024, 512, 256, 128, 64), # Decoder configuration
decoder_pab_channels=256, # Decoder Pyramid Attention Block channels
in_channels=3, # Number of input channels
classes=3, # Number of output classes
):
# Initialize the MAnet model with provided parameters
super().__init__(
encoder_name=encoder_name,
encoder_weights=encoder_weights,
decoder_channels=decoder_channels,
decoder_pab_channels=decoder_pab_channels,
in_channels=in_channels,
classes=classes,
)
# Remove the default segmentation head as it's not used in this architecture
# Modify all activation functions in the encoder and decoder from ReLU to Mish
_convert_activations(self.encoder, nn.ReLU, nn.Mish(inplace=True))
_convert_activations(self.decoder, nn.ReLU, nn.Mish(inplace=True))
# Add custom segmentation heads for different segmentation tasks
[docs]
self.cellprob_head = DeepSegmentationHead(
in_channels=decoder_channels[-1], out_channels=1
)
[docs]
self.gradflow_head = DeepSegmentationHead(
in_channels=decoder_channels[-1], out_channels=2
)
[docs]
def forward(self, x):
"""Forward pass through the network"""
# Ensure the input shape is correct
self.check_input_shape(x)
# Encode the input and then decode it
features = self.encoder(x)
decoder_output = self.decoder(*features)
# Generate masks for cell probability and gradient flows
cellprob_mask = self.cellprob_head(decoder_output)
gradflow_mask = self.gradflow_head(decoder_output)
# Concatenate the masks for output
masks = torch.cat([gradflow_mask, cellprob_mask], dim=1)
return masks
class DeepSegmentationHead(nn.Sequential):
"""Custom segmentation head for generating specific masks"""
def __init__(
self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1
):
# Define a sequence of layers for the segmentation head
layers = [
nn.Conv2d(
in_channels,
in_channels // 2,
kernel_size=kernel_size,
padding=kernel_size // 2,
),
nn.Mish(inplace=True),
nn.BatchNorm2d(in_channels // 2),
nn.Conv2d(
in_channels // 2,
out_channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
),
nn.UpsamplingBilinear2d(scale_factor=upsampling)
if upsampling > 1
else nn.Identity(),
Activation(activation) if activation else nn.Identity(),
]
super().__init__(*layers)
def _convert_activations(module, from_activation, to_activation):
"""Recursively convert activation functions in a module"""
for name, child in module.named_children():
if isinstance(child, from_activation):
setattr(module, name, to_activation)
else:
_convert_activations(child, from_activation, to_activation)