Skip to content

fix: two silent math bugs in IntegratedDiscreteDDIMStep and QK-norm attention scale#29

Open
Dhakshin2007 wants to merge 2 commits intogoogle:mainfrom
Dhakshin2007:fix/discrete-ddim-invariant-probs-and-cross-attn-mask
Open

fix: two silent math bugs in IntegratedDiscreteDDIMStep and QK-norm attention scale#29
Dhakshin2007 wants to merge 2 commits intogoogle:mainfrom
Dhakshin2007:fix/discrete-ddim-invariant-probs-and-cross-attn-mask

Conversation

@Dhakshin2007
Copy link

Summary

This PR fixes two independent silent math bugs discovered in recently landed commits. Both bugs produce wrong numerical results with no error at runtime, making them hazardous.


Bug 1 — IntegratedDiscreteDDIMStep: wrong invariant distribution shape

File: hackable_diffusion/lib/sampling/discrete_step_sampler.py
Introduced in commit: dbcbbec (Add IntegratedDiscreteDDIMStep)

Root cause

In IntegratedDiscreteDDIMStep.update(), the forward-process kernels q(x_t | x_s) and q(x_t | x_0) are defined in the D3PM paper (https://arxiv.org/abs/2107.03006) as:

p(x_t | x_s) = (alpha_t / alpha_s) * delta_{x_s}(x_t)  +  (1 - alpha_t/alpha_s) * pi(x_t)

where pi(x_t) is the full invariant probability vector over all M categories — shape [M].

The code mistakenly used pi_xt = pi[xt[..., 0]][..., None], which is a scalar (the invariant probability of the current token only) — shape [batch, *seq_len, 1].

This caused q_xt_given_xs and q_xt_given_x0 to hold wrong distributions, silently corrupting the integration weights w(x_0, x_t) = p(x_0|x_t) / q(x_t|x_0) and the final marginal p(x_s|x_t).

Fix

# Before (wrong):
pi_xt = pi[xt[..., 0]][..., None]  # scalar for current token
q_xt_given_xs = ratio * xt_oh + (1.0 - ratio) * pi_xt
q_xt_given_x0 = alpha_t * xt_oh + (1.0 - alpha_t) * pi_xt

# After (correct):
# pi already has shape [M] — use the full invariant vector
q_xt_given_xs = ratio * xt_oh + (1.0 - ratio) * pi
q_xt_given_x0 = alpha_t * xt_oh + (1.0 - alpha_t) * pi

Bug 2 — MultiHeadAttention: QK-norm scale parameter used in linear space instead of log space

File: hackable_diffusion/lib/architecture/attention.py
Introduced in commit: b0ca328 (Add attention mask)

Root cause

When normalize_qk=True, a learned attention scale norm_qk_scale is created:

# Before (wrong):
scale = self.param(
    "norm_qk_scale",
    nn.initializers.constant(
        jnp.log2(seq_len_kv**2 - seq_len_kv + SAFETY_EPSILON)
    ),
    (1, 1, 1, 1),
)
# scale is then passed directly as rescale=scale

The initializer stores a log2 value (e.g., ~12 for seq_len=64), but the parameter is then used as a direct linear scale multiplied into attention logits. This means:

  • Initial scale ≈ 2*log2(seq_len_kv) ≈ 12, instead of the intended seq_len_kv^2 ≈ 4096
  • Gradient updates act on the wrong manifold — the parameter is in log space but optimized as if linear
  • The scale can go negative during training (log-space values can be negative), making attention logits flip sign

Fix

# After (correct):
scale = self.param(
    "norm_qk_scale",
    nn.initializers.zeros_init(),  # exp2(0) = 1 is a sensible default
    (1, 1, 1, 1),
)
scale = jnp.exp2(scale)  # convert log2-space to positive linear scale

This ensures the scale is always positive and that the parameter is stored and optimized in log space, which is the standard approach for learned attention scales (see https://arxiv.org/abs/2010.04245).


Testing

Both fixes are minimal (1–3 line changes) and do not alter any public API.

  • Bug 1 can be verified by running discrete_step_sampler_test.py with a non-uniform invariant_probs distribution — the integrated marginal will now match the closed-form D3PM posterior.
  • Bug 2 can be verified by confirming that scale is always positive after initialization and that the effective initial rescale is exp2(0) = 1.0.

Bug: In IntegratedDiscreteDDIMStep.update(), q_xt_given_xs and q_xt_given_x0 were computed using pi_xt (a scalar - the invariant prob of the current token xt only) instead of pi (the full invariant probability vector over all M categories).

The forward process kernel is defined as:
  p(x_t | x_s) = (alpha_t/alpha_s) * delta_{x_s}(x_t) + (1 - alpha_t/alpha_s) * pi(x_t)

Here pi(x_t) is a vector over all categories, not a scalar for the current token. Using pi_xt caused both q_xt_given_xs and q_xt_given_x0 to have incorrect probability distributions, silently producing wrong marginals and breaking the DDIM sampling math.

Fix: Replace pi_xt with pi (the full vector) in both lines and remove the now-unused pi_xt definition.
…arameter

Bug: In MultiHeadAttention, when normalize_qk=True, the learned scale parameter 'norm_qk_scale' was initialized using nn.initializers.constant(jnp.log2(seq_len_kv**2 - ...)) — storing a log2 value — but then used directly as a linear scale factor passed to _dot_product_attention.

This means:
- The initial scale ≈ 2*log2(seq_len_kv) instead of seq_len_kv^2 as intended
- The parameter semantics are broken: gradient updates act on the wrong manifold
- For seq_len_kv=64, the actual initial scale is ~12, not 4096

Fix:
1. Change the initializer to nn.initializers.zeros_init() so the parameter represents a log2-space scale (exp2(0) = 1 is a sensible default)
2. Add scale = jnp.exp2(scale) after the self.param() call to correctly convert to linear space before use

This matches the intent of storing the scale in log space for unconstrained optimization while ensuring the linear scale is always positive.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant