Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions hackable_diffusion/lib/sampling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from hackable_diffusion.lib.sampling.discrete_step_sampler import AllCorruptedMaskFn
from hackable_diffusion.lib.sampling.discrete_step_sampler import CorruptedMaskFn
from hackable_diffusion.lib.sampling.discrete_step_sampler import DiscreteDDIMStep
from hackable_diffusion.lib.sampling.discrete_step_sampler import DiscreteFlowMatchingStep
from hackable_diffusion.lib.sampling.discrete_step_sampler import IntegratedDiscreteDDIMStep
from hackable_diffusion.lib.sampling.discrete_step_sampler import MaskValueCorruptedMaskFn
from hackable_diffusion.lib.sampling.discrete_step_sampler import MaxCappedRemaskingFn
Expand Down
168 changes: 168 additions & 0 deletions hackable_diffusion/lib/sampling/discrete_step_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,3 +751,171 @@ def finalize(
current_step,
last_step_info,
)


################################################################################
# MARK: Discrete Flow Matching Step
################################################################################


@dataclasses.dataclass(frozen=True, kw_only=True)
class DiscreteFlowMatchingStep(SamplerStep):
"""Discrete Flow Matching step following https://arxiv.org/abs/2407.15595.

This sampler is the simplest variant of Algorithm 1 in Discrete Flow Matching,
Gat et. al., 2024, https://arxiv.org/abs/2407.15595. It implements the
update rule based on the probability velocity derived for the probability
path family in (9).

The update rule is:
x_{t-dt} ~ (1 - prob_jump) * delta_{x_t} + prob_jump * prediction

where prob_jump = (alpha_s - alpha_t) / (1 - alpha_t). Note that alpha(t) in
hackable diffusion is the probability of keeping the original value, which
corresponds to 1 - kappa(t) in the paper if the time is reversed.

Attributes:
corruption_process: The corruption process to use.
temperature: The temperature to use.
gamma: The corrector term (default 0.0). Higher values introduce more noise
during the denoising process, which can improve sample quality.
"""

corruption_process: CategoricalProcess
temperature: float = 1.0
gamma: float = 0.0

def __post_init__(self):
"""DiscreteFlowMatchingStep requires a DiscreteSchedule."""
if not isinstance(self.corruption_process.schedule, DiscreteSchedule):
raise ValueError('DiscreteFlowMatchingStep requires a DiscreteSchedule.')

@property
def unused_mask_value(self) -> int:
return self.corruption_process.unused_mask_value

@property
def post_corruption_fn(self) -> discrete.PostCorruptionFn:
return self.corruption_process.post_corruption_fn

@typechecked
def initialize(
self,
initial_noise: DataArray,
initial_step_info: StepInfo,
) -> DiffusionStep:

init_logits = jnp.repeat(
initial_noise, self.corruption_process.num_categories, axis=-1
)
init_logits = jnp.zeros_like(init_logits, dtype=jnp.float32) - jnp.inf

return DiffusionStep(
xt=initial_noise,
step_info=initial_step_info,
aux={'logits': init_logits},
)

@typechecked
def update(
self,
prediction: TargetInfo,
current_step: DiffusionStep,
next_step_info: StepInfo,
) -> DiffusionStep:

current_step_info = current_step.step_info
xt = current_step.xt

unused_mask = xt == self.unused_mask_value

time = current_step_info.time
next_time = next_step_info.time
time_bcast = utils.bcast_right(time, xt.ndim)
next_time_bcast = utils.bcast_right(next_time, xt.ndim)
key = next_step_info.rng

# Sample from p_{0|t}
logits = self.corruption_process.convert_predictions(
prediction,
xt,
time_bcast,
)['logits']
logits = logits / self.temperature

_, sample_key, noise_key, jump_key = jax.random.split(key, 4)
sample = jax.random.categorical(key=sample_key, logits=logits)[..., None]
noise_sample = self.corruption_process.sample_from_invariant(
noise_key, data_spec=xt
)

# Denoising
alpha_s = self.corruption_process.schedule.alpha(next_time_bcast)
alpha_t = self.corruption_process.schedule.alpha(time_bcast)

# prob_up is the probability of switching from the current state to the
# predicted data state. Following the paper's formula (24):
# u_fwd = (dot_kappa / (1 - kappa)) * (p_data - delta_xt)
# prob_down is the probability of switching back to noise (corrector logic):
# u_bwd = (dot_kappa / kappa) * (delta_xt - p_noise)
# Following the paper's formula (26), the combined velocity is:
# u_bar = (1 + gamma) * u_fwd - gamma * u_bwd.
# Note that since u_bwd (u^(0) in the paper) involves (delta_xt - p_noise),
# it has negative jump rates back to noise. Subtracting it (-gamma * u_bwd)
# results in positive jump probabilities in the discretization.

# We discretize this as a jump process where each token has probability
# prob_up of jumping to data and prob_down of jumping to noise.

prob_up = (
(alpha_s - alpha_t)
/ jnp.maximum(1.0 - alpha_t, 1e-12)
* (1.0 + self.gamma)
)
prob_down = (alpha_s - alpha_t) / jnp.maximum(alpha_t, 1e-12) * self.gamma

# Calculate raw, unclipped probabilities
raw_p_up = jnp.maximum(prob_up, 0.0)
raw_p_down = jnp.maximum(prob_down, 0.0)
sum_jumps = raw_p_up + raw_p_down

# If the sum exceeds 1.0, scale them down proportionally to maintain their
# ratio
scale_factor = jnp.maximum(1.0, sum_jumps)

p_up = raw_p_up / scale_factor
p_down = raw_p_down / scale_factor
p_stay = 1.0 - p_up - p_down

probs = jnp.stack([p_stay, p_up, p_down], axis=-1)
probs = jnp.broadcast_to(probs, xt.shape + (3,))
jump_type = jax.random.categorical(
jump_key, logits=jnp.log(jnp.maximum(probs, 1e-12))
)

# 0: stay, 1: jump to data, 2: jump to noise
new_xt = jnp.where(jump_type == 1, sample, xt)
new_xt = jnp.where(jump_type == 2, noise_sample, new_xt)
new_xt = self.post_corruption_fn(new_xt)

# Replace the unused tokens with the unused_mask_value.
new_xt = jnp.where(unused_mask, self.unused_mask_value, new_xt)

return DiffusionStep(
xt=new_xt,
step_info=next_step_info,
aux={'logits': logits},
)

@typechecked
def finalize(
self,
prediction: TargetInfo,
current_step: DiffusionStep,
last_step_info: StepInfo,
) -> DiffusionStep:
return self.update(
prediction,
current_step,
last_step_info,
)
133 changes: 133 additions & 0 deletions hackable_diffusion/lib/sampling/discrete_step_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,5 +534,138 @@ def test_fail_for_zero_invariant_probs(self):
)


class DiscreteFlowMatchingStepTest(absltest.TestCase):
"""Tests for the DiscreteFlowMatchingStep sampler."""

def setUp(self):
super().setUp()
self.schedule = schedules.LinearDiscreteSchedule()
self.num_categories = 4
self.process = CategoricalProcess.uniform_process(
schedule=self.schedule, num_categories=self.num_categories
)
key = jax.random.PRNGKey(0)
self.initial_noise = jax.random.randint(
key, (2, 4, 1), 0, self.process.process_num_categories
)
self.dfm_step = discrete_step_sampler.DiscreteFlowMatchingStep(
corruption_process=self.process
)

def _dummy_inference_fn(self, xt, conditioning, time):
del conditioning, time
# Return logits that will deterministically sample category 1.
logits = jnp.zeros(xt.shape[:-1] + (self.process.num_categories,))
logits = logits.at[..., 1].set(10.0)
return {'logits': logits}

def test_initialize(self):
initial_step_info = StepInfo(
step=0,
time=jnp.array([1.0, 1.0])[:, None, None],
rng=jax.random.PRNGKey(0),
)
initial_step = self.dfm_step.initialize(
initial_noise=self.initial_noise,
initial_step_info=initial_step_info,
)
init_logits = jnp.repeat(
self.initial_noise, self.process.num_categories, axis=-1
)
init_logits = jnp.zeros_like(init_logits, dtype=jnp.float32) - jnp.inf

chex.assert_trees_all_equal(
initial_step,
DiffusionStep(
xt=self.initial_noise,
step_info=initial_step_info,
aux={'logits': init_logits},
),
)

def test_update(self):
initial_step_info = StepInfo(
step=0,
time=jnp.array([0.5, 0.5])[:, None, None],
rng=jax.random.PRNGKey(0),
)
initial_step = self.dfm_step.initialize(
initial_noise=self.initial_noise,
initial_step_info=initial_step_info,
)
prediction = self._dummy_inference_fn(
xt=initial_step.xt,
conditioning={},
time=initial_step.step_info.time,
)

# Test case 1: Full unmasking (alpha_s=1.0, alpha_t=0.5 -> prob_jump=1.0)
next_step_info_full = StepInfo(
step=1,
time=jnp.array([0.0, 0.0])[:, None, None],
rng=jax.random.PRNGKey(1),
)
next_step_full = self.dfm_step.update(
prediction=prediction,
current_step=initial_step,
next_step_info=next_step_info_full,
)
expected_xt_full = jnp.ones_like(self.initial_noise)
chex.assert_trees_all_equal(next_step_full.xt, expected_xt_full)

# Test case 2: No jump (alpha_s=0.5, alpha_t=0.5 -> prob_jump=0.0)
next_step_info_no = StepInfo(
step=1,
time=jnp.array([0.5, 0.5])[:, None, None],
rng=jax.random.PRNGKey(1),
)
next_step_no = self.dfm_step.update(
prediction=prediction,
current_step=initial_step,
next_step_info=next_step_info_no,
)
chex.assert_trees_all_equal(next_step_no.xt, initial_step.xt)

def test_update_with_gamma(self):
num_samples = 10
initial_step_info = StepInfo(
step=0,
time=jnp.array([[0.5]] * num_samples)[:, :, None],
rng=jax.random.PRNGKey(0),
)
# Start with more samples to ensure noise jump is detected.
initial_xt = jnp.ones((num_samples, 4, 1), dtype=jnp.int32)
initial_step = self.dfm_step.initialize(
initial_noise=initial_xt,
initial_step_info=initial_step_info,
)

# Predict category 1 (same as current).
prediction = self._dummy_inference_fn(
xt=initial_step.xt,
conditioning={},
time=initial_step.step_info.time,
)

# Use gamma that won't clip.
dfm_step_gamma = discrete_step_sampler.DiscreteFlowMatchingStep(
corruption_process=self.process, gamma=1.0
)

next_step_info = StepInfo(
step=1,
time=jnp.array([[0.4]] * num_samples)[:, :, None],
rng=jax.random.PRNGKey(1),
)
next_step = dfm_step_gamma.update(
prediction=prediction,
current_step=initial_step,
next_step_info=next_step_info,
)

# Some tokens should have changed to noise (not 1).
self.assertTrue(jnp.any(next_step.xt != 1))


if __name__ == '__main__':
absltest.main()