diff --git a/hackable_diffusion/lib/loss/discrete.py b/hackable_diffusion/lib/loss/discrete.py index 3b47faa..c8b6749 100644 --- a/hackable_diffusion/lib/loss/discrete.py +++ b/hackable_diffusion/lib/loss/discrete.py @@ -184,7 +184,7 @@ def _weight_fn( alpha_der = utils.egrad(schedule.alpha)(time) alpha = utils.flatten_non_batch_dims(alpha) alpha_der = utils.flatten_non_batch_dims(alpha_der) - weight = -1.0 * alpha_der / jnp.clip(1.0 - alpha, a_min=1e-12) + weight = -1.0 * alpha_der / jnp.clip(1.0 - alpha, min=1e-12) return weight return compute_discrete_diffusion_loss( diff --git a/hackable_diffusion/lib/sampling/discrete_step_sampler.py b/hackable_diffusion/lib/sampling/discrete_step_sampler.py index d49ef82..543ca02 100644 --- a/hackable_diffusion/lib/sampling/discrete_step_sampler.py +++ b/hackable_diffusion/lib/sampling/discrete_step_sampler.py @@ -708,7 +708,7 @@ def update( # (bsz, *seq_len, M) # Calculate integration weights: W(x_0) = p(x_0 | x_t) / q(x_t | x_0). - w_x0 = p_x0 / jnp.clip(q_xt_given_x0, a_min=1e-12) + w_x0 = p_x0 / jnp.clip(q_xt_given_x0, min=1e-12) # (bsz, *seq_len, M) sum_w = jnp.sum(w_x0, axis=-1, keepdims=True) # (bsz, *seq_len, 1) @@ -722,7 +722,7 @@ def update( # (bsz, *seq_len, M) # Convert back to logits for safe categorical sampling - total_logit = jnp.log(jnp.clip(p_xs, a_min=1e-12)) + total_logit = jnp.log(jnp.clip(p_xs, min=1e-12)) # Sample and format the new state new_xt = jax.random.categorical(key=key, logits=total_logit)[..., None]