fix: two silent math bugs in IntegratedDiscreteDDIMStep and QK-norm attention scale#29
Open
Dhakshin2007 wants to merge 2 commits intogoogle:mainfrom
Conversation
Bug: In IntegratedDiscreteDDIMStep.update(), q_xt_given_xs and q_xt_given_x0 were computed using pi_xt (a scalar - the invariant prob of the current token xt only) instead of pi (the full invariant probability vector over all M categories).
The forward process kernel is defined as:
p(x_t | x_s) = (alpha_t/alpha_s) * delta_{x_s}(x_t) + (1 - alpha_t/alpha_s) * pi(x_t)
Here pi(x_t) is a vector over all categories, not a scalar for the current token. Using pi_xt caused both q_xt_given_xs and q_xt_given_x0 to have incorrect probability distributions, silently producing wrong marginals and breaking the DDIM sampling math.
Fix: Replace pi_xt with pi (the full vector) in both lines and remove the now-unused pi_xt definition.
…arameter Bug: In MultiHeadAttention, when normalize_qk=True, the learned scale parameter 'norm_qk_scale' was initialized using nn.initializers.constant(jnp.log2(seq_len_kv**2 - ...)) — storing a log2 value — but then used directly as a linear scale factor passed to _dot_product_attention. This means: - The initial scale ≈ 2*log2(seq_len_kv) instead of seq_len_kv^2 as intended - The parameter semantics are broken: gradient updates act on the wrong manifold - For seq_len_kv=64, the actual initial scale is ~12, not 4096 Fix: 1. Change the initializer to nn.initializers.zeros_init() so the parameter represents a log2-space scale (exp2(0) = 1 is a sensible default) 2. Add scale = jnp.exp2(scale) after the self.param() call to correctly convert to linear space before use This matches the intent of storing the scale in log space for unconstrained optimization while ensuring the linear scale is always positive.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR fixes two independent silent math bugs discovered in recently landed commits. Both bugs produce wrong numerical results with no error at runtime, making them hazardous.
Bug 1 —
IntegratedDiscreteDDIMStep: wrong invariant distribution shapeFile:
hackable_diffusion/lib/sampling/discrete_step_sampler.pyIntroduced in commit:
dbcbbec(Add IntegratedDiscreteDDIMStep)Root cause
In
IntegratedDiscreteDDIMStep.update(), the forward-process kernelsq(x_t | x_s)andq(x_t | x_0)are defined in the D3PM paper (https://arxiv.org/abs/2107.03006) as:where
pi(x_t)is the full invariant probability vector over all M categories — shape[M].The code mistakenly used
pi_xt = pi[xt[..., 0]][..., None], which is a scalar (the invariant probability of the current token only) — shape[batch, *seq_len, 1].This caused
q_xt_given_xsandq_xt_given_x0to hold wrong distributions, silently corrupting the integration weightsw(x_0, x_t) = p(x_0|x_t) / q(x_t|x_0)and the final marginalp(x_s|x_t).Fix
Bug 2 —
MultiHeadAttention: QK-norm scale parameter used in linear space instead of log spaceFile:
hackable_diffusion/lib/architecture/attention.pyIntroduced in commit:
b0ca328(Add attention mask)Root cause
When
normalize_qk=True, a learned attention scalenorm_qk_scaleis created:The initializer stores a log2 value (e.g., ~12 for seq_len=64), but the parameter is then used as a direct linear scale multiplied into attention logits. This means:
2*log2(seq_len_kv)≈ 12, instead of the intendedseq_len_kv^2≈ 4096Fix
This ensures the scale is always positive and that the parameter is stored and optimized in log space, which is the standard approach for learned attention scales (see https://arxiv.org/abs/2010.04245).
Testing
Both fixes are minimal (1–3 line changes) and do not alter any public API.
discrete_step_sampler_test.pywith a non-uniforminvariant_probsdistribution — the integrated marginal will now match the closed-form D3PM posterior.scaleis always positive after initialization and that the effective initial rescale isexp2(0) = 1.0.