diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py index e70fd67b72..cc0385bf27 100644 --- a/pytensor/link/jax/dispatch/tensor_basic.py +++ b/pytensor/link/jax/dispatch/tensor_basic.py @@ -41,8 +41,20 @@ def allocempty(*shape): @jax_funcify.register(Alloc) def jax_funcify_Alloc(op, node, **kwargs): + static_shapes = [] + for shape_input in node.inputs[1:]: + try: + static_shapes.append(int(get_scalar_constant_value(shape_input))) + except NotScalarConstantError: + concrete_shape = None + break + else: + concrete_shape = tuple(static_shapes) + def alloc(x, *shape): - res = jnp.broadcast_to(x, shape) + res = jnp.broadcast_to( + x, concrete_shape if concrete_shape is not None else shape + ) Alloc._check_runtime_broadcast(node, jnp.asarray(x), res.shape) return res