From adff16230ad4d4c4e32149eb0e72eea706b40d2e Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 16 Jan 2026 11:03:23 +0800 Subject: [PATCH 1/2] Adjust execution order of activation and masking in MaskedDiceLoss Signed-off-by: ytl0623 --- monai/losses/dice.py | 30 ++++++++++++++++++++++++++- tests/losses/test_masked_dice_loss.py | 6 +++--- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 948749606b..3d810bc1fe 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -244,7 +244,22 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: Args follow :py:class:`monai.losses.DiceLoss`. """ super().__init__(*args, **kwargs) - self.spatial_weighted = MaskedLoss(loss=super().forward) + self.dice = DiceLoss( + include_background=self.include_background, + to_onehot_y=self.to_onehot_y, + sigmoid=False, + softmax=False, + other_act=None, + squared_pred=self.squared_pred, + jaccard=self.jaccard, + reduction=self.reduction, + smooth_nr=self.smooth_nr, + smooth_dr=self.smooth_dr, + batch=self.batch, + weight=self.class_weight, + soft_label=self.soft_label, + ) + self.spatial_weighted = MaskedLoss(loss=self.dice.forward) def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: """ @@ -253,6 +268,19 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor target: the shape should be BNH[WD]. mask: the shape should B1H[WD] or 11H[WD]. """ + + if self.sigmoid: + input = torch.sigmoid(input) + + n_pred_ch = input.shape[1] + if self.softmax: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `softmax=True` ignored.") + else: + input = torch.softmax(input, 1) + + if self.other_act is not None: + input = self.other_act(input) return self.spatial_weighted(input=input, target=target, mask=mask) # type: ignore[no-any-return] diff --git a/tests/losses/test_masked_dice_loss.py b/tests/losses/test_masked_dice_loss.py index c971723615..ea08254ee5 100644 --- a/tests/losses/test_masked_dice_loss.py +++ b/tests/losses/test_masked_dice_loss.py @@ -27,7 +27,7 @@ "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), "mask": torch.tensor([[[[0.0, 0.0], [1.0, 1.0]]]]), }, - 0.500, + 0.333333, ], [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) {"include_background": True, "sigmoid": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, @@ -36,7 +36,7 @@ "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), "mask": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 1.0], [0.0, 0.0]]]]), }, - 0.422969, + 0.301128, ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0}, @@ -54,7 +54,7 @@ "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), "mask": torch.tensor([[[1.0, 1.0, 0.0]]]), }, - 0.47033, + 0.579184, ], [ # shape: (2, 2, 3), (2, 1, 3) { From 061f0d2844355c0a5b9d79bf6af4fe2aa976ef32 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Tue, 20 Jan 2026 10:25:10 +0800 Subject: [PATCH 2/2] fix activation order and inheritance Signed-off-by: ytl0623 --- monai/losses/dice.py | 65 ++++++++++++++++++++++++++++++-------------- 1 file changed, 45 insertions(+), 20 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 3d810bc1fe..ec016a3235 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -13,7 +13,6 @@ import warnings from collections.abc import Callable, Sequence -from typing import Any import numpy as np import torch @@ -239,27 +238,53 @@ class MaskedDiceLoss(DiceLoss): """ - def __init__(self, *args: Any, **kwargs: Any) -> None: + def __init__( + self, + include_background: bool = True, + to_onehot_y: bool = False, + sigmoid: bool = False, + softmax: bool = False, + other_act: Callable | None = None, + squared_pred: bool = False, + jaccard: bool = False, + reduction: LossReduction | str = LossReduction.MEAN, + smooth_nr: float = 1e-5, + smooth_dr: float = 1e-5, + batch: bool = False, + weight: Sequence[float] | float | int | torch.Tensor | None = None, + soft_label: bool = False, + ) -> None: """ Args follow :py:class:`monai.losses.DiceLoss`. """ - super().__init__(*args, **kwargs) - self.dice = DiceLoss( - include_background=self.include_background, - to_onehot_y=self.to_onehot_y, + if other_act is not None and not callable(other_act): + raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.") + if sigmoid and softmax: + raise ValueError("Incompatible values: sigmoid=True and softmax=True.") + if other_act is not None and (sigmoid or softmax): + raise ValueError("Incompatible values: other_act is not None and sigmoid=True or softmax=True.") + + self.pre_sigmoid = sigmoid + self.pre_softmax = softmax + self.pre_other_act = other_act + + super().__init__( + include_background=include_background, + to_onehot_y=to_onehot_y, sigmoid=False, softmax=False, other_act=None, - squared_pred=self.squared_pred, - jaccard=self.jaccard, - reduction=self.reduction, - smooth_nr=self.smooth_nr, - smooth_dr=self.smooth_dr, - batch=self.batch, - weight=self.class_weight, - soft_label=self.soft_label, + squared_pred=squared_pred, + jaccard=jaccard, + reduction=reduction, + smooth_nr=smooth_nr, + smooth_dr=smooth_dr, + batch=batch, + weight=weight, + soft_label=soft_label, ) - self.spatial_weighted = MaskedLoss(loss=self.dice.forward) + + self.spatial_weighted = MaskedLoss(loss=super().forward) def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: """ @@ -269,18 +294,18 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor mask: the shape should B1H[WD] or 11H[WD]. """ - if self.sigmoid: + if self.pre_sigmoid: input = torch.sigmoid(input) n_pred_ch = input.shape[1] - if self.softmax: + if self.pre_softmax: if n_pred_ch == 1: - warnings.warn("single channel prediction, `softmax=True` ignored.") + warnings.warn("single channel prediction, `softmax=True` ignored.", stacklevel=2) else: input = torch.softmax(input, 1) - if self.other_act is not None: - input = self.other_act(input) + if self.pre_other_act is not None: + input = self.pre_other_act(input) return self.spatial_weighted(input=input, target=target, mask=mask) # type: ignore[no-any-return]