-
Notifications
You must be signed in to change notification settings - Fork 540
Description
Hi, and thanks for maintaining this great library!
I'm currently using POT (with PyTorch backend) to compute OT-based losses.
I noticed that ot.emd2 seems not differentiable in the usual PyTorch sense,
but when I inspect the .grad field of the cost matrix M, I still get non-zero gradients after calling loss.backward().
So I would like to confirm:
Is ot.emd2 actually differentiable and supports autograd backward propagation through M or input distributions a, b?
Or are the observed gradients just residual tensors from detached operations (i.e., not truly backpropagated through the OT plan solver)?
Minimal reproducible example
import torch
import ot
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
n = 10
a = torch.rand(n, device=device)
a = a / a.sum()
b = torch.rand(n, device=device)
b = b / b.sum()
M = torch.randn(n, n, device=device, requires_grad=True)
loss_emd2 = ot.emd2(a, b, M)
loss_emd2.backward()
grad_emd2 = M.grad.clone()
print("EMD2 loss:", loss_emd2.item())
print("EMD2 grad:\n", grad_emd2)
M.grad.zero_()
reg = 0.1
loss_sinkhorn = ot.sinkhorn2(a, b, M, reg)
loss_sinkhorn.backward()
grad_sinkhorn = M.grad.clone()
print("\nSinkhorn loss:", loss_sinkhorn.item())
print("Sinkhorn grad:\n", grad_sinkhorn)
OUTPUT:
EMD2 loss: -1.5605539083480835
EMD2 grad:
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0427],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0095, 0.1873, 0.0095,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0252, 0.0000, 0.0233, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0925, 0.0000, 0.0000, 0.0000, 0.0000, 0.0061,
0.0987],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0481, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0552, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0126, 0.0000, 0.0000, 0.0000, 0.0000, 0.0532, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0765, 0.0000, 0.0000, 0.0548, 0.0000, 0.0000, 0.0000,
0.0000],
[0.1010, 0.0000, 0.0175, 0.0000, 0.0020, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0844, 0.0000, 0.0000, 0.0000,
0.0000]], device='cuda:0')
Sinkhorn loss: -1.5583062171936035
Sinkhorn grad:
tensor([[-1.0759e-15, -4.5089e-05, -2.3012e-06, -4.7454e-05, -3.1261e-15,
-8.1661e-06, -8.5109e-06, -2.6828e-12, -5.4352e-05, 4.2903e-02],
[-1.3629e-10, -5.7191e-13, -5.8734e-08, -1.4337e-03, -2.9875e-06,
-5.2210e-04, 1.2639e-02, 1.8735e-01, 8.2132e-03, -2.8584e-06],
[-1.3665e-12, -5.0650e-08, -4.7080e-04, -9.4670e-06, 2.7954e-02,
-4.7037e-09, 2.1273e-02, -9.7631e-12, -2.6613e-04, -1.6820e-12],
[-7.8764e-13, -1.0603e-12, -1.3319e-03, 9.4009e-02, -1.1775e-03,
-4.5318e-04, -6.8939e-07, -1.8201e-14, 7.6663e-03, 9.8547e-02],
[-6.4536e-11, -5.2054e-08, -3.9438e-09, -5.9460e-09, -1.3147e-07,
-1.8057e-07, 4.8088e-02, -3.1984e-11, -1.9308e-09, -2.1640e-10],
[-4.6134e-17, -7.1095e-07, 5.5260e-02, -3.3341e-17, -4.1714e-20,
-5.8153e-07, -2.3840e-05, -2.2943e-13, -7.5624e-06, -8.0583e-10],
[-1.0100e-13, 1.3838e-02, -1.4756e-08, -1.8981e-10, -2.9864e-12,
-8.6449e-05, 5.2120e-02, -1.4227e-16, -3.6979e-08, -2.9439e-11],
[-2.2352e-07, -1.1892e-05, 7.6544e-02, -2.5317e-08, -1.2087e-06,
5.4697e-02, -9.0246e-08, -8.5752e-18, -3.6054e-07, -1.7776e-11],
[ 1.0100e-01, -1.0979e-05, 1.9160e-02, -6.2655e-06, 4.4209e-04,
-1.1139e-07, -2.8346e-07, -3.8414e-05, -3.5571e-07, -3.9045e-05],
[-1.1940e-24, -1.1416e-03, -6.1396e-10, -7.2483e-15, -9.0970e-12,
8.5502e-02, -9.0134e-09, -6.0799e-18, -1.8524e-11, -1.1368e-18]],
device='cuda:0')
Please clarify whether:
ot.emd2 is intentionally non-differentiable (since it solves a linear program);
or if there is a plan to add a differentiable variant (e.g. differentiable EMD via implicit function theorem or entropic relaxation);
or if the gradients observed are accidental numerical artifacts.
Thanks a lot for your time and for maintaining this library!