From 4887d9d53d26e0bf4907ec20a1550a1eff5b78d4 Mon Sep 17 00:00:00 2001 From: hongjie-qiu <77599736+hongjie-qiu@users.noreply.github.com> Date: Thu, 19 Feb 2026 12:16:44 -0500 Subject: [PATCH 1/2] Fix batch size broadcasting bug in GeneralizedWassersteinDiceLoss (#4650) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit After `torch.gather`, `alpha_extended` retains shape (B, 1, S) while `wasserstein_distance_map` has shape (B, S). When batch size > 1 the element-wise multiply broadcasts to (B, B, S), mixing values across samples. Fixed by squeezing dim=1 after gather in both `_compute_generalized_true_positive` and `_compute_denominator`, and reducing with `dim=1` instead of `dim=[1, 2]`. Also fixed the `reduction="none"` code path which incorrectly tried to reshape the per-sample loss tensor (B,) to (B, C, 1, ...) — GWDL aggregates over classes internally so the class dimension doesn't apply. Added regression tests that verify batch consistency: - identical samples in a batch produce the same loss as a single sample - batched per-sample losses match individually computed losses Signed-off-by: hongjie-qiu <77599736+hongjie-qiu@users.noreply.github.com> --- monai/losses/dice.py | 12 +-- .../test_generalized_wasserstein_dice_loss.py | 78 +++++++++++++++++++ 2 files changed, 84 insertions(+), 6 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 948749606b..72ee1f27ee 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -548,10 +548,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: elif self.reduction == LossReduction.SUM.value: wass_dice_loss = torch.sum(wass_dice_loss) # sum over the batch and channel dims elif self.reduction == LossReduction.NONE.value: - # If we are not computing voxelwise loss components at least - # make sure a none reduction maintains a broadcastable shape - broadcast_shape = input.shape[0:2] + (1,) * (len(input.shape) - 2) - wass_dice_loss = wass_dice_loss.view(broadcast_shape) + # GWDL aggregates over classes internally, so wass_dice_loss has shape (B,) + pass else: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') @@ -609,8 +607,9 @@ def _compute_generalized_true_positive( alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1))) flat_target_extended = torch.unsqueeze(flat_target, dim=1) alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1) + alpha_extended = torch.squeeze(alpha_extended, dim=1) - return torch.sum(alpha_extended * (1.0 - wasserstein_distance_map), dim=[1, 2]) + return torch.sum(alpha_extended * (1.0 - wasserstein_distance_map), dim=1) def _compute_denominator( self, alpha: torch.Tensor, flat_target: torch.Tensor, wasserstein_distance_map: torch.Tensor @@ -626,8 +625,9 @@ def _compute_denominator( alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1))) flat_target_extended = torch.unsqueeze(flat_target, dim=1) alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1) + alpha_extended = torch.squeeze(alpha_extended, dim=1) - return torch.sum(alpha_extended * (2.0 - wasserstein_distance_map), dim=[1, 2]) + return torch.sum(alpha_extended * (2.0 - wasserstein_distance_map), dim=1) def _compute_alpha_generalized_true_positives(self, flat_target: torch.Tensor) -> torch.Tensor: """ diff --git a/tests/losses/test_generalized_wasserstein_dice_loss.py b/tests/losses/test_generalized_wasserstein_dice_loss.py index 6868c04775..9e10b252e3 100644 --- a/tests/losses/test_generalized_wasserstein_dice_loss.py +++ b/tests/losses/test_generalized_wasserstein_dice_loss.py @@ -218,6 +218,84 @@ def forward(self, x): # check that the predicted segmentation has improved self.assertGreater(diff_start, diff_end) + def test_batch_size_greater_than_one(self): + """ + Regression test for https://github.com/Project-MONAI/MONAI/issues/4650 + With M=identity and batch_size > 1, the GWDL should produce the same + per-sample loss values as with batch_size=1. + """ + target_single = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) + target_single = target_single.unsqueeze(0) # shape (1, H, W) + pred_single = 1000 * F.one_hot(target_single, num_classes=2).permute(0, 3, 1, 2).float() + + # Create a batch of size 2 by repeating the same sample + target_batch = target_single.repeat(2, 1, 1) # shape (2, H, W) + pred_batch = pred_single.repeat(2, 1, 1, 1) # shape (2, C, H, W) + + for w_mode in ["default", "GDL"]: + loss_fn = GeneralizedWassersteinDiceLoss( + dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode=w_mode, reduction="none" + ) + + loss_single = loss_fn(pred_single, target_single) + loss_batch = loss_fn(pred_batch, target_batch) + + # Each sample in the batch should produce the same loss as the single sample + for i in range(2): + self.assertAlmostEqual( + float(loss_batch[i]), + float(loss_single[0]), + places=5, + msg=f"Batch loss[{i}] != single loss for weighting_mode={w_mode}", + ) + + # Also test with mean reduction: batch loss should equal single-sample loss + for w_mode in ["default", "GDL"]: + loss_fn = GeneralizedWassersteinDiceLoss( + dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode=w_mode, reduction="mean" + ) + + loss_single = float(loss_fn(pred_single, target_single)) + loss_batch = float(loss_fn(pred_batch, target_batch)) + + self.assertAlmostEqual( + loss_batch, + loss_single, + places=5, + msg=f"Batch mean loss != single mean loss for weighting_mode={w_mode}", + ) + + def test_batch_size_different_samples(self): + """ + Regression test for https://github.com/Project-MONAI/MONAI/issues/4650 + Verify loss is computed correctly when batch contains different samples. + """ + target_a = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]).unsqueeze(0) + target_b = torch.tensor([[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]).unsqueeze(0) + + pred_a = 1000 * F.one_hot(target_a, num_classes=2).permute(0, 3, 1, 2).float() + pred_b = 1000 * F.one_hot(target_b, num_classes=2).permute(0, 3, 1, 2).float() + + # Combine into a batch + target_batch = torch.cat([target_a, target_b], dim=0) + pred_batch = torch.cat([pred_a, pred_b], dim=0) + + for w_mode in ["default", "GDL"]: + loss_fn = GeneralizedWassersteinDiceLoss( + dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode=w_mode, reduction="none" + ) + + loss_a = float(loss_fn(pred_a, target_a)) + loss_b = float(loss_fn(pred_b, target_b)) + loss_batch = loss_fn(pred_batch, target_batch) + + self.assertAlmostEqual( + float(loss_batch[0]), loss_a, places=5, msg=f"Batch loss[0] != loss_a for weighting_mode={w_mode}" + ) + self.assertAlmostEqual( + float(loss_batch[1]), loss_b, places=5, msg=f"Batch loss[1] != loss_b for weighting_mode={w_mode}" + ) + def test_script(self): target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) From d8b8b019c7cbcc17b6931616ca655a0009996b31 Mon Sep 17 00:00:00 2001 From: hongjie-qiu <77599736+hongjie-qiu@users.noreply.github.com> Date: Fri, 20 Feb 2026 17:38:32 -0500 Subject: [PATCH 2/2] Strengthen regression tests with non-trivial loss assertions Address review feedback: use poor predictions in mean-reduction and different-samples tests so the expected loss values are non-trivial (~1.0 instead of ~0.0), ensuring the assertions are meaningful. Signed-off-by: Jeffrey Qiu Signed-off-by: hongjie-qiu <77599736+hongjie-qiu@users.noreply.github.com> --- .../test_generalized_wasserstein_dice_loss.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/losses/test_generalized_wasserstein_dice_loss.py b/tests/losses/test_generalized_wasserstein_dice_loss.py index 9e10b252e3..1ed05c137d 100644 --- a/tests/losses/test_generalized_wasserstein_dice_loss.py +++ b/tests/losses/test_generalized_wasserstein_dice_loss.py @@ -249,15 +249,21 @@ def test_batch_size_greater_than_one(self): msg=f"Batch loss[{i}] != single loss for weighting_mode={w_mode}", ) - # Also test with mean reduction: batch loss should equal single-sample loss + # Also test with mean reduction using a non-trivial (poor) prediction + # so the expected loss is not near zero + pred_poor = 1000 * F.one_hot(1 - target_single, num_classes=2).permute(0, 3, 1, 2).float() + pred_poor_batch = pred_poor.repeat(2, 1, 1, 1) + for w_mode in ["default", "GDL"]: loss_fn = GeneralizedWassersteinDiceLoss( dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode=w_mode, reduction="mean" ) - loss_single = float(loss_fn(pred_single, target_single)) - loss_batch = float(loss_fn(pred_batch, target_batch)) + loss_single = float(loss_fn(pred_poor, target_single)) + loss_batch = float(loss_fn(pred_poor_batch, target_batch)) + # Verify the loss is non-trivial (close to 1 for poor predictions) + self.assertGreater(loss_single, 0.5, msg=f"Expected non-trivial loss for weighting_mode={w_mode}") self.assertAlmostEqual( loss_batch, loss_single, @@ -274,7 +280,8 @@ def test_batch_size_different_samples(self): target_b = torch.tensor([[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]).unsqueeze(0) pred_a = 1000 * F.one_hot(target_a, num_classes=2).permute(0, 3, 1, 2).float() - pred_b = 1000 * F.one_hot(target_b, num_classes=2).permute(0, 3, 1, 2).float() + # Use a poor prediction for sample b so its loss is non-trivial (~1.0) + pred_b = 1000 * F.one_hot(1 - target_b, num_classes=2).permute(0, 3, 1, 2).float() # Combine into a batch target_batch = torch.cat([target_a, target_b], dim=0)