diff --git a/hackable_diffusion/lib/corruption/__init__.py b/hackable_diffusion/lib/corruption/__init__.py index 0ce9786..57ee789 100644 --- a/hackable_diffusion/lib/corruption/__init__.py +++ b/hackable_diffusion/lib/corruption/__init__.py @@ -22,6 +22,7 @@ from hackable_diffusion.lib.corruption.discrete import PostCorruptionFn from hackable_diffusion.lib.corruption.discrete import SymmetricPostCorruptionFn from hackable_diffusion.lib.corruption.gaussian import GaussianProcess +from hackable_diffusion.lib.corruption.riemannian import RiemannianProcess from hackable_diffusion.lib.corruption.schedules import CosineDiscreteSchedule from hackable_diffusion.lib.corruption.schedules import CosineSchedule from hackable_diffusion.lib.corruption.schedules import DiscreteSchedule @@ -31,8 +32,10 @@ from hackable_diffusion.lib.corruption.schedules import InverseCosineSchedule from hackable_diffusion.lib.corruption.schedules import LinearDiffusionSchedule from hackable_diffusion.lib.corruption.schedules import LinearDiscreteSchedule +from hackable_diffusion.lib.corruption.schedules import LinearRiemannianSchedule from hackable_diffusion.lib.corruption.schedules import PolynomialDiscreteSchedule from hackable_diffusion.lib.corruption.schedules import RFSchedule +from hackable_diffusion.lib.corruption.schedules import RiemannianSchedule from hackable_diffusion.lib.corruption.schedules import Schedule from hackable_diffusion.lib.corruption.schedules import ShiftedSchedule from hackable_diffusion.lib.corruption.schedules import SquareCosineDiscreteSchedule diff --git a/hackable_diffusion/lib/corruption/riemannian.py b/hackable_diffusion/lib/corruption/riemannian.py new file mode 100644 index 0000000..3016706 --- /dev/null +++ b/hackable_diffusion/lib/corruption/riemannian.py @@ -0,0 +1,99 @@ +# Copyright 2026 Hackable Diffusion Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Riemannian Flow Matching corruption process.""" + +import dataclasses +from typing import Any + +from hackable_diffusion.lib import hd_typing +from hackable_diffusion.lib import manifolds +from hackable_diffusion.lib import utils +from hackable_diffusion.lib.corruption import base +from hackable_diffusion.lib.corruption import schedules +import kauldron.ktyping as kt + +PRNGKey = hd_typing.PRNGKey +DataArray = hd_typing.DataArray +TimeArray = hd_typing.TimeArray +TargetInfo = hd_typing.TargetInfo + + +@dataclasses.dataclass(kw_only=True, frozen=True) +class RiemannianProcess(base.CorruptionProcess): + """Riemannian Flow Matching corruption process. + + This is based on https://arxiv.org/abs/2302.03660. + + Given a schedule with interpolation parameter alpha(t): + x_t = geodesic(x_0, x_1, alpha(t)) + target = alpha'(t) * velocity(x_0, x_1, alpha(t)) + """ + + manifold: manifolds.Manifold + schedule: schedules.RiemannianSchedule + + @kt.typechecked + def sample_from_invariant( + self, + key: PRNGKey, + data_spec: DataArray, + ) -> DataArray: + """Sample from the base distribution (uniform) on the manifold.""" + return self.manifold.random_uniform(key, data_spec.shape) + + @kt.typechecked + def corrupt( + self, + key: PRNGKey, + x0: DataArray, + time: TimeArray, + ) -> tuple[DataArray, TargetInfo]: + x1 = self.sample_from_invariant(key, data_spec=x0) + + # Evaluate schedule: alpha(t) is the geodesic interpolation parameter. + alpha_t = utils.bcast_right(self.schedule.alpha(time), x0.ndim) + alpha_dot_t = utils.bcast_right(self.schedule.alpha_dot(time), x0.ndim) + + # x_t = geodesic(x0, x1, alpha(t)). + xt = self.manifold.exp(x0, alpha_t * self.manifold.log(x0, x1)) + + # By chain rule: d/dt x_t = alpha'(t) * velocity(x0, x1, alpha(t)). + vel = alpha_dot_t * self.manifold.velocity(x0, x1, alpha_t) + + target_info = { + 'x0': x0, + 'x1': x1, + 'velocity': vel, + } + + return xt, target_info + + @kt.typechecked + def convert_predictions( + self, + prediction: TargetInfo, + xt: DataArray, + time: TimeArray, + ) -> TargetInfo: + """Convert predictions to velocity parameterization.""" + if 'velocity' in prediction: + return prediction + raise NotImplementedError( + 'Only velocity prediction is supported for RFM currently.' + ) + + @kt.typechecked + def get_schedule_info(self, time: TimeArray) -> dict[str, Any]: + return self.schedule.evaluate(time) diff --git a/hackable_diffusion/lib/corruption/riemannian_test.py b/hackable_diffusion/lib/corruption/riemannian_test.py new file mode 100644 index 0000000..f410a04 --- /dev/null +++ b/hackable_diffusion/lib/corruption/riemannian_test.py @@ -0,0 +1,152 @@ +# Copyright 2026 Hackable Diffusion Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Riemannian Flow Matching corruption process.""" + +from absl.testing import absltest +from hackable_diffusion.lib import manifolds +from hackable_diffusion.lib.corruption import riemannian +from hackable_diffusion.lib.corruption import schedules +import jax +import jax.numpy as jnp +import numpy as np + + +def _make_process(manifold): + return riemannian.RiemannianProcess( + manifold=manifold, + schedule=schedules.LinearRiemannianSchedule(), + ) + + +class SphereCorruptionTest(absltest.TestCase): + + def test_corrupt(self): + manifold = manifolds.Sphere() + process = _make_process(manifold) + key = jax.random.PRNGKey(0) + + batch_size = 8 + x0 = manifold.random_uniform(key, (batch_size, 3)) + time = jnp.linspace(0, 1, batch_size) + + xt, target_info = process.corrupt(key, x0, time) + + # xt should be on the sphere. + norms = jnp.linalg.norm(xt, axis=-1) + np.testing.assert_allclose(norms, 1.0, atol=1e-5) + + # Velocity should be tangent to the sphere at xt, i.e. = 0. + vel = target_info['velocity'] + self.assertEqual(vel.shape, (batch_size, 3)) + inner_products = jnp.sum(xt * vel, axis=-1) + np.testing.assert_allclose(inner_products, 0.0, atol=1e-5) + + def test_velocity_at_t1(self): + """At t=1, alpha=0 so xt = x0 and velocity = -log(x0, x1).""" + manifold = manifolds.Sphere() + process = _make_process(manifold) + key = jax.random.PRNGKey(0) + + x0 = jnp.array([[1.0, 0.0, 0.0]]) + t1 = jnp.array([1.0]) + xt1, target1 = process.corrupt(key, x0, t1) + np.testing.assert_allclose(xt1, x0, atol=1e-5) + + v1 = target1['velocity'] + x1_sampled = target1['x1'] + v_log = manifold.log(x0, x1_sampled) + np.testing.assert_allclose(v1, -v_log, atol=1e-5) + + +class SO3CorruptionTest(absltest.TestCase): + + def test_corrupt(self): + manifold = manifolds.SO3() + process = _make_process(manifold) + key = jax.random.PRNGKey(1) + + batch_size = 8 + x0 = manifold.random_uniform(key, (batch_size, 3, 3)) + time = jnp.linspace(0, 1, batch_size) + + xt, target_info = process.corrupt(key, x0, time) + + # xt should be a valid rotation: R^T R = I and det(R) = 1. + rtrt = jnp.matmul(jnp.swapaxes(xt, -2, -1), xt) + eyes = jnp.broadcast_to(jnp.eye(3), rtrt.shape) + np.testing.assert_allclose(rtrt, eyes, atol=1e-5) + np.testing.assert_allclose(jnp.linalg.det(xt), 1.0, atol=1e-5) + + # Velocity should be in the tangent space: x^T v is skew-symmetric. + vel = target_info['velocity'] + self.assertEqual(vel.shape, (batch_size, 3, 3)) + + def test_velocity_at_t1(self): + """At t=1, alpha=0 so xt = x0 and velocity = -log(x0, x1).""" + manifold = manifolds.SO3() + process = _make_process(manifold) + key = jax.random.PRNGKey(1) + + x0 = jnp.eye(3)[None, ...] # (1, 3, 3) + t1 = jnp.array([1.0]) + xt1, target1 = process.corrupt(key, x0, t1) + np.testing.assert_allclose(xt1, x0, atol=1e-5) + + v1 = target1['velocity'] + x1_sampled = target1['x1'] + v_log = manifold.log(x0, x1_sampled) + np.testing.assert_allclose(v1, -v_log, atol=1e-4) + + +class TorusCorruptionTest(absltest.TestCase): + + def test_corrupt(self): + manifold = manifolds.Torus() + process = _make_process(manifold) + key = jax.random.PRNGKey(2) + + batch_size = 8 + dim = 4 + x0 = manifold.random_uniform(key, (batch_size, dim)) + time = jnp.linspace(0, 1, batch_size) + + xt, target_info = process.corrupt(key, x0, time) + + # xt should be in [0, 1). + self.assertTrue(jnp.all(xt >= 0.0)) + self.assertTrue(jnp.all(xt < 1.0)) + + vel = target_info['velocity'] + self.assertEqual(vel.shape, (batch_size, dim)) + + def test_velocity_at_t1(self): + """At t=1, alpha=0 so xt = x0 and velocity = -log(x0, x1).""" + manifold = manifolds.Torus() + process = _make_process(manifold) + key = jax.random.PRNGKey(2) + + x0 = jnp.array([[0.1, 0.5, 0.9]]) + t1 = jnp.array([1.0]) + xt1, target1 = process.corrupt(key, x0, t1) + np.testing.assert_allclose(xt1, x0, atol=1e-5) + + v1 = target1['velocity'] + x1_sampled = target1['x1'] + v_log = manifold.log(x0, x1_sampled) + np.testing.assert_allclose(v1, -v_log, atol=1e-5) + + +if __name__ == '__main__': + absltest.main() diff --git a/hackable_diffusion/lib/corruption/schedules.py b/hackable_diffusion/lib/corruption/schedules.py index fe5fa71..cf7a58a 100644 --- a/hackable_diffusion/lib/corruption/schedules.py +++ b/hackable_diffusion/lib/corruption/schedules.py @@ -109,6 +109,59 @@ def evaluate(self, time: TimeArray) -> dict[str, TimeArray]: SimplicialSchedule = DiscreteSchedule +################################################################################ +# MARK: Riemannian Schedules +################################################################################ + + +class RiemannianSchedule(abc.ABC, Schedule): + """Base class for Riemannian schedules. + + Controls the geodesic interpolation via alpha(t): + x_t = geodesic(x_0, x_1, alpha(t)) + v_t = alpha'(t) * velocity(x_0, x_1, alpha(t)) + + Subclasses must implement `alpha`. + """ + + @abc.abstractmethod + def alpha(self, time: TimeArray) -> TimeArray: + """The geodesic interpolation parameter at time t.""" + + def alpha_dot(self, time: TimeArray) -> TimeArray: + """Time derivative of alpha. Defaults to autodiff.""" + return utils.egrad(self.alpha)(time) + + @kt.typechecked + def evaluate(self, time: TimeArray) -> dict[str, TimeArray]: + return { + 'time': time, + 'alpha': self.alpha(time), + 'alpha_dot': self.alpha_dot(time), + } + + +class LinearRiemannianSchedule(RiemannianSchedule): + """Linear Riemannian schedule: alpha(t) = 1.0 - t. + + This is the standard flow matching schedule where the geodesic interpolation + parameter equals time directly. + Note that contrary to the original Riemannian Flow Matching, we assume that at + time t=0, the process is close to the data distribution, and at time t=1, + the process is close to the target distribution. + Hence, we use alpha(t) = 1.0 - t, and alpha_dot(t) = -1.0m instead of + alpha(t) = t, and alpha_dot(t) = 1.0. + """ + + @kt.typechecked + def alpha(self, time: TimeArray) -> TimeArray: + return 1.0 - time + + @kt.typechecked + def alpha_dot(self, time: TimeArray) -> TimeArray: + return -jnp.ones_like(time) + + ################################################################################ # MARK: Gaussian Schedules ################################################################################ diff --git a/hackable_diffusion/lib/corruption/schedules_test.py b/hackable_diffusion/lib/corruption/schedules_test.py index c6be243..f1417a9 100644 --- a/hackable_diffusion/lib/corruption/schedules_test.py +++ b/hackable_diffusion/lib/corruption/schedules_test.py @@ -373,5 +373,65 @@ def test_shifted_cosine_schedule(self): ) +################################################################################ +# MARK: Riemannian Schedule Tests +################################################################################ + + +class LinearRiemannianScheduleTest(absltest.TestCase): + + def test_alpha_boundary_values(self): + schedule = schedules.LinearRiemannianSchedule() + # At t=0, alpha(0) = 1.0 (data). + self.assertAlmostEqual( + schedule.alpha(jnp.array([0.0])).item(), 1.0, places=6 + ) + # At t=1, alpha(1) = 0.0 (noise / base distribution). + self.assertAlmostEqual( + schedule.alpha(jnp.array([1.0])).item(), 0.0, places=6 + ) + + def test_alpha_intermediate(self): + schedule = schedules.LinearRiemannianSchedule() + self.assertAlmostEqual( + schedule.alpha(jnp.array([0.4])).item(), 0.6, places=6 + ) + + def test_alpha_dot(self): + schedule = schedules.LinearRiemannianSchedule() + t = jnp.linspace(0.01, 0.99, 50) + alpha_dot = schedule.alpha_dot(t) + # For linear schedule, alpha_dot(t) = -1.0 for all t. + self.assertTrue(jnp.allclose(alpha_dot, -jnp.ones_like(t), atol=1e-6)) + + def test_alpha_dot_matches_autodiff(self): + schedule = schedules.LinearRiemannianSchedule() + t = jnp.linspace(0.01, 0.99, 50) + alpha_dot = schedule.alpha_dot(t) + alpha_dot_auto = utils.egrad(schedule.alpha)(t) + self.assertTrue( + jnp.allclose(alpha_dot, alpha_dot_auto, atol=1e-6, rtol=1e-6), + 'alpha_dot() does not match autodiff: Absolute difference:' + f' {jnp.max(jnp.abs(alpha_dot - alpha_dot_auto))}', + ) + + def test_alpha_bounds(self): + schedule = schedules.LinearRiemannianSchedule() + t = jnp.linspace(0.0, 1.0, 100) + alpha = schedule.alpha(t) + self.assertTrue(jnp.all(alpha >= 0.0)) + self.assertTrue(jnp.all(alpha <= 1.0)) + + def test_evaluate(self): + schedule = schedules.LinearRiemannianSchedule() + t = jnp.array([0.3]) + result = schedule.evaluate(t) + self.assertIn('time', result) + self.assertIn('alpha', result) + self.assertIn('alpha_dot', result) + self.assertAlmostEqual(result['alpha'].item(), 0.7, places=6) + self.assertAlmostEqual(result['alpha_dot'].item(), -1.0, places=6) + + if __name__ == '__main__': absltest.main() diff --git a/hackable_diffusion/lib/manifolds.py b/hackable_diffusion/lib/manifolds.py index e861644..f5d0800 100644 --- a/hackable_diffusion/lib/manifolds.py +++ b/hackable_diffusion/lib/manifolds.py @@ -19,28 +19,37 @@ import jax import jax.numpy as jnp +################################################################################ +# MARK: Type Aliases +################################################################################ + +DataArray = hd_typing.DataArray +TimeArray = hd_typing.TimeArray +PRNGKey = hd_typing.PRNGKey +LossOutput = hd_typing.LossOutput +Array = hd_typing.Array + ################################################################################ # MARK: Constants ################################################################################ EPSILON = 1e-9 - ################################################################################ # MARK: Utility functions ################################################################################ def unnormalized_sinc( - x: hd_typing.Array['*batch'], -) -> hd_typing.Array['*batch']: + x: Array['*batch'], +) -> Array['*batch']: """Safe sinc(x).""" return jnp.sinc(x / jnp.pi) def unnormalized_cosc( - x: hd_typing.Array['*batch'], -) -> hd_typing.Array['*batch']: + x: Array['*batch'], +) -> Array['*batch']: """Safe (1-cos(x))/x^2. Leverages the sinc trick to compute (1-cos(x))/x^2 safely. Using the identity @@ -57,11 +66,11 @@ def unnormalized_cosc( def safe_norm( - x: hd_typing.Array, + x: Array, axis: tuple[int, ...] = (-1,), keepdims: bool = True, eps: float = 1e-9, -) -> hd_typing.Array: +) -> Array: """Computes norm safely to avoid NaN gradients at zero.""" is_zero = jnp.all(x == 0, axis=axis, keepdims=keepdims) safe_x = jnp.where(is_zero, eps, x) @@ -69,7 +78,7 @@ def safe_norm( return jnp.where(is_zero, 0.0, n) -def transpose(x: hd_typing.DataArray) -> hd_typing.DataArray: +def transpose(x: DataArray) -> DataArray: """Transpose of a tensor on the last two dimensions.""" return jnp.swapaxes(x, -1, -2) @@ -116,42 +125,32 @@ class Manifold(Protocol): Flow Matching. """ - def exp( - self, x: hd_typing.DataArray, v: hd_typing.DataArray - ) -> hd_typing.DataArray: + def exp(self, x: DataArray, v: DataArray) -> DataArray: """Exponential map.""" ... - def log( - self, x: hd_typing.DataArray, y: hd_typing.DataArray - ) -> hd_typing.DataArray: + def log(self, x: DataArray, y: DataArray) -> DataArray: """Logarithm map.""" ... - def dist( - self, x: hd_typing.DataArray, y: hd_typing.DataArray - ) -> hd_typing.LossOutput: + def dist(self, x: DataArray, y: DataArray) -> LossOutput: """Riemannian distance.""" ... - def project( - self, x: hd_typing.DataArray, v: hd_typing.DataArray - ) -> hd_typing.DataArray: + def project(self, x: DataArray, v: DataArray) -> DataArray: """Project vector v to tangent space at x.""" ... - def random_uniform( - self, key: hd_typing.PRNGKey, shape: tuple[int, ...] - ) -> hd_typing.DataArray: + def random_uniform(self, key: PRNGKey, shape: tuple[int, ...]) -> DataArray: """Sample from uniform distribution on the manifold.""" ... def velocity( self, - x: hd_typing.DataArray, - y: hd_typing.DataArray, - t: hd_typing.TimeArray, - ) -> hd_typing.DataArray: + x: DataArray, + y: DataArray, + t: TimeArray, + ) -> DataArray: """Velocity of the geodesic between x and y at time t.""" ... @@ -161,19 +160,17 @@ def velocity( ################################################################################ -def dist_sq( - manifold: Manifold, x: hd_typing.DataArray, y: hd_typing.DataArray -) -> hd_typing.LossOutput: +def dist_sq(manifold: Manifold, x: DataArray, y: DataArray) -> LossOutput: """Squared Riemannian distance.""" return jnp.square(manifold.dist(x, y)) def geodesic( manifold: Manifold, - x: hd_typing.DataArray, - y: hd_typing.DataArray, - t: hd_typing.TimeArray, -) -> hd_typing.DataArray: + x: DataArray, + y: DataArray, + t: TimeArray, +) -> DataArray: """Geodesic between x and y at time t in [0, 1]. A geodesic is the generalization of a straight line to a curved manifold. @@ -198,8 +195,6 @@ def geodesic( Returns: Geodesic between x and y at time t. """ - if jnp.max(t) > 1.0 or jnp.min(t) < 0.0: - raise ValueError('Time t must be in [0, 1].') return manifold.exp(x, t * manifold.log(x, y)) @@ -232,9 +227,7 @@ class Sphere(Manifold): more details. """ - def exp( - self, x: hd_typing.DataArray, v: hd_typing.DataArray - ) -> hd_typing.DataArray: + def exp(self, x: DataArray, v: DataArray) -> DataArray: """Exponential map on S^d. Compute the exponential map on the sphere according to the formula: @@ -253,9 +246,7 @@ def exp( # recall that unnormalized_sinc(x) = sin(x) / x return jnp.cos(v_norm) * x + unnormalized_sinc(v_norm) * v - def log( - self, x: hd_typing.DataArray, y: hd_typing.DataArray - ) -> hd_typing.DataArray: + def log(self, x: DataArray, y: DataArray) -> DataArray: """Logarithm map on S^d. Compute the logarithm map on the sphere according to the formula: @@ -283,9 +274,7 @@ def log( theta = jnp.arccos(cos_theta) return (y - cos_theta * x) / unnormalized_sinc(theta) - def dist( - self, x: hd_typing.DataArray, y: hd_typing.DataArray - ) -> hd_typing.LossOutput: + def dist(self, x: DataArray, y: DataArray) -> LossOutput: """Distance on S^d. Compute the distance on the sphere according to the formula: @@ -304,9 +293,7 @@ def dist( cos_theta = jnp.clip(cos_theta, -1.0 + EPSILON, 1.0 - EPSILON) return jnp.arccos(cos_theta) - def project( - self, x: hd_typing.DataArray, v: hd_typing.DataArray - ) -> hd_typing.DataArray: + def project(self, x: DataArray, v: DataArray) -> DataArray: """Project vector v to tangent space at x. Compute the projection of v onto the tangent space at x according to the @@ -324,9 +311,7 @@ def project( non_batch_axes = tuple(range(1, x.ndim)) return v - jnp.sum(x * v, axis=non_batch_axes, keepdims=True) * x - def random_uniform( - self, key: hd_typing.PRNGKey, shape: tuple[int, ...] - ) -> hd_typing.DataArray: + def random_uniform(self, key: PRNGKey, shape: tuple[int, ...]) -> DataArray: # Samples from N(0, I) and normalizes on the sphere. non_batch_axes = tuple(range(1, len(shape))) z = jax.random.normal(key, shape) @@ -335,10 +320,10 @@ def random_uniform( def velocity( self, - x: hd_typing.DataArray, - y: hd_typing.DataArray, - t: hd_typing.TimeArray, - ) -> hd_typing.DataArray: + x: DataArray, + y: DataArray, + t: TimeArray, + ) -> DataArray: """Velocity of the geodesic between x and y at time t. Compute the velocity of the geodesic between x and y at time t according to @@ -374,7 +359,7 @@ def velocity( ################################################################################ -def _hat(v: hd_typing.Array['*batch 3']) -> hd_typing.Array['*batch 3 3']: +def _hat(v: Array['*batch 3']) -> Array['*batch 3 3']: """Hat map: R^3 -> so(3). Maps a 3D vector to a skew-symmetric matrix. Note that the operation is vectorized over the first dimension. @@ -406,7 +391,7 @@ def _hat(v: hd_typing.Array['*batch 3']) -> hd_typing.Array['*batch 3 3']: ) -def _vee(omega: hd_typing.Array['*batch 3 3']) -> hd_typing.Array['*batch 3']: +def _vee(omega: Array['*batch 3 3']) -> Array['*batch 3']: """Vee map: so(3) -> R^3. Maps a skew-symmetric matrix back to a 3D vector. Note that the operation is vectorized over the first dimension. @@ -442,9 +427,7 @@ class SO3(Manifold): skew-symmetric matrix in the Lie algebra so(3). """ - def exp( - self, x: hd_typing.DataArray, v: hd_typing.DataArray - ) -> hd_typing.DataArray: + def exp(self, x: DataArray, v: DataArray) -> DataArray: """Exponential map on SO(3). Computes the exact exponential map using Rodrigues' rotation formula. @@ -490,9 +473,7 @@ def exp( ) return jnp.matmul(x, exp_mat) - def log( - self, x: hd_typing.DataArray, y: hd_typing.DataArray - ) -> hd_typing.DataArray: + def log(self, x: DataArray, y: DataArray) -> DataArray: """Logarithm map on SO(3). This is the inverse of the exponential map. @@ -545,9 +526,7 @@ def log( ) return jnp.matmul(x, omega_mat) - def dist( - self, x: hd_typing.DataArray, y: hd_typing.DataArray - ) -> hd_typing.DataArray: + def dist(self, x: DataArray, y: DataArray) -> DataArray: """Computes the shortest geodesic distance between rotations x and y. Let y = exp(x, v) = x @ exp(Omega). Then the distance is given by theta, @@ -570,9 +549,7 @@ def dist( # theta away from pi, since in those cases the distance is ill-defined. return jnp.arccos(cos_theta) - def project( - self, x: hd_typing.DataArray, v: hd_typing.DataArray - ) -> hd_typing.DataArray: + def project(self, x: DataArray, v: DataArray) -> DataArray: """Project ambient matrix v to tangent space at x. Project the ambient matrix v to the tangent space at x by first shifting v @@ -599,9 +576,7 @@ def project( skew_omega_mat = 0.5 * (omega_mat - jnp.swapaxes(omega_mat, -1, -2)) return jnp.matmul(x, skew_omega_mat) - def random_uniform( - self, key: hd_typing.PRNGKey, shape: tuple[int, ...] - ) -> hd_typing.DataArray: + def random_uniform(self, key: PRNGKey, shape: tuple[int, ...]) -> DataArray: """Haar measure on SO(3). Samples rotation matrices uniformly via quaternions. @@ -644,10 +619,10 @@ def random_uniform( def velocity( self, - x: hd_typing.DataArray, - y: hd_typing.DataArray, - t: hd_typing.TimeArray, - ) -> hd_typing.DataArray: + x: DataArray, + y: DataArray, + t: TimeArray, + ) -> DataArray: """Velocity of the geodesic between x and y at time t. This computes the time derivative (tangent vector) along the shortest path @@ -712,38 +687,28 @@ def velocity( class Torus(Manifold): """T-dimensional Torus [0, 1]^d with periodic boundary conditions.""" - def exp( - self, x: hd_typing.DataArray, v: hd_typing.DataArray - ) -> hd_typing.DataArray: + def exp(self, x: DataArray, v: DataArray) -> DataArray: return (x + v) % 1.0 - def log( - self, x: hd_typing.DataArray, y: hd_typing.DataArray - ) -> hd_typing.DataArray: + def log(self, x: DataArray, y: DataArray) -> DataArray: """Shortest displacement on the torus.""" return (y - x + 0.5) % 1.0 - 0.5 - def dist( - self, x: hd_typing.DataArray, y: hd_typing.DataArray - ) -> hd_typing.LossOutput: + def dist(self, x: DataArray, y: DataArray) -> LossOutput: return jnp.linalg.norm(self.log(x, y), axis=-1) - def project( - self, x: hd_typing.DataArray, v: hd_typing.DataArray - ) -> hd_typing.DataArray: + def project(self, x: DataArray, v: DataArray) -> DataArray: return v # Tangent space is R^d - def random_uniform( - self, key: hd_typing.DataArray, shape: tuple[int, ...] - ) -> hd_typing.DataArray: + def random_uniform(self, key: DataArray, shape: tuple[int, ...]) -> DataArray: return jax.random.uniform(key, shape) def velocity( self, - x: hd_typing.DataArray, - y: hd_typing.DataArray, - t: hd_typing.DataArray, - ) -> hd_typing.DataArray: + x: DataArray, + y: DataArray, + t: DataArray, + ) -> DataArray: # Geodesics on the flat torus are straight lines (with periodic wrapping), # so the velocity is constant and independent of time t. del t # Unused. diff --git a/hackable_diffusion/lib/manifolds_test.py b/hackable_diffusion/lib/manifolds_test.py index 980f3c3..a522694 100644 --- a/hackable_diffusion/lib/manifolds_test.py +++ b/hackable_diffusion/lib/manifolds_test.py @@ -149,22 +149,6 @@ def test_geodesic_stays_on_manifold(self, t_val): norms = jnp.linalg.norm(gt, axis=-1) np.testing.assert_allclose(norms, 1.0, atol=1e-5) - def test_geodesic_raises_for_t_above_one(self): - """geodesic should raise ValueError when t > 1.""" - x = jnp.array([[1.0, 0.0, 0.0]]) - y = jnp.array([[0.0, 1.0, 0.0]]) - t = jnp.array([[1.5]]) - with self.assertRaises(ValueError): - manifolds.geodesic(self.manifold, x, y, t) - - def test_geodesic_raises_for_t_below_zero(self): - """geodesic should raise ValueError when t < 0.""" - x = jnp.array([[1.0, 0.0, 0.0]]) - y = jnp.array([[0.0, 1.0, 0.0]]) - t = jnp.array([[-0.1]]) - with self.assertRaises(ValueError): - manifolds.geodesic(self.manifold, x, y, t) - def test_exp_zero_tangent(self): """exp(x, 0) should return x.""" key = jax.random.PRNGKey(3) diff --git a/hackable_diffusion/lib/sampling/__init__.py b/hackable_diffusion/lib/sampling/__init__.py index c7563ee..8a38e6c 100644 --- a/hackable_diffusion/lib/sampling/__init__.py +++ b/hackable_diffusion/lib/sampling/__init__.py @@ -36,6 +36,7 @@ from hackable_diffusion.lib.sampling.gaussian_step_sampler import HeunStep from hackable_diffusion.lib.sampling.gaussian_step_sampler import SdeStep from hackable_diffusion.lib.sampling.gaussian_step_sampler import VelocityStep +from hackable_diffusion.lib.sampling.riemannian_sampling import RiemannianFlowSamplerStep from hackable_diffusion.lib.sampling.sampling import DiffusionSampler from hackable_diffusion.lib.sampling.sampling import SampleFn from hackable_diffusion.lib.sampling.simplicial_step_sampler import SimplicialDDIMStep diff --git a/hackable_diffusion/lib/sampling/riemannian_sampling.py b/hackable_diffusion/lib/sampling/riemannian_sampling.py new file mode 100644 index 0000000..a1f61b9 --- /dev/null +++ b/hackable_diffusion/lib/sampling/riemannian_sampling.py @@ -0,0 +1,84 @@ +# Copyright 2026 Hackable Diffusion Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Riemannian Flow Matching sampler step.""" + +import dataclasses +from hackable_diffusion.lib import hd_typing +from hackable_diffusion.lib.corruption import riemannian +from hackable_diffusion.lib.sampling import base +import kauldron.ktyping as kt + +################################################################################ +# MARK: Type Aliases +################################################################################ + +DataTree = hd_typing.DataTree +TargetInfoTree = hd_typing.TargetInfoTree + +################################################################################ +# MARK: Sampler Step +################################################################################ + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class RiemannianFlowSamplerStep(base.SamplerStep): + """Euler integration on Riemannian manifold for Flow Matching.""" + + corruption_process: riemannian.RiemannianProcess + + @kt.typechecked + def initialize( + self, + initial_noise: DataTree, + initial_step_info: base.StepInfoTree, + ) -> base.DiffusionStepTree: + return base.DiffusionStep( + xt=initial_noise, + step_info=initial_step_info, + aux={}, + ) + + @kt.typechecked + def update( + self, + prediction: TargetInfoTree, + current_step: base.DiffusionStep, + next_step_info: base.StepInfoTree, + ) -> base.DiffusionStepTree: + xt = current_step.xt + t = current_step.step_info.time + next_t = next_step_info.time + dt = next_t - t + + v = prediction['velocity'] + + # Riemannian Euler integration step. The exponential map generalizes the + # Euclidean update x_{t+dt} = x_t + dt * v to manifolds. + next_xt = self.corruption_process.manifold.exp(xt, dt * v) + + return base.DiffusionStep( + xt=next_xt, + step_info=next_step_info, + aux={}, + ) + + @kt.typechecked + def finalize( + self, + prediction: TargetInfoTree, + current_step: base.DiffusionStep, + last_step_info: base.StepInfoTree, + ) -> base.DiffusionStepTree: + return current_step diff --git a/hackable_diffusion/lib/sampling/riemannian_sampling_test.py b/hackable_diffusion/lib/sampling/riemannian_sampling_test.py new file mode 100644 index 0000000..c1a6616 --- /dev/null +++ b/hackable_diffusion/lib/sampling/riemannian_sampling_test.py @@ -0,0 +1,150 @@ +# Copyright 2026 Hackable Diffusion Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Riemannian Flow Matching sampler step.""" + +from absl.testing import absltest +from hackable_diffusion.lib import manifolds +from hackable_diffusion.lib.corruption import riemannian +from hackable_diffusion.lib.corruption import schedules +from hackable_diffusion.lib.sampling import base +from hackable_diffusion.lib.sampling import riemannian_sampling +import jax +import jax.numpy as jnp +import numpy as np + + +def _make_sampler(manifold): + process = riemannian.RiemannianProcess( + manifold=manifold, + schedule=schedules.LinearRiemannianSchedule(), + ) + return riemannian_sampling.RiemannianFlowSamplerStep( + corruption_process=process, + ) + + +class RiemannianFlowSamplerStepTest(absltest.TestCase): + + def test_update_sphere(self): + """Euler step on S² moves along geodesic.""" + manifold = manifolds.Sphere() + sampler = _make_sampler(manifold) + key = jax.random.PRNGKey(0) + + xt = jnp.array([[1.0, 0.0, 0.0]]) + v = jnp.array([[0.0, 1.0, 0.0]]) # Tangent vector + + current_step = base.DiffusionStep( + xt=xt, + step_info=base.StepInfo(step=0, time=jnp.array([0.0]), rng=key), + aux={}, + ) + next_step_info = base.StepInfo(step=1, time=jnp.array([0.1]), rng=key) + + prediction = {"velocity": v} + next_step = sampler.update(prediction, current_step, next_step_info) + + # dt = 0.1, so next_xt = exp_xt(0.1 * v) = [cos(0.1), sin(0.1), 0]. + expected_xt = jnp.array([[jnp.cos(0.1), jnp.sin(0.1), 0.0]]) + np.testing.assert_allclose(next_step.xt, expected_xt, atol=1e-5) + # Result stays on the sphere. + np.testing.assert_allclose(jnp.linalg.norm(next_step.xt), 1.0, atol=1e-5) + + def test_update_so3(self): + """Euler step on SO(3) produces a valid rotation matrix.""" + manifold = manifolds.SO3() + sampler = _make_sampler(manifold) + key = jax.random.PRNGKey(1) + + xt = jnp.eye(3)[None, ...] # Identity rotation (1, 3, 3). + # Tangent vector at identity is a skew-symmetric matrix. + v = jnp.array([[[0.0, -0.1, 0.0], [0.1, 0.0, 0.0], [0.0, 0.0, 0.0]]]) + + current_step = base.DiffusionStep( + xt=xt, + step_info=base.StepInfo(step=0, time=jnp.array([0.0]), rng=key), + aux={}, + ) + next_step_info = base.StepInfo(step=1, time=jnp.array([0.1]), rng=key) + + prediction = {"velocity": v} + next_step = sampler.update(prediction, current_step, next_step_info) + + # Check result is a valid rotation: R^T R = I and det(R) = 1. + rtrt = jnp.matmul(jnp.swapaxes(next_step.xt, -2, -1), next_step.xt) + np.testing.assert_allclose(rtrt, jnp.eye(3)[None, ...], atol=1e-5) + np.testing.assert_allclose(jnp.linalg.det(next_step.xt), 1.0, atol=1e-5) + + def test_update_torus(self): + """Euler step on Torus wraps around [0, 1).""" + manifold = manifolds.Torus() + sampler = _make_sampler(manifold) + key = jax.random.PRNGKey(2) + + xt = jnp.array([[0.9, 0.1, 0.5]]) + v = jnp.array([[0.5, -0.5, 0.0]]) + + current_step = base.DiffusionStep( + xt=xt, + step_info=base.StepInfo(step=0, time=jnp.array([0.0]), rng=key), + aux={}, + ) + next_step_info = base.StepInfo(step=1, time=jnp.array([1.0]), rng=key) + + prediction = {"velocity": v} + next_step = sampler.update(prediction, current_step, next_step_info) + + # dt = 1.0, so next_xt = exp(xt, v) = (xt + v) % 1.0. + expected_xt = jnp.array([[(0.9 + 0.5) % 1.0, (0.1 - 0.5) % 1.0, 0.5]]) + np.testing.assert_allclose(next_step.xt, expected_xt, atol=1e-5) + # Result stays in [0, 1). + self.assertTrue(jnp.all(next_step.xt >= 0.0)) + self.assertTrue(jnp.all(next_step.xt < 1.0)) + + def test_initialize(self): + """Initialize returns a DiffusionStep with the given noise and step info.""" + manifold = manifolds.Sphere() + sampler = _make_sampler(manifold) + key = jax.random.PRNGKey(0) + + initial_noise = manifold.random_uniform(key, (4, 3)) + initial_step_info = base.StepInfo(step=0, time=jnp.array([0.0]), rng=key) + + step = sampler.initialize(initial_noise, initial_step_info) + + np.testing.assert_array_equal(step.xt, initial_noise) + self.assertEqual(step.step_info.step, 0) + + def test_finalize(self): + """Finalize returns the current step unchanged.""" + manifold = manifolds.Sphere() + sampler = _make_sampler(manifold) + key = jax.random.PRNGKey(0) + + xt = jnp.array([[1.0, 0.0, 0.0]]) + current_step = base.DiffusionStep( + xt=xt, + step_info=base.StepInfo(step=5, time=jnp.array([1.0]), rng=key), + aux={}, + ) + last_step_info = base.StepInfo(step=6, time=jnp.array([1.0]), rng=key) + prediction = {"velocity": jnp.zeros_like(xt)} + + result = sampler.finalize(prediction, current_step, last_step_info) + np.testing.assert_array_equal(result.xt, xt) + + +if __name__ == "__main__": + absltest.main()