Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hackable_diffusion/lib/loss/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _weight_fn(
alpha_der = utils.egrad(schedule.alpha)(time)
alpha = utils.flatten_non_batch_dims(alpha)
alpha_der = utils.flatten_non_batch_dims(alpha_der)
weight = -1.0 * alpha_der / jnp.clip(1.0 - alpha, a_min=1e-12)
weight = -1.0 * alpha_der / jnp.clip(1.0 - alpha, min=1e-12)
return weight

return compute_discrete_diffusion_loss(
Expand Down
4 changes: 2 additions & 2 deletions hackable_diffusion/lib/sampling/discrete_step_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,7 @@ def update(
# (bsz, *seq_len, M)

# Calculate integration weights: W(x_0) = p(x_0 | x_t) / q(x_t | x_0).
w_x0 = p_x0 / jnp.clip(q_xt_given_x0, a_min=1e-12)
w_x0 = p_x0 / jnp.clip(q_xt_given_x0, min=1e-12)
# (bsz, *seq_len, M)
sum_w = jnp.sum(w_x0, axis=-1, keepdims=True)
# (bsz, *seq_len, 1)
Expand All @@ -722,7 +722,7 @@ def update(
# (bsz, *seq_len, M)

# Convert back to logits for safe categorical sampling
total_logit = jnp.log(jnp.clip(p_xs, a_min=1e-12))
total_logit = jnp.log(jnp.clip(p_xs, min=1e-12))

# Sample and format the new state
new_xt = jax.random.categorical(key=key, logits=total_logit)[..., None]
Expand Down
Loading