Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions hackable_diffusion/lib/architecture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from hackable_diffusion.lib.architecture.mlp_blocks import MLP
from hackable_diffusion.lib.architecture.normalization import NormalizationLayer
from hackable_diffusion.lib.architecture.normalization import NormalizationLayerFactory
from hackable_diffusion.lib.architecture.riemannian import RiemannianConditionalBackbone
from hackable_diffusion.lib.architecture.sequence_embedders import RandomFourierSequenceEmbedding
from hackable_diffusion.lib.architecture.sequence_embedders import RoPESequenceEmbedding
from hackable_diffusion.lib.architecture.sequence_embedders import SinusoidalSequenceEmbedding
Expand Down
47 changes: 47 additions & 0 deletions hackable_diffusion/lib/architecture/riemannian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2026 Hackable Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Riemannian Flow Matching architectures."""

import flax.linen as nn
from hackable_diffusion.lib import manifolds
from hackable_diffusion.lib.architecture import arch_typing

################################################################################
# MARK: Riemannian Conditional Backbone
################################################################################

ConditionalBackbone = arch_typing.ConditionalBackbone


class RiemannianConditionalBackbone(ConditionalBackbone):
"""Velocity model for Riemannian Flow Matching.

Projects the output of a backbone network to the tangent space of a manifold.
"""

backbone: ConditionalBackbone
manifold: manifolds.Manifold

@nn.compact
def __call__(self, x, conditioning_embeddings, is_training=True):

v = self.backbone(x, conditioning_embeddings, is_training=is_training)

# Project v to tangent space at xt.
if isinstance(v, dict) and 'velocity' in v:
v = v['velocity']

v_proj = self.manifold.project(x, v)
return v_proj
64 changes: 64 additions & 0 deletions hackable_diffusion/lib/architecture/riemannian_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2026 Hackable Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for Riemannian Flow Matching architectures."""

from absl.testing import absltest
from hackable_diffusion.lib import manifolds
from hackable_diffusion.lib.architecture import arch_typing
from hackable_diffusion.lib.architecture import mlp
from hackable_diffusion.lib.architecture import riemannian
import jax
import jax.numpy as jnp


class RiemannianArchitectureTest(absltest.TestCase):

def test_riemannian_backbone_projection(self):
manifold = manifolds.Sphere()
backbone = mlp.ConditionalMLP(
hidden_sizes_preprocess=(16,),
hidden_sizes_postprocess=(16,),
activation='relu',
zero_init_output=True,
dropout_rate=0.0,
conditioning_mechanism=mlp.ConditioningMechanism.CONCATENATE,
)
model = riemannian.RiemannianConditionalBackbone(
backbone=backbone,
manifold=manifold,
)

key = jax.random.PRNGKey(0)
xt = manifold.random_uniform(key, (4, 3))
time_emb = jnp.array([[0.5], [0.5], [0.5], [0.5]])

# conditioning_embeddings must be a dict keyed by ConditioningMechanism.
conditioning_embeddings = {
arch_typing.ConditioningMechanism.CONCATENATE: time_emb,
}

variables = model.init(key, xt, conditioning_embeddings, is_training=False)
v = model.apply(variables, xt, conditioning_embeddings, is_training=False)

self.assertEqual(v.shape, (4, 3))

# Check that v is tangent to xt
inner_products = jnp.sum(xt * v, axis=-1)
# Project should ensure dot(xt, v) = 0 for sphere
self.assertAlmostEqual(jnp.max(jnp.abs(inner_products)), 0.0, places=5)


if __name__ == '__main__':
absltest.main()
4 changes: 2 additions & 2 deletions hackable_diffusion/lib/corruption/riemannian.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ def corrupt(
alpha_dot_t = utils.bcast_right(self.schedule.alpha_dot(time), x0.ndim)

# x_t = geodesic(x0, x1, alpha(t)).
xt = self.manifold.exp(x0, alpha_t * self.manifold.log(x0, x1))
xt = manifolds.geodesic(self.manifold, x1, x0, alpha_t)

# By chain rule: d/dt x_t = alpha'(t) * velocity(x0, x1, alpha(t)).
vel = alpha_dot_t * self.manifold.velocity(x0, x1, alpha_t)
vel = alpha_dot_t * self.manifold.velocity(x1, x0, alpha_t)

target_info = {
'x0': x0,
Expand Down
30 changes: 18 additions & 12 deletions hackable_diffusion/lib/corruption/riemannian_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,21 @@ def test_corrupt(self):
np.testing.assert_allclose(inner_products, 0.0, atol=1e-5)

def test_velocity_at_t1(self):
"""At t=1, alpha=0 so xt = x0 and velocity = -log(x0, x1)."""
"""At t=1, alpha=0 so xt = x1 and velocity = -log(x1, x0)."""
manifold = manifolds.Sphere()
process = _make_process(manifold)
key = jax.random.PRNGKey(0)

x0 = jnp.array([[1.0, 0.0, 0.0]])
t1 = jnp.array([1.0])
xt1, target1 = process.corrupt(key, x0, t1)
np.testing.assert_allclose(xt1, x0, atol=1e-5)
x1_sampled = target1['x1']
# At t=1, alpha=0: geodesic(x1, x0, 0) = x1.
np.testing.assert_allclose(xt1, x1_sampled, atol=1e-5)

# velocity = alpha_dot(1) * velocity(x1, x0, 0) = -1 * log(x1, x0).
v1 = target1['velocity']
x1_sampled = target1['x1']
v_log = manifold.log(x0, x1_sampled)
v_log = manifold.log(x1_sampled, x0)
np.testing.assert_allclose(v1, -v_log, atol=1e-5)


Expand Down Expand Up @@ -94,19 +96,21 @@ def test_corrupt(self):
self.assertEqual(vel.shape, (batch_size, 3, 3))

def test_velocity_at_t1(self):
"""At t=1, alpha=0 so xt = x0 and velocity = -log(x0, x1)."""
"""At t=1, alpha=0 so xt = x1 and velocity = -log(x1, x0)."""
manifold = manifolds.SO3()
process = _make_process(manifold)
key = jax.random.PRNGKey(1)

x0 = jnp.eye(3)[None, ...] # (1, 3, 3)
t1 = jnp.array([1.0])
xt1, target1 = process.corrupt(key, x0, t1)
np.testing.assert_allclose(xt1, x0, atol=1e-5)
x1_sampled = target1['x1']
# At t=1, alpha=0: geodesic(x1, x0, 0) = x1.
np.testing.assert_allclose(xt1, x1_sampled, atol=1e-5)

# velocity = alpha_dot(1) * velocity(x1, x0, 0) = -1 * log(x1, x0).
v1 = target1['velocity']
x1_sampled = target1['x1']
v_log = manifold.log(x0, x1_sampled)
v_log = manifold.log(x1_sampled, x0)
np.testing.assert_allclose(v1, -v_log, atol=1e-4)


Expand All @@ -132,19 +136,21 @@ def test_corrupt(self):
self.assertEqual(vel.shape, (batch_size, dim))

def test_velocity_at_t1(self):
"""At t=1, alpha=0 so xt = x0 and velocity = -log(x0, x1)."""
"""At t=1, alpha=0 so xt = x1 and velocity = -log(x1, x0)."""
manifold = manifolds.Torus()
process = _make_process(manifold)
key = jax.random.PRNGKey(2)

x0 = jnp.array([[0.1, 0.5, 0.9]])
t1 = jnp.array([1.0])
xt1, target1 = process.corrupt(key, x0, t1)
np.testing.assert_allclose(xt1, x0, atol=1e-5)
x1_sampled = target1['x1']
# At t=1, alpha=0: geodesic(x1, x0, 0) = x1.
np.testing.assert_allclose(xt1, x1_sampled, atol=1e-5)

# velocity = alpha_dot(1) * velocity(x1, x0, 0) = -1 * log(x1, x0).
v1 = target1['velocity']
x1_sampled = target1['x1']
v_log = manifold.log(x0, x1_sampled)
v_log = manifold.log(x1_sampled, x0)
np.testing.assert_allclose(v1, -v_log, atol=1e-5)


Expand Down
4 changes: 1 addition & 3 deletions hackable_diffusion/lib/sampling/riemannian_sampling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _make_sampler(manifold):
class RiemannianFlowSamplerStepTest(absltest.TestCase):

def test_update_sphere(self):
"""Euler step on moves along geodesic."""
"""Euler step on S2 moves along geodesic."""
manifold = manifolds.Sphere()
sampler = _make_sampler(manifold)
key = jax.random.PRNGKey(0)
Expand Down Expand Up @@ -69,7 +69,6 @@ def test_update_so3(self):
key = jax.random.PRNGKey(1)

xt = jnp.eye(3)[None, ...] # Identity rotation (1, 3, 3).
# Tangent vector at identity is a skew-symmetric matrix.
v = jnp.array([[[0.0, -0.1, 0.0], [0.1, 0.0, 0.0], [0.0, 0.0, 0.0]]])

current_step = base.DiffusionStep(
Expand Down Expand Up @@ -109,7 +108,6 @@ def test_update_torus(self):
# dt = 1.0, so next_xt = exp(xt, v) = (xt + v) % 1.0.
expected_xt = jnp.array([[(0.9 + 0.5) % 1.0, (0.1 - 0.5) % 1.0, 0.5]])
np.testing.assert_allclose(next_step.xt, expected_xt, atol=1e-5)
# Result stays in [0, 1).
self.assertTrue(jnp.all(next_step.xt >= 0.0))
self.assertTrue(jnp.all(next_step.xt < 1.0))

Expand Down
Loading
Loading