Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions hackable_diffusion/lib/corruption/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from hackable_diffusion.lib.corruption.discrete import PostCorruptionFn
from hackable_diffusion.lib.corruption.discrete import SymmetricPostCorruptionFn
from hackable_diffusion.lib.corruption.gaussian import GaussianProcess
from hackable_diffusion.lib.corruption.riemannian import RiemannianFlowMatching
from hackable_diffusion.lib.corruption.schedules import CosineDiscreteSchedule
from hackable_diffusion.lib.corruption.schedules import CosineSchedule
from hackable_diffusion.lib.corruption.schedules import DiscreteSchedule
Expand All @@ -31,8 +32,10 @@
from hackable_diffusion.lib.corruption.schedules import InverseCosineSchedule
from hackable_diffusion.lib.corruption.schedules import LinearDiffusionSchedule
from hackable_diffusion.lib.corruption.schedules import LinearDiscreteSchedule
from hackable_diffusion.lib.corruption.schedules import LinearRiemannianSchedule
from hackable_diffusion.lib.corruption.schedules import PolynomialDiscreteSchedule
from hackable_diffusion.lib.corruption.schedules import RFSchedule
from hackable_diffusion.lib.corruption.schedules import RiemannianSchedule
from hackable_diffusion.lib.corruption.schedules import Schedule
from hackable_diffusion.lib.corruption.schedules import ShiftedSchedule
from hackable_diffusion.lib.corruption.schedules import SquareCosineDiscreteSchedule
Expand Down
99 changes: 99 additions & 0 deletions hackable_diffusion/lib/corruption/riemannian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2026 Hackable Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Riemannian Flow Matching corruption process."""

import dataclasses
from typing import Any

from hackable_diffusion.lib import hd_typing
from hackable_diffusion.lib import manifolds
from hackable_diffusion.lib import utils
from hackable_diffusion.lib.corruption import base
from hackable_diffusion.lib.corruption import schedules
import kauldron.ktyping as kt

PRNGKey = hd_typing.PRNGKey
DataArray = hd_typing.DataArray
TimeArray = hd_typing.TimeArray
TargetInfo = hd_typing.TargetInfo


@dataclasses.dataclass(kw_only=True, frozen=True)
class RiemannianFlowMatching(base.CorruptionProcess):
"""Riemannian Flow Matching corruption process.

This is based on https://arxiv.org/abs/2302.03660.

Given a schedule with interpolation parameter alpha(t):
x_t = geodesic(x_0, x_1, alpha(t))
target = alpha'(t) * velocity(x_0, x_1, alpha(t))
"""

manifold: manifolds.Manifold
schedule: schedules.RiemannianSchedule

@kt.typechecked
def sample_from_invariant(
self,
key: PRNGKey,
data_spec: DataArray,
) -> DataArray:
"""Sample from the base distribution (uniform) on the manifold."""
return self.manifold.random_uniform(key, data_spec.shape)

@kt.typechecked
def corrupt(
self,
key: PRNGKey,
x0: DataArray,
time: TimeArray,
) -> tuple[DataArray, TargetInfo]:
x1 = self.sample_from_invariant(key, data_spec=x0)

# Evaluate schedule: alpha(t) is the geodesic interpolation parameter.
alpha_t = utils.bcast_right(self.schedule.alpha(time), x0.ndim)
alpha_dot_t = utils.bcast_right(self.schedule.alpha_dot(time), x0.ndim)

# x_t = geodesic(x0, x1, alpha(t)).
xt = self.manifold.exp(x0, alpha_t * self.manifold.log(x0, x1))

# By chain rule: d/dt x_t = alpha'(t) * velocity(x0, x1, alpha(t)).
vel = alpha_dot_t * self.manifold.velocity(x0, x1, alpha_t)

target_info = {
'x0': x0,
'x1': x1,
'velocity': vel,
}

return xt, target_info

@kt.typechecked
def convert_predictions(
self,
prediction: TargetInfo,
xt: DataArray,
time: TimeArray,
) -> TargetInfo:
"""Convert predictions to velocity parameterization."""
if 'velocity' in prediction:
return prediction
raise NotImplementedError(
'Only velocity prediction is supported for RFM currently.'
)

@kt.typechecked
def get_schedule_info(self, time: TimeArray) -> dict[str, Any]:
return self.schedule.evaluate(time)
152 changes: 152 additions & 0 deletions hackable_diffusion/lib/corruption/riemannian_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright 2026 Hackable Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for Riemannian Flow Matching corruption process."""

from absl.testing import absltest
from hackable_diffusion.lib import manifolds
from hackable_diffusion.lib.corruption import riemannian
from hackable_diffusion.lib.corruption import schedules
import jax
import jax.numpy as jnp
import numpy as np


def _make_process(manifold):
return riemannian.RiemannianFlowMatching(
manifold=manifold,
schedule=schedules.LinearRiemannianSchedule(),
)


class SphereCorruptionTest(absltest.TestCase):

def test_corrupt(self):
manifold = manifolds.Sphere()
process = _make_process(manifold)
key = jax.random.PRNGKey(0)

batch_size = 8
x0 = manifold.random_uniform(key, (batch_size, 3))
time = jnp.linspace(0, 1, batch_size)

xt, target_info = process.corrupt(key, x0, time)

# xt should be on the sphere.
norms = jnp.linalg.norm(xt, axis=-1)
np.testing.assert_allclose(norms, 1.0, atol=1e-5)

# Velocity should be tangent to the sphere at xt, i.e. <xt, vel> = 0.
vel = target_info['velocity']
self.assertEqual(vel.shape, (batch_size, 3))
inner_products = jnp.sum(xt * vel, axis=-1)
np.testing.assert_allclose(inner_products, 0.0, atol=1e-5)

def test_velocity_at_t1(self):
"""At t=1, alpha=0 so xt = x0 and velocity = -log(x0, x1)."""
manifold = manifolds.Sphere()
process = _make_process(manifold)
key = jax.random.PRNGKey(0)

x0 = jnp.array([[1.0, 0.0, 0.0]])
t1 = jnp.array([1.0])
xt1, target1 = process.corrupt(key, x0, t1)
np.testing.assert_allclose(xt1, x0, atol=1e-5)

v1 = target1['velocity']
x1_sampled = target1['x1']
v_log = manifold.log(x0, x1_sampled)
np.testing.assert_allclose(v1, -v_log, atol=1e-5)


class SO3CorruptionTest(absltest.TestCase):

def test_corrupt(self):
manifold = manifolds.SO3()
process = _make_process(manifold)
key = jax.random.PRNGKey(1)

batch_size = 8
x0 = manifold.random_uniform(key, (batch_size, 3, 3))
time = jnp.linspace(0, 1, batch_size)

xt, target_info = process.corrupt(key, x0, time)

# xt should be a valid rotation: R^T R = I and det(R) = 1.
rtrt = jnp.matmul(jnp.swapaxes(xt, -2, -1), xt)
eyes = jnp.broadcast_to(jnp.eye(3), rtrt.shape)
np.testing.assert_allclose(rtrt, eyes, atol=1e-5)
np.testing.assert_allclose(jnp.linalg.det(xt), 1.0, atol=1e-5)

# Velocity should be in the tangent space: x^T v is skew-symmetric.
vel = target_info['velocity']
self.assertEqual(vel.shape, (batch_size, 3, 3))

def test_velocity_at_t1(self):
"""At t=1, alpha=0 so xt = x0 and velocity = -log(x0, x1)."""
manifold = manifolds.SO3()
process = _make_process(manifold)
key = jax.random.PRNGKey(1)

x0 = jnp.eye(3)[None, ...] # (1, 3, 3)
t1 = jnp.array([1.0])
xt1, target1 = process.corrupt(key, x0, t1)
np.testing.assert_allclose(xt1, x0, atol=1e-5)

v1 = target1['velocity']
x1_sampled = target1['x1']
v_log = manifold.log(x0, x1_sampled)
np.testing.assert_allclose(v1, -v_log, atol=1e-4)


class TorusCorruptionTest(absltest.TestCase):

def test_corrupt(self):
manifold = manifolds.Torus()
process = _make_process(manifold)
key = jax.random.PRNGKey(2)

batch_size = 8
dim = 4
x0 = manifold.random_uniform(key, (batch_size, dim))
time = jnp.linspace(0, 1, batch_size)

xt, target_info = process.corrupt(key, x0, time)

# xt should be in [0, 1).
self.assertTrue(jnp.all(xt >= 0.0))
self.assertTrue(jnp.all(xt < 1.0))

vel = target_info['velocity']
self.assertEqual(vel.shape, (batch_size, dim))

def test_velocity_at_t1(self):
"""At t=1, alpha=0 so xt = x0 and velocity = -log(x0, x1)."""
manifold = manifolds.Torus()
process = _make_process(manifold)
key = jax.random.PRNGKey(2)

x0 = jnp.array([[0.1, 0.5, 0.9]])
t1 = jnp.array([1.0])
xt1, target1 = process.corrupt(key, x0, t1)
np.testing.assert_allclose(xt1, x0, atol=1e-5)

v1 = target1['velocity']
x1_sampled = target1['x1']
v_log = manifold.log(x0, x1_sampled)
np.testing.assert_allclose(v1, -v_log, atol=1e-5)


if __name__ == '__main__':
absltest.main()
53 changes: 53 additions & 0 deletions hackable_diffusion/lib/corruption/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,59 @@ def evaluate(self, time: TimeArray) -> dict[str, TimeArray]:
SimplicialSchedule = DiscreteSchedule


################################################################################
# MARK: Riemannian Schedules
################################################################################


class RiemannianSchedule(abc.ABC, Schedule):
"""Base class for Riemannian schedules.

Controls the geodesic interpolation via alpha(t):
x_t = geodesic(x_0, x_1, alpha(t))
v_t = alpha'(t) * velocity(x_0, x_1, alpha(t))

Subclasses must implement `alpha`.
"""

@abc.abstractmethod
def alpha(self, time: TimeArray) -> TimeArray:
"""The geodesic interpolation parameter at time t."""

def alpha_dot(self, time: TimeArray) -> TimeArray:
"""Time derivative of alpha. Defaults to autodiff."""
return utils.egrad(self.alpha)(time)

@kt.typechecked
def evaluate(self, time: TimeArray) -> dict[str, TimeArray]:
return {
'time': time,
'alpha': self.alpha(time),
'alpha_dot': self.alpha_dot(time),
}


class LinearRiemannianSchedule(RiemannianSchedule):
"""Linear Riemannian schedule: alpha(t) = 1.0 - t.

This is the standard flow matching schedule where the geodesic interpolation
parameter equals time directly.
Note that contrary to the original Riemannian Flow Matching, we assume that at
time t=0, the process is close to the data distribution, and at time t=1,
the process is close to the target distribution.
Hence, we use alpha(t) = 1.0 - t, and alpha_dot(t) = -1.0m instead of
alpha(t) = t, and alpha_dot(t) = 1.0.
"""

@kt.typechecked
def alpha(self, time: TimeArray) -> TimeArray:
return 1.0 - time

@kt.typechecked
def alpha_dot(self, time: TimeArray) -> TimeArray:
return -jnp.ones_like(time)


################################################################################
# MARK: Gaussian Schedules
################################################################################
Expand Down
1 change: 1 addition & 0 deletions hackable_diffusion/lib/sampling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from hackable_diffusion.lib.sampling.gaussian_step_sampler import HeunStep
from hackable_diffusion.lib.sampling.gaussian_step_sampler import SdeStep
from hackable_diffusion.lib.sampling.gaussian_step_sampler import VelocityStep
from hackable_diffusion.lib.sampling.riemannian_sampling import RiemannianFlowSamplerStep
from hackable_diffusion.lib.sampling.sampling import DiffusionSampler
from hackable_diffusion.lib.sampling.sampling import SampleFn
from hackable_diffusion.lib.sampling.simplicial_step_sampler import SimplicialDDIMStep
Expand Down
Loading