From ac74c700526180de454675915aa13c97c0be6ea6 Mon Sep 17 00:00:00 2001 From: Kotha Dhakshin <179742818+Dhakshin2007@users.noreply.github.com> Date: Sun, 15 Mar 2026 19:51:13 +0530 Subject: [PATCH 1/2] fix: use full invariant prob vector pi in IntegratedDiscreteDDIMStep Bug: In IntegratedDiscreteDDIMStep.update(), q_xt_given_xs and q_xt_given_x0 were computed using pi_xt (a scalar - the invariant prob of the current token xt only) instead of pi (the full invariant probability vector over all M categories). The forward process kernel is defined as: p(x_t | x_s) = (alpha_t/alpha_s) * delta_{x_s}(x_t) + (1 - alpha_t/alpha_s) * pi(x_t) Here pi(x_t) is a vector over all categories, not a scalar for the current token. Using pi_xt caused both q_xt_given_xs and q_xt_given_x0 to have incorrect probability distributions, silently producing wrong marginals and breaking the DDIM sampling math. Fix: Replace pi_xt with pi (the full vector) in both lines and remove the now-unused pi_xt definition. --- hackable_diffusion/lib/sampling/discrete_step_sampler.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/hackable_diffusion/lib/sampling/discrete_step_sampler.py b/hackable_diffusion/lib/sampling/discrete_step_sampler.py index 49c54e7..0838b95 100644 --- a/hackable_diffusion/lib/sampling/discrete_step_sampler.py +++ b/hackable_diffusion/lib/sampling/discrete_step_sampler.py @@ -696,15 +696,13 @@ def update( # Extract invariant probabilities. pi = self.invariant_probs_vec - pi_xt = pi[xt[..., 0]][..., None] # The prior prob of the current token - # (bsz, *seq_len, 1) # Calculate q(x_t | x_s). - q_xt_given_xs = ratio * xt_oh + (1.0 - ratio) * pi_xt + q_xt_given_xs = ratio * xt_oh + (1.0 - ratio) * pi # (bsz, *seq_len, M) # Calculate q(x_t | x_0)' - q_xt_given_x0 = alpha_t * xt_oh + (1.0 - alpha_t) * pi_xt + q_xt_given_x0 = alpha_t * xt_oh + (1.0 - alpha_t) * pi # (bsz, *seq_len, M) # Calculate integration weights: W(x_0) = p(x_0 | x_t) / q(x_t | x_0). From 683fc91604ec011890ced4626ff33a3eaf3f34a3 Mon Sep 17 00:00:00 2001 From: Kotha Dhakshin <179742818+Dhakshin2007@users.noreply.github.com> Date: Sun, 15 Mar 2026 19:57:46 +0530 Subject: [PATCH 2/2] fix: use exp2 for norm_qk_scale to correctly exponentiate log-space parameter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bug: In MultiHeadAttention, when normalize_qk=True, the learned scale parameter 'norm_qk_scale' was initialized using nn.initializers.constant(jnp.log2(seq_len_kv**2 - ...)) — storing a log2 value — but then used directly as a linear scale factor passed to _dot_product_attention. This means: - The initial scale ≈ 2*log2(seq_len_kv) instead of seq_len_kv^2 as intended - The parameter semantics are broken: gradient updates act on the wrong manifold - For seq_len_kv=64, the actual initial scale is ~12, not 4096 Fix: 1. Change the initializer to nn.initializers.zeros_init() so the parameter represents a log2-space scale (exp2(0) = 1 is a sensible default) 2. Add scale = jnp.exp2(scale) after the self.param() call to correctly convert to linear space before use This matches the intent of storing the scale in log space for unconstrained optimization while ensuring the linear scale is always positive. --- hackable_diffusion/lib/architecture/attention.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/hackable_diffusion/lib/architecture/attention.py b/hackable_diffusion/lib/architecture/attention.py index 64e506b..280cff2 100644 --- a/hackable_diffusion/lib/architecture/attention.py +++ b/hackable_diffusion/lib/architecture/attention.py @@ -300,11 +300,10 @@ def __call__( if self.normalize_qk: scale = self.param( "norm_qk_scale", - nn.initializers.constant( - jnp.log2(seq_len_kv**2 - seq_len_kv + SAFETY_EPSILON) - ), - (1, 1, 1, 1), + nn.initializers.zeros_init(), +(1, 1, 1, 1), ) + scale = jnp.exp2(scale) norm_q = jnp.linalg.norm(q, ord=2, axis=-1, keepdims=True) norm_k = jnp.linalg.norm(k, ord=2, axis=-1, keepdims=True)