Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 46 additions & 3 deletions docs/architecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,57 @@ print(f"Output shape: {output.shape}")

For simpler, non-image data, a `ConditionalMLP` backbone is provided. It
processes the input `x`, combines it with conditioning embeddings, and passes it
through a series of dense layers. This module is mainly use for testing
through a series of dense layers. This module is mainly used for testing
purposes.

### `RiemannianConditionalBackbone`

(`lib/architecture/riemannian.py`)

A specialized wrapper for any `ConditionalBackbone` that handles Riemannian
manifold constraints. Its primary role is to ensure that the model's output
`velocity` is a valid **tangent vector** at the point `xt`.

This is achieved by applying the manifold's **`project`** operator to the raw
output of the underlying backbone.

#### Riemannian Projections

Each manifold defines a `project(x, v)` method that ensures the output $$v$$ is a
valid tangent vector at point $$x$$.

* **Sphere ($$S^d$$)**: The projection is $$v_{\text{tangent}} = v - \langle x, v \rangle x$$, which removes the component of $$v$$ parallel to $$x$$.
* **SO(3)**: The projection maps a $$3 \times 3$$ matrix $$V$$ to the tangent space
$$T_R SO(3)$$ by computing the skew-symmetric part of the relative velocity in
the Lie algebra: $$R \cdot \text{skew}(R^T V)$$, where
$$\text{skew}(\Omega) = 0.5(\Omega - \Omega^T)$$.

By wrapping a standard neural network (e.g., a UNet) in this backbone, we can
learn complex velocity fields on manifolds using standard architectures.

#### Example Usage

```python
from hackable_diffusion.lib import manifolds
from hackable_diffusion.lib.architecture.riemannian import RiemannianConditionalBackbone
from hackable_diffusion.lib.architecture.mlp import ConditionalMLP

# 1. Choose a manifold
manifold = manifolds.Sphere()

# 2. Create a standard backbone
mlp = ConditionalMLP(num_features=256, num_layers=3)

# 3. Wrap it in a RiemannianConditionalBackbone
model = RiemannianConditionalBackbone(
backbone=mlp,
manifold=manifold,
)
```

The conditioning mechanism is simpler here, limited to `SUM` or `CONCATENATE` of
the conditioning embeddings with the intermediate representation of `x`.

## Attention

### `MultiHeadAttention`

(`lib/architecture/attention.py`)
Expand Down
130 changes: 125 additions & 5 deletions docs/corruption.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@ various corruption processes for both continuous and discrete data.

The main components are:

* **`CorruptionProcess` Protocol**: An interface that standardizes how
* **`CorruptionProcess` Protocol**: An interface that standardizes how
corruption is applied.
* **Schedules**: Functions that define the rate and nature of corruption over
* **Schedules**: Functions that define the rate and nature of corruption over
time `t`.
* **Process Implementations**: Concrete classes like `GaussianProcess` for
continuous data (e.g., images) and `CategoricalProcess` for discrete data
(e.g., labels, tokens).
* **Process Implementations**: Concrete classes like `GaussianProcess` for
continuous data (e.g., images), `CategoricalProcess` for discrete data
(e.g., labels, tokens), and `RiemannianProcess` for data on Riemannian
manifolds.

## `CorruptionProcess` Protocol

Expand Down Expand Up @@ -230,3 +231,122 @@ print(f"Logits target shape: {target_info['logits'].shape}")
* The model prediction for discrete data is expected to be logits over the
categories. `convert_predictions` will then convert these logits to a
predicted `x0` (via argmax).

## `RiemannianProcess`

(`lib/corruption/riemannian.py`)

This process implements **Riemannian Flow Matching (RFM)**, a generalization of
Flow Matching to smooth Riemannian manifolds. Unlike standard diffusion, which
relies on Gaussian noise, RFM uses the manifold's intrinsic geometry to
interpolate between data and noise distributions.

### Mathematical Foundations: Continuous-time Flow Matching

Let $$(\mathcal{M}, g)$$ be a $$d$$-dimensional smooth Riemannian manifold. A
probability path $$p_t$$ on $$\mathcal{M}$$ can be defined via the **Continuity
Equation**:

$$\frac{\partial p_t}{\partial t} + \operatorname{div}_g (p_t v_t) = 0$$

where $$\operatorname{div}_g$$ is the Riemannian divergence operator and $$v_t \in T_x \mathcal{M}$$ is a time-dependent vector field. Riemannian Flow Matching aims to find a vector field $$v_\theta(x, t)$$ that generates a path $$p_t$$ such that $$p_0$$ is the data distribution and $$p_1$$ is an invariant noise distribution.

### Riemannian Concepts: Exp, Log, and Geodesics

The geometry of the manifold is abstracted through three key operations implemented in `lib/manifolds.py`:

#### 1. Exponential Mapping ($$\text{Exp}_x$$)

The exponential map $$\text{Exp}_x : T_x \mathcal{M} \to \mathcal{M}$$ provides a way to "map" a tangent vector $$v \in T_x \mathcal{M}$$ back onto the manifold. Intuitively, if you start at point $$x$$ and walk in the direction of $$v$$ for unit time along the unique "straightest" path (geodesic), you arrive at $$\text{Exp}_x(v)$$.

In the library, this is used during **sampling** (to move from $$x_t$$ to $$x_{t-dt}$$) and to construct geodesics.

#### 2. Logarithm Mapping ($$\text{Log}_x$$)

The logarithm map $$\text{Log}_x : \mathcal{M} \to T_x \mathcal{M}$$ is the inverse of the exponential map (where defined). Given two points $$x, y \in \mathcal{M}$$, $$\text{Log}_x(y)$$ returns the tangent vector at $$x$$ that points toward $$y$$ along the shortest geodesic. The length of this vector equals the Riemannian distance between the two points: $$\|\text{Log}_x(y)\|_g = d_g(x, y)$$.

In the library, this is used during **training** to find the direction of the conditional flow between noise and data.

#### 3. Geodesic Mapping ($$\gamma$$)

A geodesic is the generalization of a straight line to curved spaces. The unique geodesic path starting at $$x$$ and ending at $$y$$ can be parameterized by $$t \in [0, 1]$$ as:

$$\gamma(t) = \text{Exp}_x(t \cdot \text{Log}_x(y))$$

This mapping ensures that the interpolation between distributions stays on the manifold and follows the shortest possible paths, which is the cornerstone of Riemannian Flow Matching.

### The Riemannian Flow Matching loss

$$\mathcal{L}(\theta) = \mathbb{E}_{t \sim \mathcal{U}[0, 1], x_0 \sim p_0, x_1 \sim p_1} [ \| v_{\theta}(x_t, t) - u_t(x_t | x_0, x_1) \|_{g}^2 ]$$

where the conditional velocity field $$u_t(x|x_0, x_1)$$ is derived from a
conditional probability path $$p_t(x|x_0, x_1)$$ that satisfies the continuity
equation. In this library, we use **geodesic paths** for the conditional
interpolation:

1. **Conditional Path**: $$x_t = \text{Exp}_{x_1}(\alpha(t) \text{Log}_{x_1}(x_0))$$
2. **Conditional Velocity**: $$u_t(x_t | x_0, x_1) = \dot{\alpha}(t) \cdot \frac{d}{ds} \text{Exp}_{x_1}(s \text{Log}_{x_1}(x_0)) \big|_{s=\alpha(t)}$$

For the standard `LinearRiemannianSchedule`, $$\alpha(t) = 1 - t$$, meaning the
path flows from noise ($$t=0, \alpha=1, x_{t=0}=x_1$$) to data ($$t=1, \alpha=0,
x_{t=1}=x_0$$). *Note: The implementation uses $$\alpha(t)$$ such that $$t=0$$ is
clean data and $$t=1$$ is noise, with internal interpolation adjustments to
match this theory.*

### Supported Manifolds (`lib/manifolds.py`)

Each manifold implements the `Manifold` protocol, providing core geometric
operations with an emphasis on numerical stability.

#### 1. Unit Hypersphere ($$S^d$$)

Points $$x \in \mathbb{R}^{d+1}$$ such that $$\|x\|_2 = 1$$. The tangent space
$$T_x S^d$$ is the subspace $$\{v \in \mathbb{R}^{d+1} \mid \langle x, v \rangle = 0\}$$.

* **Exp**: $$\text{Exp}_x(v) = \cos(\|v\|)x + \text{sinc}(\|v\|)v$$
* **Log**: $$\text{Log}_x(y) = \frac{\theta}{\sin \theta}(y - \cos \theta x)$$, where $$\theta = \arccos(\langle x, y \rangle)$$
* **Velocity**: The time-derivative along the geodesic:
$$u_t = -\theta \sin(\theta t)x_1 + \cos(\theta t) \text{Log}_{x_1}(x_0)$$

The implementation uses an **unnormalized sinc trick** ($$\text{sinc}(x) = \frac{\sin x}{x}$$) to handle the singularity at $$\theta=0$$ gracefully.

#### 2. Special Orthogonal Group ($$SO(3)$$)

Points $$R$$ are $$3 \times 3$$ rotation matrices. The tangent space $$T_R SO(3)$$ is
isomorphic to the Lie Algebra $$\mathfrak{so}(3)$$ of skew-symmetric matrices
via $$R \cdot \omega^\wedge$$.

* **Exp**: Computed via **Rodrigues' Rotation Formula**:
$$\text{Exp}_R(v) = R (I + \text{sinc}(\theta)\omega^\wedge + \text{cosc}(\theta)(\omega^\wedge)^2)$$, where $$\theta = \|\omega\|$$.
* **Log**: Maps $$R_1^T R_0$$ to its rotation axis and angle $$\theta$$.
* **Velocity**: $$u_t = x_t \cdot \text{Log}(x_1^T x_0)$$.

The library uses a safe **cosc trick** ($$\text{cosc}(x) = \frac{1 - \cos x}{x^2} = \frac{1}{2} \text{sinc}(\frac{x}{2})^2$$) to ensure numerical stability in the Rodrigues formula.

#### 3. Flat Torus ($[0, 1]^d$)

The torus is a flat space with periodic boundary conditions.

* **Metric**: Standard Euclidean metric $$g = I$$.
* **Geodesics**: Straight lines modulo 1.
* **Velocity**: Constant velocity $$u = \text{Log}_{x_1}(x_0) = (x_0 - x_1 + 0.5) \pmod 1 - 0.5$$.

### Example Usage

```python
from hackable_diffusion.lib import manifolds
from hackable_diffusion.lib.corruption.riemannian import RiemannianProcess
from hackable_diffusion.lib.corruption.schedules import LinearRiemannianSchedule

# 1. Define manifold and process
manifold = manifolds.Sphere()
schedule = LinearRiemannianSchedule()
process = RiemannianProcess(manifold=manifold, schedule=schedule)

# 2. Corrupt data
x0 = jnp.array([[1.0, 0.0, 0.0]]) # Point on S2
time = jnp.array([0.5])
xt, target_info = process.corrupt(subkey, x0, time)

# target_info['velocity'] is the regression target u_t
```
17 changes: 11 additions & 6 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ system for encoding and injecting conditioning signals via

This module defines the **forward process** of diffusion. It includes
implementations for corrupting data with noise, such as `GaussianProcess` for
continuous data and `CategoricalProcess` for discrete data. It also defines the
noise `schedules` that govern the corruption over time.
continuous data, `CategoricalProcess` for discrete data, and `RiemannianProcess`
for data on Riemannian manifolds (e.g., Sphere, SO(3), Torus). It also defines
the noise `schedules` that govern the corruption over time.

### [Inference Function](./inference.md)

Expand Down Expand Up @@ -101,14 +102,18 @@ The `notebooks/` directory contains a set of tutorials that demonstrate how to
use the library to train and sample from diffusion models. These serve as
excellent starting points for understanding the library's components in action.

* **`2d_training.ipynb`**: A minimal example that trains a diffusion model on
* **`2d_training.ipynb`**: A minimal example that trains a diffusion model on
a simple 2D toy dataset.
* **`mnist.ipynb`**: Trains a standard continuous diffusion model (Gaussian
* **`mnist.ipynb`**: Trains a standard continuous diffusion model (Gaussian
process) on the MNIST dataset, demonstrating image data handling.
* **`mnist_discrete.ipynb`**: Trains a discrete diffusion model on MNIST,
* **`mnist_discrete.ipynb`**: Trains a discrete diffusion model on MNIST,
treating pixel values as categorical data. This showcases the use of
`CategoricalProcess`.
* **`mnist_multimodal.ipynb`**: A more advanced example that trains a
* **`mnist_multimodal.ipynb`**: A more advanced example that trains a
multimodal model to jointly generate MNIST images with discrete and
continuous diffusion models, demonstrating the "Nested" design pattern in a
practical setting.
* **`riemannian_sphere_training.ipynb`**: Demonstrates Riemannian Flow
Matching on the unit sphere S^2.
* **`riemannian_torus_ode_to_sde.ipynb`**: Shows how to use Riemannian Flow
Matching on the torus manifold for both ODE and SDE sampling.
36 changes: 36 additions & 0 deletions docs/loss.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,39 @@ loss requires a `DiscreteSchedule`.

This is a concrete implementation that computes discrete diffusion loss without
any weighting (i.e. weight=1).
---

## Riemannian Flow Matching Loss

Training a Riemannian Flow Matching (RFM) model requires a loss function that
respects the intrinsic geometry of the manifold $$(\mathcal{M}, g)$$.

### Metric-Aware Loss

The **Riemannian Flow Matching loss** is defined as the squared norm of the
difference between the model's velocity prediction $$v_\theta$$ and the true
geodesic velocity $$u_t$$:

$$\mathcal{L}(\theta) = \mathbb{E}_{t, x_0, x_1} [ \| v_{\theta}(x_t, t) - u_t(x_t | x_0, x_1) \|_{g}^2 ]$$

where the norm is induced by the Riemannian metric $$g$$ at point $$x_t$$:

$$\| v \|_{g} = \sqrt{g_{x_t}(v, v)}$$

### Implementation for Embedded Manifolds

For many manifolds implemented in this library (like the Sphere $$S^d$$ or $$SO(3)$$), the Riemannian metric is induced by the standard Euclidean metric of the ambient space $$\mathbb{R}^n$$. In these cases, the loss simplifies to:

$$\mathcal{L}(\theta) = \mathbb{E}_{t, x_0, x_1} [ \| v_{\theta}(x_t, t) - u_t(x_t | x_0, x_1) \|_{2}^2 ]$$

**Crucially**, this equivalence only holds if $$v_{\theta}$$ and $$u_t$$ are both valid **tangent vectors** (i.e., $$v, u \in T_{x_t} \mathcal{M}$$). The library ensures this via:

1. **True Target**: The `RiemannianProcess` returns a $$u_t$$ that is
mathematically guaranteed to be tangent to the manifold.
2. **Model Forecast**: The **`RiemannianConditionalBackbone`** (see
[Architecture docs](./architecture.md)) acts as a wrapper that projects the
raw model output onto the tangent space $$T_{x_t} \mathcal{M}$$ before
computing the loss.

By enforcing the tangent space constraint, the RFM objective can be optimized
using standard MSE loss while remaining geometrically rigorous.
Loading
Loading