Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,10 +548,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
elif self.reduction == LossReduction.SUM.value:
wass_dice_loss = torch.sum(wass_dice_loss) # sum over the batch and channel dims
elif self.reduction == LossReduction.NONE.value:
# If we are not computing voxelwise loss components at least
# make sure a none reduction maintains a broadcastable shape
broadcast_shape = input.shape[0:2] + (1,) * (len(input.shape) - 2)
wass_dice_loss = wass_dice_loss.view(broadcast_shape)
# GWDL aggregates over classes internally, so wass_dice_loss has shape (B,)
pass
else:
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')

Expand Down Expand Up @@ -609,8 +607,9 @@ def _compute_generalized_true_positive(
alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1)))
flat_target_extended = torch.unsqueeze(flat_target, dim=1)
alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1)
alpha_extended = torch.squeeze(alpha_extended, dim=1)

return torch.sum(alpha_extended * (1.0 - wasserstein_distance_map), dim=[1, 2])
return torch.sum(alpha_extended * (1.0 - wasserstein_distance_map), dim=1)

def _compute_denominator(
self, alpha: torch.Tensor, flat_target: torch.Tensor, wasserstein_distance_map: torch.Tensor
Expand All @@ -626,8 +625,9 @@ def _compute_denominator(
alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1)))
flat_target_extended = torch.unsqueeze(flat_target, dim=1)
alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1)
alpha_extended = torch.squeeze(alpha_extended, dim=1)

return torch.sum(alpha_extended * (2.0 - wasserstein_distance_map), dim=[1, 2])
return torch.sum(alpha_extended * (2.0 - wasserstein_distance_map), dim=1)

def _compute_alpha_generalized_true_positives(self, flat_target: torch.Tensor) -> torch.Tensor:
"""
Expand Down
85 changes: 85 additions & 0 deletions tests/losses/test_generalized_wasserstein_dice_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,91 @@ def forward(self, x):
# check that the predicted segmentation has improved
self.assertGreater(diff_start, diff_end)

def test_batch_size_greater_than_one(self):
"""
Regression test for https://github.com/Project-MONAI/MONAI/issues/4650
With M=identity and batch_size > 1, the GWDL should produce the same
per-sample loss values as with batch_size=1.
"""
target_single = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]])
target_single = target_single.unsqueeze(0) # shape (1, H, W)
pred_single = 1000 * F.one_hot(target_single, num_classes=2).permute(0, 3, 1, 2).float()

# Create a batch of size 2 by repeating the same sample
target_batch = target_single.repeat(2, 1, 1) # shape (2, H, W)
pred_batch = pred_single.repeat(2, 1, 1, 1) # shape (2, C, H, W)

for w_mode in ["default", "GDL"]:
loss_fn = GeneralizedWassersteinDiceLoss(
dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode=w_mode, reduction="none"
)

loss_single = loss_fn(pred_single, target_single)
loss_batch = loss_fn(pred_batch, target_batch)

# Each sample in the batch should produce the same loss as the single sample
for i in range(2):
self.assertAlmostEqual(
float(loss_batch[i]),
float(loss_single[0]),
places=5,
msg=f"Batch loss[{i}] != single loss for weighting_mode={w_mode}",
)

# Also test with mean reduction using a non-trivial (poor) prediction
# so the expected loss is not near zero
pred_poor = 1000 * F.one_hot(1 - target_single, num_classes=2).permute(0, 3, 1, 2).float()
pred_poor_batch = pred_poor.repeat(2, 1, 1, 1)

for w_mode in ["default", "GDL"]:
loss_fn = GeneralizedWassersteinDiceLoss(
dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode=w_mode, reduction="mean"
)

loss_single = float(loss_fn(pred_poor, target_single))
loss_batch = float(loss_fn(pred_poor_batch, target_batch))

# Verify the loss is non-trivial (close to 1 for poor predictions)
self.assertGreater(loss_single, 0.5, msg=f"Expected non-trivial loss for weighting_mode={w_mode}")
self.assertAlmostEqual(
loss_batch,
loss_single,
places=5,
msg=f"Batch mean loss != single mean loss for weighting_mode={w_mode}",
)

def test_batch_size_different_samples(self):
"""
Regression test for https://github.com/Project-MONAI/MONAI/issues/4650
Verify loss is computed correctly when batch contains different samples.
"""
target_a = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]).unsqueeze(0)
target_b = torch.tensor([[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]).unsqueeze(0)

pred_a = 1000 * F.one_hot(target_a, num_classes=2).permute(0, 3, 1, 2).float()
# Use a poor prediction for sample b so its loss is non-trivial (~1.0)
pred_b = 1000 * F.one_hot(1 - target_b, num_classes=2).permute(0, 3, 1, 2).float()

# Combine into a batch
target_batch = torch.cat([target_a, target_b], dim=0)
pred_batch = torch.cat([pred_a, pred_b], dim=0)

for w_mode in ["default", "GDL"]:
loss_fn = GeneralizedWassersteinDiceLoss(
dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode=w_mode, reduction="none"
)

loss_a = float(loss_fn(pred_a, target_a))
loss_b = float(loss_fn(pred_b, target_b))
loss_batch = loss_fn(pred_batch, target_batch)

self.assertAlmostEqual(
float(loss_batch[0]), loss_a, places=5, msg=f"Batch loss[0] != loss_a for weighting_mode={w_mode}"
)
self.assertAlmostEqual(
float(loss_batch[1]), loss_b, places=5, msg=f"Batch loss[1] != loss_b for weighting_mode={w_mode}"
)

def test_script(self):
target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]])

Expand Down
Loading