Skip to content

question regarding applying transport  #651

@sancelot

Description

@sancelot

hi,
I would like to identify a transport plan given two known input grids , and output grids and apply it to a new grid however it does not work:

import ot

def compute_transport_plan(tensor1, tensor2):
    """
    Compute the optimal transport plan from tensor1 to tensor2.
    """
    tensor1 = tensor1.flatten()
    tensor2 = tensor2.flatten()
    
    # Assuming uniform weights for the discrete distributions
    a = np.ones(len(tensor1)) / len(tensor1)
    b = np.ones(len(tensor2)) / len(tensor2)
    
    M = ot.dist(tensor1[:, None], tensor2[:, None], metric='euclidean')
    
    # Compute the optimal transport plan
    transport_plan = ot.emd(a, b, M)
    return transport_plan

def apply_transport_plan(tensor, transport_plan, target_shape):
  """
  Apply the transport plan to transform the tensor.
  """
  tensor = tensor.flatten()
  
  # Multiplication might work without reshape depending on tensor shapes
  transformed_tensor = np.dot(transport_plan, tensor)
  # Reshape to target shape only if necessary
  if len(transformed_tensor.shape) != len(target_shape):
      transformed_tensor = transformed_tensor.reshape(target_shape)
  
  return transformed_tensor

# Example usage
tensor1 = np.array([[1, 2, 3], [4, 5, 6]])
tensor2 = np.array([[7, 8, 9], [10, 11, 12]])
new_tensor = np.array([[13, 14, 15], [16, 17, 18]])


#  Compute Optimal Transport Plan
transport_plan = compute_transport_plan(tensor1, tensor2)
print(f"Optimal Transport Plan: \n{transport_plan}")

# Apply the Transport Plan to a New Tensor
transformed_tensor = apply_transport_plan(new_tensor, transport_plan, tensor2.shape)
print(f"Transformed Tensor: \n{transformed_tensor}")

I have this result :
Optimal Transport Plan:
[[0. 0. 0. 0. 0.16666667 0. ]
[0. 0.16666667 0. 0. 0. 0. ]
[0. 0. 0.16666667 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.16666667]
[0.16666667 0. 0. 0. 0. 0. ]
[0. 0. 0. 0.16666667 0. 0. ]]

Transformed Tensor:
[[2.83333333 2.33333333 2.5 ]
[3. 2.16666667 2.66666667]]

I expected [19,20,21],[22,23,24]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions