diff --git a/docs/architecture.md b/docs/architecture.md index ef254a6..bbdc6a1 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -140,14 +140,57 @@ print(f"Output shape: {output.shape}") For simpler, non-image data, a `ConditionalMLP` backbone is provided. It processes the input `x`, combines it with conditioning embeddings, and passes it -through a series of dense layers. This module is mainly use for testing +through a series of dense layers. This module is mainly used for testing purposes. +### `RiemannianConditionalBackbone` + +(`lib/architecture/riemannian.py`) + +A specialized wrapper for any `ConditionalBackbone` that handles Riemannian +manifold constraints. Its primary role is to ensure that the model's output +`velocity` is a valid **tangent vector** at the point `xt`. + +This is achieved by applying the manifold's **`project`** operator to the raw +output of the underlying backbone. + +#### Riemannian Projections + +Each manifold defines a `project(x, v)` method that ensures the output $$v$$ is a +valid tangent vector at point $$x$$. + +* **Sphere ($$S^d$$)**: The projection is $$v_{\text{tangent}} = v - \langle x, v \rangle x$$, which removes the component of $$v$$ parallel to $$x$$. +* **SO(3)**: The projection maps a $$3 \times 3$$ matrix $$V$$ to the tangent space + $$T_R SO(3)$$ by computing the skew-symmetric part of the relative velocity in + the Lie algebra: $$R \cdot \text{skew}(R^T V)$$, where + $$\text{skew}(\Omega) = 0.5(\Omega - \Omega^T)$$. + +By wrapping a standard neural network (e.g., a UNet) in this backbone, we can +learn complex velocity fields on manifolds using standard architectures. + +#### Example Usage + +```python +from hackable_diffusion.lib import manifolds +from hackable_diffusion.lib.architecture.riemannian import RiemannianConditionalBackbone +from hackable_diffusion.lib.architecture.mlp import ConditionalMLP + +# 1. Choose a manifold +manifold = manifolds.Sphere() + +# 2. Create a standard backbone +mlp = ConditionalMLP(num_features=256, num_layers=3) + +# 3. Wrap it in a RiemannianConditionalBackbone +model = RiemannianConditionalBackbone( + backbone=mlp, + manifold=manifold, +) +``` + The conditioning mechanism is simpler here, limited to `SUM` or `CONCATENATE` of the conditioning embeddings with the intermediate representation of `x`. -## Attention - ### `MultiHeadAttention` (`lib/architecture/attention.py`) diff --git a/docs/corruption.md b/docs/corruption.md index 73c0d9c..86ebbf1 100644 --- a/docs/corruption.md +++ b/docs/corruption.md @@ -16,13 +16,14 @@ various corruption processes for both continuous and discrete data. The main components are: - * **`CorruptionProcess` Protocol**: An interface that standardizes how +* **`CorruptionProcess` Protocol**: An interface that standardizes how corruption is applied. - * **Schedules**: Functions that define the rate and nature of corruption over +* **Schedules**: Functions that define the rate and nature of corruption over time `t`. - * **Process Implementations**: Concrete classes like `GaussianProcess` for - continuous data (e.g., images) and `CategoricalProcess` for discrete data - (e.g., labels, tokens). +* **Process Implementations**: Concrete classes like `GaussianProcess` for + continuous data (e.g., images), `CategoricalProcess` for discrete data + (e.g., labels, tokens), and `RiemannianProcess` for data on Riemannian + manifolds. ## `CorruptionProcess` Protocol @@ -230,3 +231,122 @@ print(f"Logits target shape: {target_info['logits'].shape}") * The model prediction for discrete data is expected to be logits over the categories. `convert_predictions` will then convert these logits to a predicted `x0` (via argmax). + +## `RiemannianProcess` + +(`lib/corruption/riemannian.py`) + +This process implements **Riemannian Flow Matching (RFM)**, a generalization of +Flow Matching to smooth Riemannian manifolds. Unlike standard diffusion, which +relies on Gaussian noise, RFM uses the manifold's intrinsic geometry to +interpolate between data and noise distributions. + +### Mathematical Foundations: Continuous-time Flow Matching + +Let $$(\mathcal{M}, g)$$ be a $$d$$-dimensional smooth Riemannian manifold. A +probability path $$p_t$$ on $$\mathcal{M}$$ can be defined via the **Continuity +Equation**: + +$$\frac{\partial p_t}{\partial t} + \operatorname{div}_g (p_t v_t) = 0$$ + +where $$\operatorname{div}_g$$ is the Riemannian divergence operator and $$v_t \in T_x \mathcal{M}$$ is a time-dependent vector field. Riemannian Flow Matching aims to find a vector field $$v_\theta(x, t)$$ that generates a path $$p_t$$ such that $$p_0$$ is the data distribution and $$p_1$$ is an invariant noise distribution. + +### Riemannian Concepts: Exp, Log, and Geodesics + +The geometry of the manifold is abstracted through three key operations implemented in `lib/manifolds.py`: + +#### 1. Exponential Mapping ($$\text{Exp}_x$$) + +The exponential map $$\text{Exp}_x : T_x \mathcal{M} \to \mathcal{M}$$ provides a way to "map" a tangent vector $$v \in T_x \mathcal{M}$$ back onto the manifold. Intuitively, if you start at point $$x$$ and walk in the direction of $$v$$ for unit time along the unique "straightest" path (geodesic), you arrive at $$\text{Exp}_x(v)$$. + +In the library, this is used during **sampling** (to move from $$x_t$$ to $$x_{t-dt}$$) and to construct geodesics. + +#### 2. Logarithm Mapping ($$\text{Log}_x$$) + +The logarithm map $$\text{Log}_x : \mathcal{M} \to T_x \mathcal{M}$$ is the inverse of the exponential map (where defined). Given two points $$x, y \in \mathcal{M}$$, $$\text{Log}_x(y)$$ returns the tangent vector at $$x$$ that points toward $$y$$ along the shortest geodesic. The length of this vector equals the Riemannian distance between the two points: $$\|\text{Log}_x(y)\|_g = d_g(x, y)$$. + +In the library, this is used during **training** to find the direction of the conditional flow between noise and data. + +#### 3. Geodesic Mapping ($$\gamma$$) + +A geodesic is the generalization of a straight line to curved spaces. The unique geodesic path starting at $$x$$ and ending at $$y$$ can be parameterized by $$t \in [0, 1]$$ as: + +$$\gamma(t) = \text{Exp}_x(t \cdot \text{Log}_x(y))$$ + +This mapping ensures that the interpolation between distributions stays on the manifold and follows the shortest possible paths, which is the cornerstone of Riemannian Flow Matching. + +### The Riemannian Flow Matching loss + +$$\mathcal{L}(\theta) = \mathbb{E}_{t \sim \mathcal{U}[0, 1], x_0 \sim p_0, x_1 \sim p_1} [ \| v_{\theta}(x_t, t) - u_t(x_t | x_0, x_1) \|_{g}^2 ]$$ + +where the conditional velocity field $$u_t(x|x_0, x_1)$$ is derived from a +conditional probability path $$p_t(x|x_0, x_1)$$ that satisfies the continuity +equation. In this library, we use **geodesic paths** for the conditional +interpolation: + +1. **Conditional Path**: $$x_t = \text{Exp}_{x_1}(\alpha(t) \text{Log}_{x_1}(x_0))$$ +2. **Conditional Velocity**: $$u_t(x_t | x_0, x_1) = \dot{\alpha}(t) \cdot \frac{d}{ds} \text{Exp}_{x_1}(s \text{Log}_{x_1}(x_0)) \big|_{s=\alpha(t)}$$ + +For the standard `LinearRiemannianSchedule`, $$\alpha(t) = 1 - t$$, meaning the +path flows from noise ($$t=0, \alpha=1, x_{t=0}=x_1$$) to data ($$t=1, \alpha=0, +x_{t=1}=x_0$$). *Note: The implementation uses $$\alpha(t)$$ such that $$t=0$$ is +clean data and $$t=1$$ is noise, with internal interpolation adjustments to +match this theory.* + +### Supported Manifolds (`lib/manifolds.py`) + +Each manifold implements the `Manifold` protocol, providing core geometric +operations with an emphasis on numerical stability. + +#### 1. Unit Hypersphere ($$S^d$$) + +Points $$x \in \mathbb{R}^{d+1}$$ such that $$\|x\|_2 = 1$$. The tangent space +$$T_x S^d$$ is the subspace $$\{v \in \mathbb{R}^{d+1} \mid \langle x, v \rangle = 0\}$$. + +* **Exp**: $$\text{Exp}_x(v) = \cos(\|v\|)x + \text{sinc}(\|v\|)v$$ +* **Log**: $$\text{Log}_x(y) = \frac{\theta}{\sin \theta}(y - \cos \theta x)$$, where $$\theta = \arccos(\langle x, y \rangle)$$ +* **Velocity**: The time-derivative along the geodesic: + $$u_t = -\theta \sin(\theta t)x_1 + \cos(\theta t) \text{Log}_{x_1}(x_0)$$ + +The implementation uses an **unnormalized sinc trick** ($$\text{sinc}(x) = \frac{\sin x}{x}$$) to handle the singularity at $$\theta=0$$ gracefully. + +#### 2. Special Orthogonal Group ($$SO(3)$$) + +Points $$R$$ are $$3 \times 3$$ rotation matrices. The tangent space $$T_R SO(3)$$ is +isomorphic to the Lie Algebra $$\mathfrak{so}(3)$$ of skew-symmetric matrices +via $$R \cdot \omega^\wedge$$. + +* **Exp**: Computed via **Rodrigues' Rotation Formula**: + $$\text{Exp}_R(v) = R (I + \text{sinc}(\theta)\omega^\wedge + \text{cosc}(\theta)(\omega^\wedge)^2)$$, where $$\theta = \|\omega\|$$. +* **Log**: Maps $$R_1^T R_0$$ to its rotation axis and angle $$\theta$$. +* **Velocity**: $$u_t = x_t \cdot \text{Log}(x_1^T x_0)$$. + +The library uses a safe **cosc trick** ($$\text{cosc}(x) = \frac{1 - \cos x}{x^2} = \frac{1}{2} \text{sinc}(\frac{x}{2})^2$$) to ensure numerical stability in the Rodrigues formula. + +#### 3. Flat Torus ($[0, 1]^d$) + +The torus is a flat space with periodic boundary conditions. + +* **Metric**: Standard Euclidean metric $$g = I$$. +* **Geodesics**: Straight lines modulo 1. +* **Velocity**: Constant velocity $$u = \text{Log}_{x_1}(x_0) = (x_0 - x_1 + 0.5) \pmod 1 - 0.5$$. + +### Example Usage + +```python +from hackable_diffusion.lib import manifolds +from hackable_diffusion.lib.corruption.riemannian import RiemannianProcess +from hackable_diffusion.lib.corruption.schedules import LinearRiemannianSchedule + +# 1. Define manifold and process +manifold = manifolds.Sphere() +schedule = LinearRiemannianSchedule() +process = RiemannianProcess(manifold=manifold, schedule=schedule) + +# 2. Corrupt data +x0 = jnp.array([[1.0, 0.0, 0.0]]) # Point on S2 +time = jnp.array([0.5]) +xt, target_info = process.corrupt(subkey, x0, time) + +# target_info['velocity'] is the regression target u_t +``` diff --git a/docs/index.md b/docs/index.md index 11894f1..8449ebc 100644 --- a/docs/index.md +++ b/docs/index.md @@ -66,8 +66,9 @@ system for encoding and injecting conditioning signals via This module defines the **forward process** of diffusion. It includes implementations for corrupting data with noise, such as `GaussianProcess` for -continuous data and `CategoricalProcess` for discrete data. It also defines the -noise `schedules` that govern the corruption over time. +continuous data, `CategoricalProcess` for discrete data, and `RiemannianProcess` +for data on Riemannian manifolds (e.g., Sphere, SO(3), Torus). It also defines +the noise `schedules` that govern the corruption over time. ### [Inference Function](./inference.md) @@ -101,14 +102,18 @@ The `notebooks/` directory contains a set of tutorials that demonstrate how to use the library to train and sample from diffusion models. These serve as excellent starting points for understanding the library's components in action. - * **`2d_training.ipynb`**: A minimal example that trains a diffusion model on +* **`2d_training.ipynb`**: A minimal example that trains a diffusion model on a simple 2D toy dataset. - * **`mnist.ipynb`**: Trains a standard continuous diffusion model (Gaussian +* **`mnist.ipynb`**: Trains a standard continuous diffusion model (Gaussian process) on the MNIST dataset, demonstrating image data handling. - * **`mnist_discrete.ipynb`**: Trains a discrete diffusion model on MNIST, +* **`mnist_discrete.ipynb`**: Trains a discrete diffusion model on MNIST, treating pixel values as categorical data. This showcases the use of `CategoricalProcess`. - * **`mnist_multimodal.ipynb`**: A more advanced example that trains a +* **`mnist_multimodal.ipynb`**: A more advanced example that trains a multimodal model to jointly generate MNIST images with discrete and continuous diffusion models, demonstrating the "Nested" design pattern in a practical setting. +* **`riemannian_sphere_training.ipynb`**: Demonstrates Riemannian Flow + Matching on the unit sphere S^2. +* **`riemannian_torus_ode_to_sde.ipynb`**: Shows how to use Riemannian Flow + Matching on the torus manifold for both ODE and SDE sampling. diff --git a/docs/loss.md b/docs/loss.md index 8c35f55..13db094 100644 --- a/docs/loss.md +++ b/docs/loss.md @@ -158,3 +158,39 @@ loss requires a `DiscreteSchedule`. This is a concrete implementation that computes discrete diffusion loss without any weighting (i.e. weight=1). +--- + +## Riemannian Flow Matching Loss + +Training a Riemannian Flow Matching (RFM) model requires a loss function that +respects the intrinsic geometry of the manifold $$(\mathcal{M}, g)$$. + +### Metric-Aware Loss + +The **Riemannian Flow Matching loss** is defined as the squared norm of the +difference between the model's velocity prediction $$v_\theta$$ and the true +geodesic velocity $$u_t$$: + +$$\mathcal{L}(\theta) = \mathbb{E}_{t, x_0, x_1} [ \| v_{\theta}(x_t, t) - u_t(x_t | x_0, x_1) \|_{g}^2 ]$$ + +where the norm is induced by the Riemannian metric $$g$$ at point $$x_t$$: + +$$\| v \|_{g} = \sqrt{g_{x_t}(v, v)}$$ + +### Implementation for Embedded Manifolds + +For many manifolds implemented in this library (like the Sphere $$S^d$$ or $$SO(3)$$), the Riemannian metric is induced by the standard Euclidean metric of the ambient space $$\mathbb{R}^n$$. In these cases, the loss simplifies to: + +$$\mathcal{L}(\theta) = \mathbb{E}_{t, x_0, x_1} [ \| v_{\theta}(x_t, t) - u_t(x_t | x_0, x_1) \|_{2}^2 ]$$ + +**Crucially**, this equivalence only holds if $$v_{\theta}$$ and $$u_t$$ are both valid **tangent vectors** (i.e., $$v, u \in T_{x_t} \mathcal{M}$$). The library ensures this via: + +1. **True Target**: The `RiemannianProcess` returns a $$u_t$$ that is + mathematically guaranteed to be tangent to the manifold. +2. **Model Forecast**: The **`RiemannianConditionalBackbone`** (see + [Architecture docs](./architecture.md)) acts as a wrapper that projects the + raw model output onto the tangent space $$T_{x_t} \mathcal{M}$$ before + computing the loss. + +By enforcing the tangent space constraint, the RFM objective can be optimized +using standard MSE loss while remaining geometrically rigorous. diff --git a/docs/sampling.md b/docs/sampling.md index e1004f3..1d7d250 100644 --- a/docs/sampling.md +++ b/docs/sampling.md @@ -21,8 +21,8 @@ key components: 1. **`TimeSchedule`**: Defines the sequence of discrete time steps `{t_N, t_{N-1}, ..., t_0}` for the denoising process. 2. **`InferenceFn`**: The function that calls the trained model to make a - denoising prediction at a single time step (see the [Inference - Function](./inference.md) documentation). + denoising prediction at a single time step (see the + [Inference Function](./inference.md) documentation). 3. **`SamplerStep`**: An implementation of a specific sampling algorithm (e.g., DDIM, SDE) that uses the model's prediction to compute the state at the next time step. @@ -43,10 +43,10 @@ The overall flow for `N` steps is: Two main data structures manage the state of the sampling loop: - * **`StepInfo`**: A static container for all information related to a single +* **`StepInfo`**: A static container for all information related to a single step that can be pre-computed. This includes the step index, the continuous time `t`, and a JAX random key for that step. - * **`DiffusionStep`**: The complete, dynamic state of the process at a given +* **`DiffusionStep`**: The complete, dynamic state of the process at a given step. It contains the noisy data `xt` and the `StepInfo` for that step. This is the "state" that is carried over from one iteration of the sampling loop to the next. @@ -58,8 +58,8 @@ Two main data structures manage the state of the sampling loop: The `TimeSchedule` protocol is responsible for discretizing the `[0, 1]` time interval. - * `UniformTimeSchedule`: Creates linearly spaced time steps. - * `EDMTimeSchedule`: Implements the non-uniform time step distribution from +* `UniformTimeSchedule`: Creates linearly spaced time steps. +* `EDMTimeSchedule`: Implements the non-uniform time step distribution from the EDM paper, which can improve sample quality. The `rho` parameter controls the density of steps near `t=0`. @@ -80,14 +80,53 @@ computes the next `DiffusionStep`. Implementations for **Gaussian** processes include: - * **`DDIMStep`**: Implements the popular Denoising Diffusion Implicit Models +* **`DDIMStep`**: Implements the popular Denoising Diffusion Implicit Models sampler. It can be deterministic (`stoch_coeff=0.0`) or stochastic (`stoch_coeff > 0.0`). - * **`SdeStep`**: A stochastic sampler based on discretizing the reverse-time +* **`SdeStep`**: A stochastic sampler based on discretizing the reverse-time Stochastic Differential Equation (SDE). - * **`VelocityStep`**: A sampler that operates using the velocity prediction +* **`VelocityStep`**: A sampler that operates using the velocity prediction from the model. - * **`HeunStep`**: A more accurate second-order solver. +* **`HeunStep`**: A more accurate second-order solver. + +### Riemannian Sampling Theory + +Generating samples from a Riemannian Flow Matching model involves solving a +time-dependent Ordinary Differential Equation (ODE) on the manifold +$$\mathcal{M}$$: + +$$\frac{dx_t}{dt} = v_{\theta}(x_t, t), \quad x_1 \sim \text{Invariant}(\mathcal{M})$$ + +where $$v_{\theta}$$ is the learned velocity field. To solve this ODE while +remaining on the manifold, we use specialized integration schemes that respect +the manifold's intrinsic geometry. + +#### Riemannian Euler Integration + +Implementations for **Riemannian** processes include: + +* **`RiemannianFlowSamplerStep`**: Implements **Riemannian Euler + integration**. Instead of a standard additive update $$(x + dt \cdot v)$$, + this step uses the manifold's **exponential map** to move along the tangent + vector $$v$$ while respecting the manifold's curvature: + + $$x_{t-\Delta t} = \text{Exp}_{x_t}(-\Delta t \cdot v_{\theta}(x_t, t))$$ + + This ensures that the updated state $$x_{t-\Delta t}$$ remains perfectly on + the manifold $$\mathcal{M}$$ (e.g., still has unit norm on a sphere) without + needing ad-hoc projection steps. This is mathematically equivalent to moving + along the unique geodesic starting at $$x_t$$ with initial velocity + $$-v_\theta$$. + +#### Why use Riemannian Euler? + +In contrast, a **Euclidean Euler** step followed by a projection: 1. $$x' = x_t - +\Delta t \cdot v_\theta$$ 2. $$x_{t-\Delta t} = \text{Project}(x')$$ + +can lead to numerical drift and artifacts, especially when the manifold is +highly curved or the step size is large. Riemannian Euler is the "natural" +first-order integrator for manifolds as it directly utilizes the Riemannian +metric's shortest paths. Note that in our implementation we assume that one step corresponds to one NFE. While this strong assumption allows you to make the identification `num_steps = @@ -171,6 +210,31 @@ print(f"Shape of generated images: {generated_images.shape}") # Shape of generated images: (4, 32, 32, 3) ``` +### Riemannian Sampling Example + +For manifolds like the Sphere or SO(3), the setup is similar but uses +Riemannian-specific components. + +```python +from hackable_diffusion.lib import manifolds +from hackable_diffusion.lib.corruption.riemannian import RiemannianProcess +from hackable_diffusion.lib.sampling.riemannian_sampling import RiemannianFlowSamplerStep + +# 1. Define manifold and process +manifold = manifolds.Sphere() +process = RiemannianProcess(manifold=manifold) + +# 2. Configure Sampler Step +stepper = RiemannianFlowSamplerStep(corruption_process=process) + +# 3. Create the sampler +sampler = DiffusionSampler( + time_schedule=UniformTimeSchedule(), # or EDM + stepper=stepper, + num_steps=50, +) +``` + This modular setup makes it easy to experiment with different samplers (e.g., swapping `DDIMStep` for `SdeStep`), time schedules, or number of steps with minimal code changes.