diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 948749606b..72ee1f27ee 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -548,10 +548,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: elif self.reduction == LossReduction.SUM.value: wass_dice_loss = torch.sum(wass_dice_loss) # sum over the batch and channel dims elif self.reduction == LossReduction.NONE.value: - # If we are not computing voxelwise loss components at least - # make sure a none reduction maintains a broadcastable shape - broadcast_shape = input.shape[0:2] + (1,) * (len(input.shape) - 2) - wass_dice_loss = wass_dice_loss.view(broadcast_shape) + # GWDL aggregates over classes internally, so wass_dice_loss has shape (B,) + pass else: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') @@ -609,8 +607,9 @@ def _compute_generalized_true_positive( alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1))) flat_target_extended = torch.unsqueeze(flat_target, dim=1) alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1) + alpha_extended = torch.squeeze(alpha_extended, dim=1) - return torch.sum(alpha_extended * (1.0 - wasserstein_distance_map), dim=[1, 2]) + return torch.sum(alpha_extended * (1.0 - wasserstein_distance_map), dim=1) def _compute_denominator( self, alpha: torch.Tensor, flat_target: torch.Tensor, wasserstein_distance_map: torch.Tensor @@ -626,8 +625,9 @@ def _compute_denominator( alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1))) flat_target_extended = torch.unsqueeze(flat_target, dim=1) alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1) + alpha_extended = torch.squeeze(alpha_extended, dim=1) - return torch.sum(alpha_extended * (2.0 - wasserstein_distance_map), dim=[1, 2]) + return torch.sum(alpha_extended * (2.0 - wasserstein_distance_map), dim=1) def _compute_alpha_generalized_true_positives(self, flat_target: torch.Tensor) -> torch.Tensor: """ diff --git a/tests/losses/test_generalized_wasserstein_dice_loss.py b/tests/losses/test_generalized_wasserstein_dice_loss.py index 6868c04775..1ed05c137d 100644 --- a/tests/losses/test_generalized_wasserstein_dice_loss.py +++ b/tests/losses/test_generalized_wasserstein_dice_loss.py @@ -218,6 +218,91 @@ def forward(self, x): # check that the predicted segmentation has improved self.assertGreater(diff_start, diff_end) + def test_batch_size_greater_than_one(self): + """ + Regression test for https://github.com/Project-MONAI/MONAI/issues/4650 + With M=identity and batch_size > 1, the GWDL should produce the same + per-sample loss values as with batch_size=1. + """ + target_single = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) + target_single = target_single.unsqueeze(0) # shape (1, H, W) + pred_single = 1000 * F.one_hot(target_single, num_classes=2).permute(0, 3, 1, 2).float() + + # Create a batch of size 2 by repeating the same sample + target_batch = target_single.repeat(2, 1, 1) # shape (2, H, W) + pred_batch = pred_single.repeat(2, 1, 1, 1) # shape (2, C, H, W) + + for w_mode in ["default", "GDL"]: + loss_fn = GeneralizedWassersteinDiceLoss( + dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode=w_mode, reduction="none" + ) + + loss_single = loss_fn(pred_single, target_single) + loss_batch = loss_fn(pred_batch, target_batch) + + # Each sample in the batch should produce the same loss as the single sample + for i in range(2): + self.assertAlmostEqual( + float(loss_batch[i]), + float(loss_single[0]), + places=5, + msg=f"Batch loss[{i}] != single loss for weighting_mode={w_mode}", + ) + + # Also test with mean reduction using a non-trivial (poor) prediction + # so the expected loss is not near zero + pred_poor = 1000 * F.one_hot(1 - target_single, num_classes=2).permute(0, 3, 1, 2).float() + pred_poor_batch = pred_poor.repeat(2, 1, 1, 1) + + for w_mode in ["default", "GDL"]: + loss_fn = GeneralizedWassersteinDiceLoss( + dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode=w_mode, reduction="mean" + ) + + loss_single = float(loss_fn(pred_poor, target_single)) + loss_batch = float(loss_fn(pred_poor_batch, target_batch)) + + # Verify the loss is non-trivial (close to 1 for poor predictions) + self.assertGreater(loss_single, 0.5, msg=f"Expected non-trivial loss for weighting_mode={w_mode}") + self.assertAlmostEqual( + loss_batch, + loss_single, + places=5, + msg=f"Batch mean loss != single mean loss for weighting_mode={w_mode}", + ) + + def test_batch_size_different_samples(self): + """ + Regression test for https://github.com/Project-MONAI/MONAI/issues/4650 + Verify loss is computed correctly when batch contains different samples. + """ + target_a = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]).unsqueeze(0) + target_b = torch.tensor([[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]).unsqueeze(0) + + pred_a = 1000 * F.one_hot(target_a, num_classes=2).permute(0, 3, 1, 2).float() + # Use a poor prediction for sample b so its loss is non-trivial (~1.0) + pred_b = 1000 * F.one_hot(1 - target_b, num_classes=2).permute(0, 3, 1, 2).float() + + # Combine into a batch + target_batch = torch.cat([target_a, target_b], dim=0) + pred_batch = torch.cat([pred_a, pred_b], dim=0) + + for w_mode in ["default", "GDL"]: + loss_fn = GeneralizedWassersteinDiceLoss( + dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode=w_mode, reduction="none" + ) + + loss_a = float(loss_fn(pred_a, target_a)) + loss_b = float(loss_fn(pred_b, target_b)) + loss_batch = loss_fn(pred_batch, target_batch) + + self.assertAlmostEqual( + float(loss_batch[0]), loss_a, places=5, msg=f"Batch loss[0] != loss_a for weighting_mode={w_mode}" + ) + self.assertAlmostEqual( + float(loss_batch[1]), loss_b, places=5, msg=f"Batch loss[1] != loss_b for weighting_mode={w_mode}" + ) + def test_script(self): target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]])