Skip to content

Typos in the formulas of gwggrad and solve_gromov_linesearch #541

@mr-gomez

Description

@mr-gomez

Describe the bug

The formulas in gwggrad and solve_gromov_linesearch have typos and do not match the cited references [12] and [24]. I also calculated the gradient by hand to confirm that POT has typos.

For concreteness, I'm using the following versions of the cited papers:
[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
IN: https://proceedings.mlr.press/v48/peyre16.pdf

[24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
IN: https://arxiv.org/pdf/1805.09114.pdf

To Reproduce

  1. Run the code below.

Code sample

Notes:

  1. The code below is a modification of plot_gromov.py from the examples gallery. I computed the GW distance two times, one using POT and another with my corrections implemented in gwggrad_mod and solve_gromov_linesearch_mod.
  2. The typos did not affect the result of the Gromov-Wasserstein distance in my example, but I wonder if making sub-optimal choices in line-search will affect the speed of convergence in more complicated calculations.
import scipy as sp
import numpy as np
import ot

# Import functions required in ot.gromov._gw
from ot.utils import list_to_array
from ot.optim import cg, solve_1d_linesearch_quad
from ot.backend import get_backend, NumpyBackend

from ot.gromov._utils import init_matrix, gwloss, gwggrad
from ot.gromov._gw import solve_gromov_linesearch

#############################################################################
#
# Sample two Gaussian distributions (2D and 3D)
# ---------------------------------------------
#############################################################################

n_samples = 30  # nb samples

mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])

mu_t = np.array([4, 4, 4])
cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])

np.random.seed(0)
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
P = sp.linalg.sqrtm(cov_t)
xt = np.random.randn(n_samples, 3).dot(P) + mu_t

C1 = sp.spatial.distance.cdist(xs, xs)
C2 = sp.spatial.distance.cdist(xt, xt)

C1 /= C1.max()
C2 /= C2.max()

#############################################################################
#
# Parameters for dGW
# ---------------------------------------------
#############################################################################

p = ot.unif(n_samples)
q = ot.unif(n_samples)
G0 = p[:, None] * q[None, :]

loss_fun='square_loss'
symmetric=None
log=True
armijo=False
max_iter=1e4
tol_rel=1e-9
tol_abs=1e-9

#############################################################################
# 
# gwggrad and solve_gromov_linesearch with typos corrected
# ---------------------------------------------
#############################################################################
def gwggrad_mod(constC, hC1, hC2, T, nx=None):
    if nx is None:
        constC, hC1, hC2, T = list_to_array(constC, hC1, hC2, T)
        nx = get_backend(constC, hC1, hC2, T)
    
    return constC - 2 * nx.dot( nx.dot(hC1, T), hC2.T )

def solve_gromov_linesearch_mod(G, deltaG, cost_G, constC, C1, C2, M, reg,
                            alpha_min=None, alpha_max=None, nx=None, **kwargs):
    if nx is None:
        G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M)

        if isinstance(M, int) or isinstance(M, float):
            nx = get_backend(G, deltaG, C1, C2)
        else:
            nx = get_backend(G, deltaG, C1, C2, M)
    
    dot_dG = nx.dot(nx.dot(C1, deltaG), C2.T)
    dot_G  = nx.dot(nx.dot(C1, G     ), C2.T)
    
    a = -2 * reg * nx.sum(dot_dG * deltaG)
    b = nx.sum(M * deltaG) + reg * (nx.sum(constC * deltaG) - 2 * nx.sum(dot_dG * G) - 2 * nx.sum(dot_G * deltaG))

    alpha = solve_1d_linesearch_quad(a, b)
    if alpha_min is not None or alpha_max is not None:
        alpha = np.clip(alpha, alpha_min, alpha_max)

    # the new cost is deduced from the line search quadratic function
    cost_G = cost_G + a * (alpha ** 2) + b * alpha

    return alpha, 1, cost_G

#############################################################################
#
# Compute Gromov-Wasserstein with modified functions
# ---------------------------------------------
#############################################################################
# cg for GW is implemented using numpy on CPU
np_ = NumpyBackend()

nx = get_backend(C1, C2, p, q)
p0, q0, C10, C20 = p, q, C1, C2
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, np_)

######################################################################
# Define loss function, gradient and linesearch
# ---------------------------------------------
# NOTE: Using modified gwgrad and line_search
def f(G):
    return gwloss(constC, hC1, hC2, G, np_)

def df(G):
    return gwggrad_mod(constC, hC1, hC2, G, np_)

def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
    return solve_gromov_linesearch_mod(G, deltaG, cost_G, constC, C1, C2, M=0., reg=1., nx=np_, **kwargs)
######################################################################

res_mod, log_mod = cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs)
log_mod['gw_dist'] = nx.from_numpy(log_mod['loss'][-1], type_as=C1)
log_mod['u'] = nx.from_numpy(log_mod['u'], type_as=C1)
log_mod['v'] = nx.from_numpy(log_mod['v'], type_as=C1)

gw_mod = nx.from_numpy(res_mod, type_as=C1)


# Compute GW with the original function
gw0, log0 = ot.gromov.gromov_wasserstein(
    C1, C2, p, q, 'square_loss', verbose=True, log=True)

#############################################################################
#
# Compare gwggrad and solve_gromov_linesearch with their modified versions
# ---------------------------------------------
#############################################################################
G = G0
deltaG = np.random.rand(*G.shape)
cost_G = 0

grad_mod = gwggrad_mod(constC, hC1, hC2, G, np_)
grad = gwggrad(constC, hC1, hC2, G, np_)

linesearch_mod = solve_gromov_linesearch_mod(G, deltaG, cost_G, constC, C1, C2, M=0., reg=1., nx=np_)
linesearch = solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M=0., reg=1., nx=np_)

print()
print(f"dGW with func: {log0['gw_dist']}")
print(f"dGW with mods: {log_mod['gw_dist']}")
print("GW-distances agree:", log0['gw_dist'] == log_mod['gw_dist'])

print()
print('Gradients agree:', np.array_equal(grad_mod, grad))
print('Line-search results agree:', linesearch_mod == linesearch)

Expected behavior

The functions gwggrad and solve_gromov_linesearch should output the result of gwggrad_mod and solve_gromov_linesearch_mod, respectively.

Environment

Output of the following code snippet:

import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
Linux-6.5.7-100.fc37.x86_64-x86_64-with-glibc2.36
Python 3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0]
NumPy 1.24.3
SciPy 1.11.1
POT 0.9.1

Additional context

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions