Skip to content

How ot.dist work with sequence? #522

@huutuongtu

Description

@huutuongtu

Hello,

I consider OT a black box, so I may ask something stupid.

I'm following Wasserstein 2 Minibatch GAN with PyTorch to train my own model, but I got error. My input and output are sequences. Here is my code:

    ab = (torch.ones(batch_size) / batch_size).to(device)
    sgd = torch.optim.Adam(model.parameters(), lr=0.001)
    CE_loss = nn.CrossEntropyLoss(ignore_index = 41)
    for epoch in range(1000):
      logits, c_emb, t_emb = model(phonetic, linguistic, transcript)
      # print(logits.shape) #batch x classes x time
      # print(c_emb.shape) #batch x time x features
      # print(t_emb.shape) #batch x time x features
      M = ot.dist(c_emb, t_emb)
      loss_W = ot.emd2(ab, ab, M).to(device)
      loss_CE = CE_loss(logits, output)
      loss = loss_W + loss_CE
      loss.backward()
      sgd.step()
      sgd.zero_grad()

The error:

M = ot.dist(c_emb, t_emb)
File "/opt/conda/lib/python3.8/site-packages/ot/utils.py", line 307, in dist
return euclidean_distances(x1, x2, squared=True)
File "/opt/conda/lib/python3.8/site-packages/ot/utils.py", line 253, in euclidean_distances
a2 = nx.einsum('ij,ij->i', X, X)
File "/opt/conda/lib/python3.8/site-packages/ot/backend.py", line 1897, in einsum
return torch.einsum(subscripts, *operands)
File "/opt/conda/lib/python3.8/site-packages/torch/functional.py", line 378, in einsum
return _VF.einsum(equation, operands) # type: ignore[attr-defined]
RuntimeError: einsum(): the number of subscripts in the equation (2) does not match the number of dimensions (3) for operand 0 and no ellipsis was given

How can I use ot.dist with sequence correctly?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions