-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Fix #8239: Enhance SoftclDiceLoss and SoftDiceclDiceLoss with DiceLoss-compatible API #8703
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,10 +11,17 @@ | |
|
|
||
| from __future__ import annotations | ||
|
|
||
| import warnings | ||
| from collections.abc import Callable | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| from torch.nn.modules.loss import _Loss | ||
|
|
||
| from monai.losses.dice import DiceLoss | ||
| from monai.networks import one_hot | ||
| from monai.utils import LossReduction | ||
|
|
||
|
|
||
| def soft_erode(img: torch.Tensor) -> torch.Tensor: # type: ignore | ||
| """ | ||
|
|
@@ -92,26 +99,6 @@ def soft_skel(img: torch.Tensor, iter_: int) -> torch.Tensor: | |
| return skel | ||
|
|
||
|
|
||
| def soft_dice(y_true: torch.Tensor, y_pred: torch.Tensor, smooth: float = 1.0) -> torch.Tensor: | ||
| """ | ||
| Function to compute soft dice loss | ||
|
|
||
| Adapted from: | ||
| https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L22 | ||
|
|
||
| Args: | ||
| y_true: the shape should be BCH(WD) | ||
| y_pred: the shape should be BCH(WD) | ||
|
|
||
| Returns: | ||
| dice loss | ||
| """ | ||
| intersection = torch.sum((y_true * y_pred)[:, 1:, ...]) | ||
| coeff = (2.0 * intersection + smooth) / (torch.sum(y_true[:, 1:, ...]) + torch.sum(y_pred[:, 1:, ...]) + smooth) | ||
| soft_dice: torch.Tensor = 1.0 - coeff | ||
| return soft_dice | ||
|
|
||
|
|
||
| class SoftclDiceLoss(_Loss): | ||
| """ | ||
| Compute the Soft clDice loss defined in: | ||
|
|
@@ -121,64 +108,245 @@ class SoftclDiceLoss(_Loss): | |
|
|
||
| Adapted from: | ||
| https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L7 | ||
|
|
||
| The data `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]). | ||
| Note that axis N of `input` is expected to be logits or probabilities for each class, if passing logits as input, | ||
| must set `sigmoid=True` or `softmax=True`, or specifying `other_act`. And the same axis of `target` | ||
| can be 1 or N (one-hot format). | ||
|
|
||
| """ | ||
|
|
||
| def __init__(self, iter_: int = 3, smooth: float = 1.0) -> None: | ||
| def __init__( | ||
| self, | ||
| iter_: int = 3, | ||
| smooth: float = 1.0, | ||
| include_background: bool = True, | ||
| to_onehot_y: bool = False, | ||
| sigmoid: bool = False, | ||
| softmax: bool = False, | ||
| other_act: Callable | None = None, | ||
| reduction: LossReduction | str = LossReduction.MEAN, | ||
| ) -> None: | ||
| """ | ||
| Args: | ||
| iter_: Number of iterations for skeletonization | ||
| smooth: Smoothing parameter | ||
| iter_: Number of iterations for skeletonization. | ||
| smooth: Smoothing parameter to avoid division by zero. Defaults to 1.0. | ||
| include_background: if False, channel index 0 (background category) is excluded from the calculation. | ||
| if the non-background segmentations are small compared to the total image size they can get overwhelmed | ||
| by the signal from the background so excluding it in such cases helps convergence. | ||
| to_onehot_y: whether to convert the ``target`` into the one-hot format, | ||
| using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False. | ||
| sigmoid: if True, apply a sigmoid function to the prediction. | ||
| softmax: if True, apply a softmax function to the prediction. | ||
| other_act: callable function to execute other activation layers, Defaults to ``None``. for example: | ||
| ``other_act = torch.tanh``. | ||
| reduction: {``"none"``, ``"mean"``, ``"sum"``} | ||
| Specifies the reduction to apply to the output. Defaults to ``"mean"``. | ||
|
|
||
| - ``"none"``: no reduction will be applied. | ||
| - ``"mean"``: the sum of the output will be divided by the number of elements in the output. | ||
| - ``"sum"``: the output will be summed. | ||
|
|
||
| Raises: | ||
| TypeError: When ``other_act`` is not an ``Optional[Callable]``. | ||
| ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. | ||
| Incompatible values. | ||
|
|
||
| """ | ||
| super().__init__() | ||
| super().__init__(reduction=LossReduction(reduction).value) | ||
| if other_act is not None and not callable(other_act): | ||
| raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.") | ||
| if int(sigmoid) + int(softmax) + int(other_act is not None) > 1: | ||
| raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].") | ||
| if smooth <= 0: | ||
| raise ValueError(f"smooth must be a positive value but got {smooth}.") | ||
| self.iter = iter_ | ||
| self.smooth = smooth | ||
| self.include_background = include_background | ||
| self.to_onehot_y = to_onehot_y | ||
| self.sigmoid = sigmoid | ||
| self.softmax = softmax | ||
| self.other_act = other_act | ||
|
|
||
| def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| """ | ||
| Args: | ||
| input: the shape should be BNH[WD], where N is the number of classes. | ||
| target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. | ||
|
|
||
| Raises: | ||
| AssertionError: When input and target (after one hot transform if set) | ||
| have different shapes. | ||
|
|
||
| """ | ||
| n_pred_ch = input.shape[1] | ||
|
|
||
| if self.sigmoid: | ||
| input = torch.sigmoid(input) | ||
|
|
||
| if self.softmax: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `softmax=True` ignored.") | ||
| else: | ||
| input = torch.softmax(input, dim=1) | ||
|
|
||
| def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: | ||
| skel_pred = soft_skel(y_pred, self.iter) | ||
| skel_true = soft_skel(y_true, self.iter) | ||
| tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / ( | ||
| torch.sum(skel_pred[:, 1:, ...]) + self.smooth | ||
| if self.other_act is not None: | ||
| input = self.other_act(input) | ||
|
|
||
| if self.to_onehot_y: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") | ||
| else: | ||
| target = one_hot(target, num_classes=n_pred_ch) | ||
|
|
||
| if not self.include_background: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `include_background=False` ignored.") | ||
| else: | ||
| target = target[:, 1:] | ||
| input = input[:, 1:] | ||
|
|
||
| if target.shape != input.shape: | ||
| raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") | ||
|
|
||
| skel_pred = soft_skel(input, self.iter) | ||
| skel_true = soft_skel(target, self.iter) | ||
|
|
||
| # Compute per-batch clDice by reducing over channel and spatial dimensions | ||
| # reduce_axis includes all dimensions except batch (dim 0) | ||
| reduce_axis: list[int] = list(range(1, len(input.shape))) | ||
|
|
||
| tprec = (torch.sum(torch.multiply(skel_pred, target), dim=reduce_axis) + self.smooth) / ( | ||
| torch.sum(skel_pred, dim=reduce_axis) + self.smooth | ||
| ) | ||
| tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / ( | ||
| torch.sum(skel_true[:, 1:, ...]) + self.smooth | ||
| tsens = (torch.sum(torch.multiply(skel_true, input), dim=reduce_axis) + self.smooth) / ( | ||
| torch.sum(skel_true, dim=reduce_axis) + self.smooth | ||
| ) | ||
| cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens) | ||
| cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens + 1e-8) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of hard-coding |
||
|
|
||
| # Apply reduction | ||
| if self.reduction == LossReduction.MEAN.value: | ||
| cl_dice = torch.mean(cl_dice) | ||
| elif self.reduction == LossReduction.SUM.value: | ||
| cl_dice = torch.sum(cl_dice) | ||
| elif self.reduction == LossReduction.NONE.value: | ||
| pass # keep per-batch values | ||
| else: | ||
| raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') | ||
|
|
||
| return cl_dice | ||
|
|
||
|
|
||
| class SoftDiceclDiceLoss(_Loss): | ||
| """ | ||
| Compute the Soft clDice loss defined in: | ||
| Compute both Dice loss and clDice loss, and return the weighted sum of these two losses. | ||
| The details of Dice loss is shown in ``monai.losses.DiceLoss``. | ||
| The details of clDice loss is shown in ``monai.losses.SoftclDiceLoss``. | ||
|
|
||
| Adapted from: | ||
| Shit et al. (2021) clDice -- A Novel Topology-Preserving Loss Function | ||
| for Tubular Structure Segmentation. (https://arxiv.org/abs/2003.07311) | ||
|
|
||
| Adapted from: | ||
| https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L38 | ||
| """ | ||
|
|
||
| def __init__(self, iter_: int = 3, alpha: float = 0.5, smooth: float = 1.0) -> None: | ||
| def __init__( | ||
| self, | ||
| iter_: int = 3, | ||
| alpha: float = 0.5, | ||
| smooth: float = 1.0, | ||
| include_background: bool = True, | ||
| to_onehot_y: bool = False, | ||
| sigmoid: bool = False, | ||
| softmax: bool = False, | ||
| other_act: Callable | None = None, | ||
| reduction: LossReduction | str = LossReduction.MEAN, | ||
| ) -> None: | ||
| """ | ||
| Args: | ||
| iter_: Number of iterations for skeletonization | ||
| smooth: Smoothing parameter | ||
| alpha: Weighing factor for cldice | ||
| iter_: Number of iterations for skeletonization, used by clDice. | ||
| alpha: Weighing factor for cldice component. Total loss = (1 - alpha) * dice + alpha * cldice. | ||
| Defaults to 0.5. | ||
| smooth: Smoothing parameter to avoid division by zero, used by both Dice and clDice. Defaults to 1.0. | ||
| include_background: if False, channel index 0 (background category) is excluded from the calculation. | ||
| if the non-background segmentations are small compared to the total image size they can get overwhelmed | ||
| by the signal from the background so excluding it in such cases helps convergence. | ||
| to_onehot_y: whether to convert the ``target`` into the one-hot format, | ||
| using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False. | ||
| sigmoid: if True, apply a sigmoid function to the prediction. | ||
| softmax: if True, apply a softmax function to the prediction. | ||
| other_act: callable function to execute other activation layers, Defaults to ``None``. for example: | ||
| ``other_act = torch.tanh``. | ||
| reduction: {``"none"``, ``"mean"``, ``"sum"``} | ||
| Specifies the reduction to apply to the output. Defaults to ``"mean"``. | ||
|
|
||
| - ``"none"``: no reduction will be applied. | ||
| - ``"mean"``: the sum of the output will be divided by the number of elements in the output. | ||
| - ``"sum"``: the output will be summed. | ||
|
|
||
| Raises: | ||
| TypeError: When ``other_act`` is not an ``Optional[Callable]``. | ||
| ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. | ||
| Incompatible values. | ||
|
|
||
| """ | ||
| super().__init__() | ||
| self.iter = iter_ | ||
| self.smooth = smooth | ||
| self.alpha = alpha | ||
|
|
||
| def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: | ||
| dice = soft_dice(y_true, y_pred, self.smooth) | ||
| skel_pred = soft_skel(y_pred, self.iter) | ||
| skel_true = soft_skel(y_true, self.iter) | ||
| tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / ( | ||
| torch.sum(skel_pred[:, 1:, ...]) + self.smooth | ||
| if smooth <= 0: | ||
| raise ValueError(f"smooth must be a positive value but got {smooth}.") | ||
| self.dice = DiceLoss( | ||
| include_background=include_background, | ||
| to_onehot_y=False, | ||
| sigmoid=sigmoid, | ||
| softmax=softmax, | ||
| other_act=other_act, | ||
| reduction=reduction, | ||
| smooth_nr=smooth, | ||
| smooth_dr=smooth, | ||
| ) | ||
| tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / ( | ||
| torch.sum(skel_true[:, 1:, ...]) + self.smooth | ||
| self.cldice = SoftclDiceLoss( | ||
| iter_=iter_, | ||
| smooth=smooth, | ||
| include_background=include_background, | ||
| to_onehot_y=False, | ||
| sigmoid=sigmoid, | ||
| softmax=softmax, | ||
| other_act=other_act, | ||
| reduction=reduction, | ||
| ) | ||
| cl_dice = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens) | ||
| total_loss: torch.Tensor = (1.0 - self.alpha) * dice + self.alpha * cl_dice | ||
| self.alpha = alpha | ||
| self.to_onehot_y = to_onehot_y | ||
|
|
||
| def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same with the names here. |
||
| """ | ||
| Args: | ||
| input: the shape should be BNH[WD], where N is the number of classes. | ||
| target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. | ||
|
|
||
| Raises: | ||
| ValueError: When number of dimensions for input and target are different. | ||
| ValueError: When number of channels for target is neither 1 nor the same as input. | ||
|
|
||
| """ | ||
| if input.dim() != target.dim(): | ||
| raise ValueError( | ||
| "the number of dimensions for input and target should be the same, " | ||
| f"got shape {input.shape} and {target.shape}." | ||
| ) | ||
|
|
||
| if target.shape[1] != 1 and target.shape[1] != input.shape[1]: | ||
| raise ValueError( | ||
| "number of channels for target is neither 1 nor the same as input, " | ||
| f"got shape {input.shape} and {target.shape}." | ||
| ) | ||
|
|
||
| if self.to_onehot_y: | ||
| n_pred_ch = input.shape[1] | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") | ||
| else: | ||
| target = one_hot(target, num_classes=n_pred_ch) | ||
|
|
||
| dice_loss = self.dice(input, target) | ||
| cldice_loss = self.cldice(input, target) | ||
| total_loss: torch.Tensor = (1.0 - self.alpha) * dice_loss + self.alpha * cldice_loss | ||
|
|
||
| return total_loss | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing validation for
iter_parameter.smoothis validated butiter_is not. A non-positive value would produce incorrect skeletonization.Proposed fix
if smooth <= 0: raise ValueError(f"smooth must be a positive value but got {smooth}.") + if iter_ < 0: + raise ValueError(f"iter_ must be a non-negative integer but got {iter_}.") self.iter = iter_🧰 Tools
🪛 Ruff (0.14.11)
162-162: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents