diff --git a/hackable_diffusion/lib/architecture/attention.py b/hackable_diffusion/lib/architecture/attention.py index 64e506b..280cff2 100644 --- a/hackable_diffusion/lib/architecture/attention.py +++ b/hackable_diffusion/lib/architecture/attention.py @@ -300,11 +300,10 @@ def __call__( if self.normalize_qk: scale = self.param( "norm_qk_scale", - nn.initializers.constant( - jnp.log2(seq_len_kv**2 - seq_len_kv + SAFETY_EPSILON) - ), - (1, 1, 1, 1), + nn.initializers.zeros_init(), +(1, 1, 1, 1), ) + scale = jnp.exp2(scale) norm_q = jnp.linalg.norm(q, ord=2, axis=-1, keepdims=True) norm_k = jnp.linalg.norm(k, ord=2, axis=-1, keepdims=True) diff --git a/hackable_diffusion/lib/sampling/discrete_step_sampler.py b/hackable_diffusion/lib/sampling/discrete_step_sampler.py index 49c54e7..0838b95 100644 --- a/hackable_diffusion/lib/sampling/discrete_step_sampler.py +++ b/hackable_diffusion/lib/sampling/discrete_step_sampler.py @@ -696,15 +696,13 @@ def update( # Extract invariant probabilities. pi = self.invariant_probs_vec - pi_xt = pi[xt[..., 0]][..., None] # The prior prob of the current token - # (bsz, *seq_len, 1) # Calculate q(x_t | x_s). - q_xt_given_xs = ratio * xt_oh + (1.0 - ratio) * pi_xt + q_xt_given_xs = ratio * xt_oh + (1.0 - ratio) * pi # (bsz, *seq_len, M) # Calculate q(x_t | x_0)' - q_xt_given_x0 = alpha_t * xt_oh + (1.0 - alpha_t) * pi_xt + q_xt_given_x0 = alpha_t * xt_oh + (1.0 - alpha_t) * pi # (bsz, *seq_len, M) # Calculate integration weights: W(x_0) = p(x_0 | x_t) / q(x_t | x_0).