Skip to content

Commit 063df92

Browse files
committed
Fix batch size broadcasting bug in GeneralizedWassersteinDiceLoss (#4650)
After `torch.gather`, `alpha_extended` retains shape (B, 1, S) while `wasserstein_distance_map` has shape (B, S). When batch size > 1 the element-wise multiply broadcasts to (B, B, S), mixing values across samples. Fixed by squeezing dim=1 after gather in both `_compute_generalized_true_positive` and `_compute_denominator`, and reducing with `dim=1` instead of `dim=[1, 2]`. Also fixed the `reduction="none"` code path which incorrectly tried to reshape the per-sample loss tensor (B,) to (B, C, 1, ...) — GWDL aggregates over classes internally so the class dimension doesn't apply. Added regression tests that verify batch consistency: - identical samples in a batch produce the same loss as a single sample - batched per-sample losses match individually computed losses Signed-off-by: hongjie-qiu <77599736+hongjie-qiu@users.noreply.github.com>
1 parent 2f10e18 commit 063df92

2 files changed

Lines changed: 96 additions & 6 deletions

File tree

monai/losses/dice.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -548,10 +548,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
548548
elif self.reduction == LossReduction.SUM.value:
549549
wass_dice_loss = torch.sum(wass_dice_loss) # sum over the batch and channel dims
550550
elif self.reduction == LossReduction.NONE.value:
551-
# If we are not computing voxelwise loss components at least
552-
# make sure a none reduction maintains a broadcastable shape
553-
broadcast_shape = input.shape[0:2] + (1,) * (len(input.shape) - 2)
554-
wass_dice_loss = wass_dice_loss.view(broadcast_shape)
551+
# GWDL aggregates over classes internally, so wass_dice_loss has shape (B,)
552+
pass
555553
else:
556554
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
557555

@@ -609,8 +607,9 @@ def _compute_generalized_true_positive(
609607
alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1)))
610608
flat_target_extended = torch.unsqueeze(flat_target, dim=1)
611609
alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1)
610+
alpha_extended = torch.squeeze(alpha_extended, dim=1)
612611

613-
return torch.sum(alpha_extended * (1.0 - wasserstein_distance_map), dim=[1, 2])
612+
return torch.sum(alpha_extended * (1.0 - wasserstein_distance_map), dim=1)
614613

615614
def _compute_denominator(
616615
self, alpha: torch.Tensor, flat_target: torch.Tensor, wasserstein_distance_map: torch.Tensor
@@ -626,8 +625,9 @@ def _compute_denominator(
626625
alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1)))
627626
flat_target_extended = torch.unsqueeze(flat_target, dim=1)
628627
alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1)
628+
alpha_extended = torch.squeeze(alpha_extended, dim=1)
629629

630-
return torch.sum(alpha_extended * (2.0 - wasserstein_distance_map), dim=[1, 2])
630+
return torch.sum(alpha_extended * (2.0 - wasserstein_distance_map), dim=1)
631631

632632
def _compute_alpha_generalized_true_positives(self, flat_target: torch.Tensor) -> torch.Tensor:
633633
"""

tests/losses/test_generalized_wasserstein_dice_loss.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,96 @@ def forward(self, x):
218218
# check that the predicted segmentation has improved
219219
self.assertGreater(diff_start, diff_end)
220220

221+
def test_batch_size_greater_than_one(self):
222+
"""
223+
Regression test for https://github.com/Project-MONAI/MONAI/issues/4650
224+
With M=identity and batch_size > 1, the GWDL should produce the same
225+
per-sample loss values as with batch_size=1.
226+
"""
227+
target_single = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]])
228+
target_single = target_single.unsqueeze(0) # shape (1, H, W)
229+
pred_single = 1000 * F.one_hot(target_single, num_classes=2).permute(0, 3, 1, 2).float()
230+
231+
# Create a batch of size 2 by repeating the same sample
232+
target_batch = target_single.repeat(2, 1, 1) # shape (2, H, W)
233+
pred_batch = pred_single.repeat(2, 1, 1, 1) # shape (2, C, H, W)
234+
235+
for w_mode in ["default", "GDL"]:
236+
loss_fn = GeneralizedWassersteinDiceLoss(
237+
dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]),
238+
weighting_mode=w_mode,
239+
reduction="none",
240+
)
241+
242+
loss_single = loss_fn(pred_single, target_single)
243+
loss_batch = loss_fn(pred_batch, target_batch)
244+
245+
# Each sample in the batch should produce the same loss as the single sample
246+
for i in range(2):
247+
self.assertAlmostEqual(
248+
float(loss_batch[i]),
249+
float(loss_single[0]),
250+
places=5,
251+
msg=f"Batch loss[{i}] != single loss for weighting_mode={w_mode}",
252+
)
253+
254+
# Also test with mean reduction: batch loss should equal single-sample loss
255+
for w_mode in ["default", "GDL"]:
256+
loss_fn = GeneralizedWassersteinDiceLoss(
257+
dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]),
258+
weighting_mode=w_mode,
259+
reduction="mean",
260+
)
261+
262+
loss_single = float(loss_fn(pred_single, target_single))
263+
loss_batch = float(loss_fn(pred_batch, target_batch))
264+
265+
self.assertAlmostEqual(
266+
loss_batch,
267+
loss_single,
268+
places=5,
269+
msg=f"Batch mean loss != single mean loss for weighting_mode={w_mode}",
270+
)
271+
272+
def test_batch_size_different_samples(self):
273+
"""
274+
Regression test for https://github.com/Project-MONAI/MONAI/issues/4650
275+
Verify loss is computed correctly when batch contains different samples.
276+
"""
277+
target_a = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]).unsqueeze(0)
278+
target_b = torch.tensor([[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]).unsqueeze(0)
279+
280+
pred_a = 1000 * F.one_hot(target_a, num_classes=2).permute(0, 3, 1, 2).float()
281+
pred_b = 1000 * F.one_hot(target_b, num_classes=2).permute(0, 3, 1, 2).float()
282+
283+
# Combine into a batch
284+
target_batch = torch.cat([target_a, target_b], dim=0)
285+
pred_batch = torch.cat([pred_a, pred_b], dim=0)
286+
287+
for w_mode in ["default", "GDL"]:
288+
loss_fn = GeneralizedWassersteinDiceLoss(
289+
dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]),
290+
weighting_mode=w_mode,
291+
reduction="none",
292+
)
293+
294+
loss_a = float(loss_fn(pred_a, target_a))
295+
loss_b = float(loss_fn(pred_b, target_b))
296+
loss_batch = loss_fn(pred_batch, target_batch)
297+
298+
self.assertAlmostEqual(
299+
float(loss_batch[0]),
300+
loss_a,
301+
places=5,
302+
msg=f"Batch loss[0] != loss_a for weighting_mode={w_mode}",
303+
)
304+
self.assertAlmostEqual(
305+
float(loss_batch[1]),
306+
loss_b,
307+
places=5,
308+
msg=f"Batch loss[1] != loss_b for weighting_mode={w_mode}",
309+
)
310+
221311
def test_script(self):
222312
target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]])
223313

0 commit comments

Comments
 (0)