From e62bf08be09f299966111b706dea4d79cd42d235 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 27 Mar 2026 14:03:53 -0700 Subject: [PATCH] [JAX] Replace jnp.clip(..., a_min=..., a_max=...) with jnp.clip(..., min=..., max=...). a_min and a_max are deprecated parameter names to jax.numpy.clip. PiperOrigin-RevId: 890625997 --- hackable_diffusion/lib/loss/discrete.py | 2 +- hackable_diffusion/lib/sampling/discrete_step_sampler.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/hackable_diffusion/lib/loss/discrete.py b/hackable_diffusion/lib/loss/discrete.py index 3b47faa..c8b6749 100644 --- a/hackable_diffusion/lib/loss/discrete.py +++ b/hackable_diffusion/lib/loss/discrete.py @@ -184,7 +184,7 @@ def _weight_fn( alpha_der = utils.egrad(schedule.alpha)(time) alpha = utils.flatten_non_batch_dims(alpha) alpha_der = utils.flatten_non_batch_dims(alpha_der) - weight = -1.0 * alpha_der / jnp.clip(1.0 - alpha, a_min=1e-12) + weight = -1.0 * alpha_der / jnp.clip(1.0 - alpha, min=1e-12) return weight return compute_discrete_diffusion_loss( diff --git a/hackable_diffusion/lib/sampling/discrete_step_sampler.py b/hackable_diffusion/lib/sampling/discrete_step_sampler.py index d49ef82..543ca02 100644 --- a/hackable_diffusion/lib/sampling/discrete_step_sampler.py +++ b/hackable_diffusion/lib/sampling/discrete_step_sampler.py @@ -708,7 +708,7 @@ def update( # (bsz, *seq_len, M) # Calculate integration weights: W(x_0) = p(x_0 | x_t) / q(x_t | x_0). - w_x0 = p_x0 / jnp.clip(q_xt_given_x0, a_min=1e-12) + w_x0 = p_x0 / jnp.clip(q_xt_given_x0, min=1e-12) # (bsz, *seq_len, M) sum_w = jnp.sum(w_x0, axis=-1, keepdims=True) # (bsz, *seq_len, 1) @@ -722,7 +722,7 @@ def update( # (bsz, *seq_len, M) # Convert back to logits for safe categorical sampling - total_logit = jnp.log(jnp.clip(p_xs, a_min=1e-12)) + total_logit = jnp.log(jnp.clip(p_xs, min=1e-12)) # Sample and format the new state new_xt = jax.random.categorical(key=key, logits=total_logit)[..., None]