diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 948749606b..ec016a3235 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -13,7 +13,6 @@ import warnings from collections.abc import Callable, Sequence -from typing import Any import numpy as np import torch @@ -239,11 +238,52 @@ class MaskedDiceLoss(DiceLoss): """ - def __init__(self, *args: Any, **kwargs: Any) -> None: + def __init__( + self, + include_background: bool = True, + to_onehot_y: bool = False, + sigmoid: bool = False, + softmax: bool = False, + other_act: Callable | None = None, + squared_pred: bool = False, + jaccard: bool = False, + reduction: LossReduction | str = LossReduction.MEAN, + smooth_nr: float = 1e-5, + smooth_dr: float = 1e-5, + batch: bool = False, + weight: Sequence[float] | float | int | torch.Tensor | None = None, + soft_label: bool = False, + ) -> None: """ Args follow :py:class:`monai.losses.DiceLoss`. """ - super().__init__(*args, **kwargs) + if other_act is not None and not callable(other_act): + raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.") + if sigmoid and softmax: + raise ValueError("Incompatible values: sigmoid=True and softmax=True.") + if other_act is not None and (sigmoid or softmax): + raise ValueError("Incompatible values: other_act is not None and sigmoid=True or softmax=True.") + + self.pre_sigmoid = sigmoid + self.pre_softmax = softmax + self.pre_other_act = other_act + + super().__init__( + include_background=include_background, + to_onehot_y=to_onehot_y, + sigmoid=False, + softmax=False, + other_act=None, + squared_pred=squared_pred, + jaccard=jaccard, + reduction=reduction, + smooth_nr=smooth_nr, + smooth_dr=smooth_dr, + batch=batch, + weight=weight, + soft_label=soft_label, + ) + self.spatial_weighted = MaskedLoss(loss=super().forward) def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: @@ -253,6 +293,19 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor target: the shape should be BNH[WD]. mask: the shape should B1H[WD] or 11H[WD]. """ + + if self.pre_sigmoid: + input = torch.sigmoid(input) + + n_pred_ch = input.shape[1] + if self.pre_softmax: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `softmax=True` ignored.", stacklevel=2) + else: + input = torch.softmax(input, 1) + + if self.pre_other_act is not None: + input = self.pre_other_act(input) return self.spatial_weighted(input=input, target=target, mask=mask) # type: ignore[no-any-return] diff --git a/tests/losses/test_masked_dice_loss.py b/tests/losses/test_masked_dice_loss.py index c971723615..ea08254ee5 100644 --- a/tests/losses/test_masked_dice_loss.py +++ b/tests/losses/test_masked_dice_loss.py @@ -27,7 +27,7 @@ "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), "mask": torch.tensor([[[[0.0, 0.0], [1.0, 1.0]]]]), }, - 0.500, + 0.333333, ], [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) {"include_background": True, "sigmoid": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, @@ -36,7 +36,7 @@ "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), "mask": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 1.0], [0.0, 0.0]]]]), }, - 0.422969, + 0.301128, ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0}, @@ -54,7 +54,7 @@ "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), "mask": torch.tensor([[[1.0, 1.0, 0.0]]]), }, - 0.47033, + 0.579184, ], [ # shape: (2, 2, 3), (2, 1, 3) {