From a815a3dfa7cc0546a12309efb1b5e3e1bf4548f9 Mon Sep 17 00:00:00 2001 From: Valentin De Bortoli Date: Sat, 14 Mar 2026 12:31:23 -0700 Subject: [PATCH] Implement discrete flow matching sampler with corrector term (gamma). PiperOrigin-RevId: 883718898 --- hackable_diffusion/lib/sampling/__init__.py | 1 + .../lib/sampling/discrete_step_sampler.py | 168 ++++++++++++++++++ .../sampling/discrete_step_sampler_test.py | 133 ++++++++++++++ 3 files changed, 302 insertions(+) diff --git a/hackable_diffusion/lib/sampling/__init__.py b/hackable_diffusion/lib/sampling/__init__.py index c7563ee..baa0ba2 100644 --- a/hackable_diffusion/lib/sampling/__init__.py +++ b/hackable_diffusion/lib/sampling/__init__.py @@ -24,6 +24,7 @@ from hackable_diffusion.lib.sampling.discrete_step_sampler import AllCorruptedMaskFn from hackable_diffusion.lib.sampling.discrete_step_sampler import CorruptedMaskFn from hackable_diffusion.lib.sampling.discrete_step_sampler import DiscreteDDIMStep +from hackable_diffusion.lib.sampling.discrete_step_sampler import DiscreteFlowMatchingStep from hackable_diffusion.lib.sampling.discrete_step_sampler import IntegratedDiscreteDDIMStep from hackable_diffusion.lib.sampling.discrete_step_sampler import MaskValueCorruptedMaskFn from hackable_diffusion.lib.sampling.discrete_step_sampler import MaxCappedRemaskingFn diff --git a/hackable_diffusion/lib/sampling/discrete_step_sampler.py b/hackable_diffusion/lib/sampling/discrete_step_sampler.py index 49c54e7..7d6fe4c 100644 --- a/hackable_diffusion/lib/sampling/discrete_step_sampler.py +++ b/hackable_diffusion/lib/sampling/discrete_step_sampler.py @@ -751,3 +751,171 @@ def finalize( current_step, last_step_info, ) + + +################################################################################ +# MARK: Discrete Flow Matching Step +################################################################################ + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class DiscreteFlowMatchingStep(SamplerStep): + """Discrete Flow Matching step following https://arxiv.org/abs/2407.15595. + + This sampler is the simplest variant of Algorithm 1 in Discrete Flow Matching, + Gat et. al., 2024, https://arxiv.org/abs/2407.15595. It implements the + update rule based on the probability velocity derived for the probability + path family in (9). + + The update rule is: + x_{t-dt} ~ (1 - prob_jump) * delta_{x_t} + prob_jump * prediction + + where prob_jump = (alpha_s - alpha_t) / (1 - alpha_t). Note that alpha(t) in + hackable diffusion is the probability of keeping the original value, which + corresponds to 1 - kappa(t) in the paper if the time is reversed. + + Attributes: + corruption_process: The corruption process to use. + temperature: The temperature to use. + gamma: The corrector term (default 0.0). Higher values introduce more noise + during the denoising process, which can improve sample quality. + """ + + corruption_process: CategoricalProcess + temperature: float = 1.0 + gamma: float = 0.0 + + def __post_init__(self): + """DiscreteFlowMatchingStep requires a DiscreteSchedule.""" + if not isinstance(self.corruption_process.schedule, DiscreteSchedule): + raise ValueError('DiscreteFlowMatchingStep requires a DiscreteSchedule.') + + @property + def unused_mask_value(self) -> int: + return self.corruption_process.unused_mask_value + + @property + def post_corruption_fn(self) -> discrete.PostCorruptionFn: + return self.corruption_process.post_corruption_fn + + @typechecked + def initialize( + self, + initial_noise: DataArray, + initial_step_info: StepInfo, + ) -> DiffusionStep: + + init_logits = jnp.repeat( + initial_noise, self.corruption_process.num_categories, axis=-1 + ) + init_logits = jnp.zeros_like(init_logits, dtype=jnp.float32) - jnp.inf + + return DiffusionStep( + xt=initial_noise, + step_info=initial_step_info, + aux={'logits': init_logits}, + ) + + @typechecked + def update( + self, + prediction: TargetInfo, + current_step: DiffusionStep, + next_step_info: StepInfo, + ) -> DiffusionStep: + + current_step_info = current_step.step_info + xt = current_step.xt + + unused_mask = xt == self.unused_mask_value + + time = current_step_info.time + next_time = next_step_info.time + time_bcast = utils.bcast_right(time, xt.ndim) + next_time_bcast = utils.bcast_right(next_time, xt.ndim) + key = next_step_info.rng + + # Sample from p_{0|t} + logits = self.corruption_process.convert_predictions( + prediction, + xt, + time_bcast, + )['logits'] + logits = logits / self.temperature + + _, sample_key, noise_key, jump_key = jax.random.split(key, 4) + sample = jax.random.categorical(key=sample_key, logits=logits)[..., None] + noise_sample = self.corruption_process.sample_from_invariant( + noise_key, data_spec=xt + ) + + # Denoising + alpha_s = self.corruption_process.schedule.alpha(next_time_bcast) + alpha_t = self.corruption_process.schedule.alpha(time_bcast) + + # prob_up is the probability of switching from the current state to the + # predicted data state. Following the paper's formula (24): + # u_fwd = (dot_kappa / (1 - kappa)) * (p_data - delta_xt) + # prob_down is the probability of switching back to noise (corrector logic): + # u_bwd = (dot_kappa / kappa) * (delta_xt - p_noise) + # Following the paper's formula (26), the combined velocity is: + # u_bar = (1 + gamma) * u_fwd - gamma * u_bwd. + # Note that since u_bwd (u^(0) in the paper) involves (delta_xt - p_noise), + # it has negative jump rates back to noise. Subtracting it (-gamma * u_bwd) + # results in positive jump probabilities in the discretization. + + # We discretize this as a jump process where each token has probability + # prob_up of jumping to data and prob_down of jumping to noise. + + prob_up = ( + (alpha_s - alpha_t) + / jnp.maximum(1.0 - alpha_t, 1e-12) + * (1.0 + self.gamma) + ) + prob_down = (alpha_s - alpha_t) / jnp.maximum(alpha_t, 1e-12) * self.gamma + + # Calculate raw, unclipped probabilities + raw_p_up = jnp.maximum(prob_up, 0.0) + raw_p_down = jnp.maximum(prob_down, 0.0) + sum_jumps = raw_p_up + raw_p_down + + # If the sum exceeds 1.0, scale them down proportionally to maintain their + # ratio + scale_factor = jnp.maximum(1.0, sum_jumps) + + p_up = raw_p_up / scale_factor + p_down = raw_p_down / scale_factor + p_stay = 1.0 - p_up - p_down + + probs = jnp.stack([p_stay, p_up, p_down], axis=-1) + probs = jnp.broadcast_to(probs, xt.shape + (3,)) + jump_type = jax.random.categorical( + jump_key, logits=jnp.log(jnp.maximum(probs, 1e-12)) + ) + + # 0: stay, 1: jump to data, 2: jump to noise + new_xt = jnp.where(jump_type == 1, sample, xt) + new_xt = jnp.where(jump_type == 2, noise_sample, new_xt) + new_xt = self.post_corruption_fn(new_xt) + + # Replace the unused tokens with the unused_mask_value. + new_xt = jnp.where(unused_mask, self.unused_mask_value, new_xt) + + return DiffusionStep( + xt=new_xt, + step_info=next_step_info, + aux={'logits': logits}, + ) + + @typechecked + def finalize( + self, + prediction: TargetInfo, + current_step: DiffusionStep, + last_step_info: StepInfo, + ) -> DiffusionStep: + return self.update( + prediction, + current_step, + last_step_info, + ) diff --git a/hackable_diffusion/lib/sampling/discrete_step_sampler_test.py b/hackable_diffusion/lib/sampling/discrete_step_sampler_test.py index 76cc79c..1b0522b 100644 --- a/hackable_diffusion/lib/sampling/discrete_step_sampler_test.py +++ b/hackable_diffusion/lib/sampling/discrete_step_sampler_test.py @@ -534,5 +534,138 @@ def test_fail_for_zero_invariant_probs(self): ) +class DiscreteFlowMatchingStepTest(absltest.TestCase): + """Tests for the DiscreteFlowMatchingStep sampler.""" + + def setUp(self): + super().setUp() + self.schedule = schedules.LinearDiscreteSchedule() + self.num_categories = 4 + self.process = CategoricalProcess.uniform_process( + schedule=self.schedule, num_categories=self.num_categories + ) + key = jax.random.PRNGKey(0) + self.initial_noise = jax.random.randint( + key, (2, 4, 1), 0, self.process.process_num_categories + ) + self.dfm_step = discrete_step_sampler.DiscreteFlowMatchingStep( + corruption_process=self.process + ) + + def _dummy_inference_fn(self, xt, conditioning, time): + del conditioning, time + # Return logits that will deterministically sample category 1. + logits = jnp.zeros(xt.shape[:-1] + (self.process.num_categories,)) + logits = logits.at[..., 1].set(10.0) + return {'logits': logits} + + def test_initialize(self): + initial_step_info = StepInfo( + step=0, + time=jnp.array([1.0, 1.0])[:, None, None], + rng=jax.random.PRNGKey(0), + ) + initial_step = self.dfm_step.initialize( + initial_noise=self.initial_noise, + initial_step_info=initial_step_info, + ) + init_logits = jnp.repeat( + self.initial_noise, self.process.num_categories, axis=-1 + ) + init_logits = jnp.zeros_like(init_logits, dtype=jnp.float32) - jnp.inf + + chex.assert_trees_all_equal( + initial_step, + DiffusionStep( + xt=self.initial_noise, + step_info=initial_step_info, + aux={'logits': init_logits}, + ), + ) + + def test_update(self): + initial_step_info = StepInfo( + step=0, + time=jnp.array([0.5, 0.5])[:, None, None], + rng=jax.random.PRNGKey(0), + ) + initial_step = self.dfm_step.initialize( + initial_noise=self.initial_noise, + initial_step_info=initial_step_info, + ) + prediction = self._dummy_inference_fn( + xt=initial_step.xt, + conditioning={}, + time=initial_step.step_info.time, + ) + + # Test case 1: Full unmasking (alpha_s=1.0, alpha_t=0.5 -> prob_jump=1.0) + next_step_info_full = StepInfo( + step=1, + time=jnp.array([0.0, 0.0])[:, None, None], + rng=jax.random.PRNGKey(1), + ) + next_step_full = self.dfm_step.update( + prediction=prediction, + current_step=initial_step, + next_step_info=next_step_info_full, + ) + expected_xt_full = jnp.ones_like(self.initial_noise) + chex.assert_trees_all_equal(next_step_full.xt, expected_xt_full) + + # Test case 2: No jump (alpha_s=0.5, alpha_t=0.5 -> prob_jump=0.0) + next_step_info_no = StepInfo( + step=1, + time=jnp.array([0.5, 0.5])[:, None, None], + rng=jax.random.PRNGKey(1), + ) + next_step_no = self.dfm_step.update( + prediction=prediction, + current_step=initial_step, + next_step_info=next_step_info_no, + ) + chex.assert_trees_all_equal(next_step_no.xt, initial_step.xt) + + def test_update_with_gamma(self): + num_samples = 10 + initial_step_info = StepInfo( + step=0, + time=jnp.array([[0.5]] * num_samples)[:, :, None], + rng=jax.random.PRNGKey(0), + ) + # Start with more samples to ensure noise jump is detected. + initial_xt = jnp.ones((num_samples, 4, 1), dtype=jnp.int32) + initial_step = self.dfm_step.initialize( + initial_noise=initial_xt, + initial_step_info=initial_step_info, + ) + + # Predict category 1 (same as current). + prediction = self._dummy_inference_fn( + xt=initial_step.xt, + conditioning={}, + time=initial_step.step_info.time, + ) + + # Use gamma that won't clip. + dfm_step_gamma = discrete_step_sampler.DiscreteFlowMatchingStep( + corruption_process=self.process, gamma=1.0 + ) + + next_step_info = StepInfo( + step=1, + time=jnp.array([[0.4]] * num_samples)[:, :, None], + rng=jax.random.PRNGKey(1), + ) + next_step = dfm_step_gamma.update( + prediction=prediction, + current_step=initial_step, + next_step_info=next_step_info, + ) + + # Some tokens should have changed to noise (not 1). + self.assertTrue(jnp.any(next_step.xt != 1)) + + if __name__ == '__main__': absltest.main()