From cee7141b3ba130f1d935a14ab59a1b16b7184412 Mon Sep 17 00:00:00 2001 From: Jona JOACHIM Date: Mon, 13 Apr 2026 14:13:23 +0200 Subject: [PATCH 1/3] JAX: fix Alloc failing under jax.jit when shape inputs are constants MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit jax_funcify_Alloc returned a closure that received shape dimensions as JAX arrays at runtime. When the compiled pytensor function was re-traced by an outer jax.jit or jax.value_and_grad call — as done by downstream libraries that extract f.vm.jit_fn and differentiate through it — those JAX arrays were promoted to JitTracers. jnp.broadcast_to requires a concrete tuple of Python ints for its shape argument and raised: TypeError: Shapes must be 1D sequences of concrete values of integer type, got (JitTracer(int64[]),). Fix: at compile time in jax_funcify_Alloc, attempt to resolve each shape input to a concrete Python int using get_scalar_constant_value. When all shape dimensions are constant (the common case), bake the resulting tuple directly into the closure so that no JAX array is ever passed as a shape argument at runtime. Dynamic shape dimensions fall back to the previous behaviour. This mirrors the existing fix already applied to jax_funcify_ARange in the same file for the same class of problem. --- pytensor/link/jax/dispatch/tensor_basic.py | 30 +++++++++++++++++++--- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py index e70fd67b72..eb885c99ea 100644 --- a/pytensor/link/jax/dispatch/tensor_basic.py +++ b/pytensor/link/jax/dispatch/tensor_basic.py @@ -41,10 +41,32 @@ def allocempty(*shape): @jax_funcify.register(Alloc) def jax_funcify_Alloc(op, node, **kwargs): - def alloc(x, *shape): - res = jnp.broadcast_to(x, shape) - Alloc._check_runtime_broadcast(node, jnp.asarray(x), res.shape) - return res + # Extract concrete shape values at compile time where possible. + # Shape inputs that are constants must be baked in as Python ints so that + # jnp.broadcast_to receives a concrete tuple even when the surrounding + # function is traced by jax.jit / jax.value_and_grad (where JAX array + # constants would otherwise become JitTracers, which broadcast_to rejects). + static_shapes = [] + for shape_input in node.inputs[1:]: + try: + static_shapes.append(int(get_scalar_constant_value(shape_input))) + except NotScalarConstantError: + static_shapes.append(None) + + if all(s is not None for s in static_shapes): + concrete_shape = tuple(static_shapes) + + def alloc(x, *shape): + res = jnp.broadcast_to(x, concrete_shape) + Alloc._check_runtime_broadcast(node, jnp.asarray(x), res.shape) + return res + + else: + + def alloc(x, *shape): + res = jnp.broadcast_to(x, shape) + Alloc._check_runtime_broadcast(node, jnp.asarray(x), res.shape) + return res return alloc From 109b55bcf706947dd1efa43bd2aae11973c154ed Mon Sep 17 00:00:00 2001 From: Jona JOACHIM Date: Mon, 13 Apr 2026 16:07:48 +0200 Subject: [PATCH 2/3] Simplify the fix: break for-loop early, remove additional if statement, only define the closure once. --- pytensor/link/jax/dispatch/tensor_basic.py | 26 ++++++---------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py index eb885c99ea..f3edb623a8 100644 --- a/pytensor/link/jax/dispatch/tensor_basic.py +++ b/pytensor/link/jax/dispatch/tensor_basic.py @@ -41,32 +41,20 @@ def allocempty(*shape): @jax_funcify.register(Alloc) def jax_funcify_Alloc(op, node, **kwargs): - # Extract concrete shape values at compile time where possible. - # Shape inputs that are constants must be baked in as Python ints so that - # jnp.broadcast_to receives a concrete tuple even when the surrounding - # function is traced by jax.jit / jax.value_and_grad (where JAX array - # constants would otherwise become JitTracers, which broadcast_to rejects). static_shapes = [] for shape_input in node.inputs[1:]: try: static_shapes.append(int(get_scalar_constant_value(shape_input))) except NotScalarConstantError: - static_shapes.append(None) - - if all(s is not None for s in static_shapes): - concrete_shape = tuple(static_shapes) - - def alloc(x, *shape): - res = jnp.broadcast_to(x, concrete_shape) - Alloc._check_runtime_broadcast(node, jnp.asarray(x), res.shape) - return res - + concrete_shape = None + break else: + concrete_shape = tuple(static_shapes) - def alloc(x, *shape): - res = jnp.broadcast_to(x, shape) - Alloc._check_runtime_broadcast(node, jnp.asarray(x), res.shape) - return res + def alloc(x, *shape): + res = jnp.broadcast_to(x, concrete_shape if concrete_shape is not None else shape) + Alloc._check_runtime_broadcast(node, jnp.asarray(x), res.shape) + return res return alloc From 108df008a13e3f2345b9b31e706e461a75807e63 Mon Sep 17 00:00:00 2001 From: Jona JOACHIM Date: Tue, 14 Apr 2026 10:43:06 +0200 Subject: [PATCH 3/3] fix linting --- pytensor/link/jax/dispatch/tensor_basic.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py index f3edb623a8..cc0385bf27 100644 --- a/pytensor/link/jax/dispatch/tensor_basic.py +++ b/pytensor/link/jax/dispatch/tensor_basic.py @@ -52,7 +52,9 @@ def jax_funcify_Alloc(op, node, **kwargs): concrete_shape = tuple(static_shapes) def alloc(x, *shape): - res = jnp.broadcast_to(x, concrete_shape if concrete_shape is not None else shape) + res = jnp.broadcast_to( + x, concrete_shape if concrete_shape is not None else shape + ) Alloc._check_runtime_broadcast(node, jnp.asarray(x), res.shape) return res