Skip to content

How is transport plan (couplings) retrieved from ot.da.LinearGWTransport()? #653

@ttsesm

Description

@ttsesm

Hi,

I am trying to use the ot.da.LinearGWTransport() function as described in the POT documentation as follows:

def match_gaussians(Cs, Ms, ct, mt, verbose=False):
    """Match 3D Gaussians to 2D Gaussians using Linear Gromov-Wasserstein Transport."""
    # Compute cost matrices
    C3D, C2D = compute_cost_matrices(Cs, Ms, ct, mt)

    # Initialize and fit the LinearGWTransport model
    gw = ot.da.LinearGWTransport(log=verbose)
    gw.fit(Xs=Ms, Xt=mt, ys=C3D, yt=C2D)

    # Get the transport plan
    transport_plan = gw.coupling # <------------------------------------- This doesn't exist

    # Get the linear operator (projection matrix)
    projection_matrix = gw.L # <------------------------------------- This doesn't exist

    return projection_matrix, transport_plan

where Cs, Ms and ct, mt are my 3D and 2D gaussians respectively (i.e. covariance and mean values).

However, I am not sure how to retrieve back the transport plan and the projection matrix from the fitted distribution. From what I've noticed I can get back the A and B matrices but I am not sure how these are related to the transport plan and projection matrix.

I would appreciated if someone has an idea and/or provide some feedback.

Thanks.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions