Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ def impl(*outer_inputs):

return impl

cache_version = 1
cache_version = 2
if scalar_cache_key is None:
key = None
else:
Expand Down
48 changes: 23 additions & 25 deletions pytensor/link/numba/dispatch/vectorize_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,32 +964,30 @@ def codegen(ctx, builder, sig, args):
if spec is None:
continue
indexed_axes = {src_axis: idx_k for idx_k, src_axis in spec}
n_indexed = len(indexed_axes)
n_index_loop_dims = max(idx_types[idx_k].ndim for idx_k, _ in spec)
if n_indexed == n_index_loop_dims:
iter_shapes[p] = list(iter_shapes[p])
for idx_k, axis in spec:
iter_shapes[p][axis] = idx_shapes[idx_k][0]
else:
batch_ndim = len(input_bc_patterns[p])
max_axis = max(a for _, a in spec)
new_shape = []
new_bc = []
src_d = 0
idx_d = 0
for loop_d in range(batch_ndim):
if src_d in indexed_axes and idx_d < n_index_loop_dims:
new_shape.append(one)
new_bc.append(True)
idx_d += 1
if idx_d >= n_index_loop_dims:
src_d = max_axis + 1
else:
new_shape.append(in_shapes[p][src_d])
new_bc.append(iter_bc[p][loop_d])
src_d += 1
iter_shapes[p] = new_shape
iter_bc[p] = tuple(new_bc)
# An indexed input imposes no constraint on the loop dims produced by
# its indices: those are pinned by the index arrays' own iter_shape
# entries below. Mark each such loop dim broadcastable and copy the
# source shape for the non-indexed dims.
batch_ndim = len(input_bc_patterns[p])
max_axis = max(a for _, a in spec)
new_shape = []
new_bc = []
src_d = 0
idx_d = 0
for loop_d in range(batch_ndim):
if src_d in indexed_axes and idx_d < n_index_loop_dims:
new_shape.append(one)
new_bc.append(True)
idx_d += 1
if idx_d >= n_index_loop_dims:
src_d = max_axis + 1
else:
new_shape.append(in_shapes[p][src_d])
new_bc.append(iter_bc[p][loop_d])
src_d += 1
iter_shapes[p] = new_shape
iter_bc[p] = tuple(new_bc)

# Each index array participates in iter_shape validation.
# Write indices can broadcast against each other, but if ALL write
Expand Down
49 changes: 38 additions & 11 deletions tests/link/numba/test_indexed_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
import numpy as np
import pytest

import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import Mode, get_mode
from pytensor import Mode, function, get_mode
from pytensor.tensor.rewriting.indexed_elemwise import IndexedElemwise
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor1,
Expand All @@ -21,8 +20,8 @@

def fused_and_unfused(inputs, output):
"""Compile fused and unfused versions of a graph."""
fn = pytensor.function(inputs, output, mode=NUMBA_MODE, trust_input=True)
fn_u = pytensor.function(inputs, output, mode=NUMBA_NO_FUSION, trust_input=True)
fn = function(inputs, output, mode=NUMBA_MODE, trust_input=True)
fn_u = function(inputs, output, mode=NUMBA_NO_FUSION, trust_input=True)
return fn, fn_u


Expand Down Expand Up @@ -336,8 +335,8 @@ def test_write_with_non_indexed_leading_dims(self, target_shape, val_shape):
mode_u = NUMBA_NO_FUSION.excluding(
"local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1"
)
fn = pytensor.function([val, target], out, mode=mode, trust_input=True)
fn_u = pytensor.function([val, target], out, mode=mode_u, trust_input=True)
fn = function([val, target], out, mode=mode, trust_input=True)
fn_u = function([val, target], out, mode=mode_u, trust_input=True)
assert_fused(fn)
valv = rng.normal(size=val_shape)
tv = rng.normal(size=target_shape)
Expand Down Expand Up @@ -403,7 +402,7 @@ def test_target_not_modified_when_non_inplace(self):
y = pt.vector("y", shape=(919,))
t = pt.vector("t", shape=(85,))
out = t[idx].inc(x[idx] + y)
fn = pytensor.function([x, y, t], out, mode=NUMBA_MODE, trust_input=True)
fn = function([x, y, t], out, mode=NUMBA_MODE, trust_input=True)
xv, yv = rng.normal(size=(85,)), rng.normal(size=(919,))
tv = rng.normal(size=(85,))
tv_copy = tv.copy()
Expand Down Expand Up @@ -491,7 +490,7 @@ def test_non_inplace_aliasing_write_preserves_input(self):
t = pt.vector("t", shape=(None,))
idx = np.array([0, 2, 5], dtype=np.int64)
out = t[idx].set(t[idx] * 2.0)
fn = pytensor.function([t], out, mode=NUMBA_MODE, trust_input=True)
fn = function([t], out, mode=NUMBA_MODE, trust_input=True)

# The non-inplace write fuses fully -- no external AdvancedIncSubtensor1.
assert_fused(fn)
Expand All @@ -511,7 +510,7 @@ def test_non_inplace_aliasing_write_preserves_input(self):

# Run the inner graph via perform; the input buffer must be left untouched.
perform_outs = node.op(*node.inputs, return_list=True)
f_perform = pytensor.function(
f_perform = function(
[t],
perform_outs,
mode=Mode(linker="py", optimizer=None),
Expand Down Expand Up @@ -581,7 +580,7 @@ def test_mismatched_index_and_direct_input(self):
y = pt.vector("y", shape=(None,))
idx = pt.vector("idx", dtype="int64", shape=(None,))
out = x[idx] + y
fn = pytensor.function([x, idx, y], out, mode=NUMBA_MODE, trust_input=True)
fn = function([x, idx, y], out, mode=NUMBA_MODE, trust_input=True)
assert_fused(fn)
# Matching: idx=50, y=50 — should work
fn(np.zeros(100), np.zeros(50, dtype=np.int64), np.zeros(50))
Expand All @@ -595,11 +594,39 @@ def test_runtime_broadcast_on_index_dim(self):
y = pt.vector("y", shape=(None,))
idx = pt.vector("idx", dtype="int64", shape=(None,))
out = x[idx] + y
fn = pytensor.function([x, idx, y], out, mode=NUMBA_MODE, trust_input=True)
fn = function([x, idx, y], out, mode=NUMBA_MODE, trust_input=True)
assert_fused(fn)
# Both idx and y have length 1 — should work (both agree on dim 0)
result = fn(np.zeros(100), np.zeros(1, dtype=np.int64), np.zeros(1))
assert result.shape == (1,)
# idx=1, y=5 — should error (shape mismatch, no static broadcast info)
with pytest.raises(Exception):
fn(np.zeros(100), np.zeros(1, dtype=np.int64), np.zeros(5))

def test_loop_shape_regression(self):
"""
Regression test for https://github.com/pymc-devs/pytensor/issues/2201

Stale loop shape branch would only look at the first dimension of indices.
In this test example it would think we are iterating over a [24, 24] loop instead of a [24, 3]
"""
rng = np.random.default_rng(2201)

resp_idx = rng.integers(0, 6, size=24).astype("int32")
item_idx = rng.integers(0, 5, size=(24, 3)).astype("int32")
mask = pt.dmatrix("mask", shape=(24, 3))
beta = pt.dmatrix("beta", shape=(6, 5))

u = beta[resp_idx[:, None], item_idx]
u_masked = pt.where(mask, u, -1e10)
out = u_masked.sum()

f = function([beta, mask], out, mode=NUMBA_MODE)
assert_fused(f)

ref_f = function([beta, mask], out, mode=Mode(linker="py", optimizer=None))
test_beta = np.zeros((6, 5))
test_mask = np.ones((24, 3), dtype="bool")
res = f(beta=test_beta, mask=test_mask)
ref_res = ref_f(beta=test_beta, mask=test_mask)
np.testing.assert_allclose(res, ref_res, strict=True)
Loading