diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index d1740e78ae..1383ee1df1 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -794,7 +794,7 @@ def impl(*outer_inputs): return impl - cache_version = 1 + cache_version = 2 if scalar_cache_key is None: key = None else: diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index c5dc338da8..dfd12ed1b3 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -964,32 +964,30 @@ def codegen(ctx, builder, sig, args): if spec is None: continue indexed_axes = {src_axis: idx_k for idx_k, src_axis in spec} - n_indexed = len(indexed_axes) n_index_loop_dims = max(idx_types[idx_k].ndim for idx_k, _ in spec) - if n_indexed == n_index_loop_dims: - iter_shapes[p] = list(iter_shapes[p]) - for idx_k, axis in spec: - iter_shapes[p][axis] = idx_shapes[idx_k][0] - else: - batch_ndim = len(input_bc_patterns[p]) - max_axis = max(a for _, a in spec) - new_shape = [] - new_bc = [] - src_d = 0 - idx_d = 0 - for loop_d in range(batch_ndim): - if src_d in indexed_axes and idx_d < n_index_loop_dims: - new_shape.append(one) - new_bc.append(True) - idx_d += 1 - if idx_d >= n_index_loop_dims: - src_d = max_axis + 1 - else: - new_shape.append(in_shapes[p][src_d]) - new_bc.append(iter_bc[p][loop_d]) - src_d += 1 - iter_shapes[p] = new_shape - iter_bc[p] = tuple(new_bc) + # An indexed input imposes no constraint on the loop dims produced by + # its indices: those are pinned by the index arrays' own iter_shape + # entries below. Mark each such loop dim broadcastable and copy the + # source shape for the non-indexed dims. + batch_ndim = len(input_bc_patterns[p]) + max_axis = max(a for _, a in spec) + new_shape = [] + new_bc = [] + src_d = 0 + idx_d = 0 + for loop_d in range(batch_ndim): + if src_d in indexed_axes and idx_d < n_index_loop_dims: + new_shape.append(one) + new_bc.append(True) + idx_d += 1 + if idx_d >= n_index_loop_dims: + src_d = max_axis + 1 + else: + new_shape.append(in_shapes[p][src_d]) + new_bc.append(iter_bc[p][loop_d]) + src_d += 1 + iter_shapes[p] = new_shape + iter_bc[p] = tuple(new_bc) # Each index array participates in iter_shape validation. # Write indices can broadcast against each other, but if ALL write diff --git a/tests/link/numba/test_indexed_elemwise.py b/tests/link/numba/test_indexed_elemwise.py index 505961d1c1..22ffbe5927 100644 --- a/tests/link/numba/test_indexed_elemwise.py +++ b/tests/link/numba/test_indexed_elemwise.py @@ -3,9 +3,8 @@ import numpy as np import pytest -import pytensor import pytensor.tensor as pt -from pytensor.compile.mode import Mode, get_mode +from pytensor import Mode, function, get_mode from pytensor.tensor.rewriting.indexed_elemwise import IndexedElemwise from pytensor.tensor.subtensor import ( AdvancedIncSubtensor1, @@ -21,8 +20,8 @@ def fused_and_unfused(inputs, output): """Compile fused and unfused versions of a graph.""" - fn = pytensor.function(inputs, output, mode=NUMBA_MODE, trust_input=True) - fn_u = pytensor.function(inputs, output, mode=NUMBA_NO_FUSION, trust_input=True) + fn = function(inputs, output, mode=NUMBA_MODE, trust_input=True) + fn_u = function(inputs, output, mode=NUMBA_NO_FUSION, trust_input=True) return fn, fn_u @@ -336,8 +335,8 @@ def test_write_with_non_indexed_leading_dims(self, target_shape, val_shape): mode_u = NUMBA_NO_FUSION.excluding( "local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1" ) - fn = pytensor.function([val, target], out, mode=mode, trust_input=True) - fn_u = pytensor.function([val, target], out, mode=mode_u, trust_input=True) + fn = function([val, target], out, mode=mode, trust_input=True) + fn_u = function([val, target], out, mode=mode_u, trust_input=True) assert_fused(fn) valv = rng.normal(size=val_shape) tv = rng.normal(size=target_shape) @@ -403,7 +402,7 @@ def test_target_not_modified_when_non_inplace(self): y = pt.vector("y", shape=(919,)) t = pt.vector("t", shape=(85,)) out = t[idx].inc(x[idx] + y) - fn = pytensor.function([x, y, t], out, mode=NUMBA_MODE, trust_input=True) + fn = function([x, y, t], out, mode=NUMBA_MODE, trust_input=True) xv, yv = rng.normal(size=(85,)), rng.normal(size=(919,)) tv = rng.normal(size=(85,)) tv_copy = tv.copy() @@ -491,7 +490,7 @@ def test_non_inplace_aliasing_write_preserves_input(self): t = pt.vector("t", shape=(None,)) idx = np.array([0, 2, 5], dtype=np.int64) out = t[idx].set(t[idx] * 2.0) - fn = pytensor.function([t], out, mode=NUMBA_MODE, trust_input=True) + fn = function([t], out, mode=NUMBA_MODE, trust_input=True) # The non-inplace write fuses fully -- no external AdvancedIncSubtensor1. assert_fused(fn) @@ -511,7 +510,7 @@ def test_non_inplace_aliasing_write_preserves_input(self): # Run the inner graph via perform; the input buffer must be left untouched. perform_outs = node.op(*node.inputs, return_list=True) - f_perform = pytensor.function( + f_perform = function( [t], perform_outs, mode=Mode(linker="py", optimizer=None), @@ -581,7 +580,7 @@ def test_mismatched_index_and_direct_input(self): y = pt.vector("y", shape=(None,)) idx = pt.vector("idx", dtype="int64", shape=(None,)) out = x[idx] + y - fn = pytensor.function([x, idx, y], out, mode=NUMBA_MODE, trust_input=True) + fn = function([x, idx, y], out, mode=NUMBA_MODE, trust_input=True) assert_fused(fn) # Matching: idx=50, y=50 — should work fn(np.zeros(100), np.zeros(50, dtype=np.int64), np.zeros(50)) @@ -595,7 +594,7 @@ def test_runtime_broadcast_on_index_dim(self): y = pt.vector("y", shape=(None,)) idx = pt.vector("idx", dtype="int64", shape=(None,)) out = x[idx] + y - fn = pytensor.function([x, idx, y], out, mode=NUMBA_MODE, trust_input=True) + fn = function([x, idx, y], out, mode=NUMBA_MODE, trust_input=True) assert_fused(fn) # Both idx and y have length 1 — should work (both agree on dim 0) result = fn(np.zeros(100), np.zeros(1, dtype=np.int64), np.zeros(1)) @@ -603,3 +602,31 @@ def test_runtime_broadcast_on_index_dim(self): # idx=1, y=5 — should error (shape mismatch, no static broadcast info) with pytest.raises(Exception): fn(np.zeros(100), np.zeros(1, dtype=np.int64), np.zeros(5)) + + def test_loop_shape_regression(self): + """ + Regression test for https://github.com/pymc-devs/pytensor/issues/2201 + + Stale loop shape branch would only look at the first dimension of indices. + In this test example it would think we are iterating over a [24, 24] loop instead of a [24, 3] + """ + rng = np.random.default_rng(2201) + + resp_idx = rng.integers(0, 6, size=24).astype("int32") + item_idx = rng.integers(0, 5, size=(24, 3)).astype("int32") + mask = pt.dmatrix("mask", shape=(24, 3)) + beta = pt.dmatrix("beta", shape=(6, 5)) + + u = beta[resp_idx[:, None], item_idx] + u_masked = pt.where(mask, u, -1e10) + out = u_masked.sum() + + f = function([beta, mask], out, mode=NUMBA_MODE) + assert_fused(f) + + ref_f = function([beta, mask], out, mode=Mode(linker="py", optimizer=None)) + test_beta = np.zeros((6, 5)) + test_mask = np.ones((24, 3), dtype="bool") + res = f(beta=test_beta, mask=test_mask) + ref_res = ref_f(beta=test_beta, mask=test_mask) + np.testing.assert_allclose(res, ref_res, strict=True)