Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions hackable_diffusion/lib/architecture/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,10 @@ def __call__(
if self.normalize_qk:
scale = self.param(
"norm_qk_scale",
nn.initializers.constant(
jnp.log2(seq_len_kv**2 - seq_len_kv + SAFETY_EPSILON)
),
(1, 1, 1, 1),
nn.initializers.zeros_init(),
(1, 1, 1, 1),
)
scale = jnp.exp2(scale)

norm_q = jnp.linalg.norm(q, ord=2, axis=-1, keepdims=True)
norm_k = jnp.linalg.norm(k, ord=2, axis=-1, keepdims=True)
Expand Down
6 changes: 2 additions & 4 deletions hackable_diffusion/lib/sampling/discrete_step_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,15 +696,13 @@ def update(

# Extract invariant probabilities.
pi = self.invariant_probs_vec
pi_xt = pi[xt[..., 0]][..., None] # The prior prob of the current token
# (bsz, *seq_len, 1)

# Calculate q(x_t | x_s).
q_xt_given_xs = ratio * xt_oh + (1.0 - ratio) * pi_xt
q_xt_given_xs = ratio * xt_oh + (1.0 - ratio) * pi
# (bsz, *seq_len, M)

# Calculate q(x_t | x_0)'
q_xt_given_x0 = alpha_t * xt_oh + (1.0 - alpha_t) * pi_xt
q_xt_given_x0 = alpha_t * xt_oh + (1.0 - alpha_t) * pi
# (bsz, *seq_len, M)

# Calculate integration weights: W(x_0) = p(x_0 | x_t) / q(x_t | x_0).
Expand Down