Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2738,22 +2738,26 @@ def local_blockwise_inc_subtensor(fgraph, node):
idxs_core_ndim = [len(inp_sig) for inp_sig in node.op.inputs_sig[2:]]
max_idx_core_ndim = max(idxs_core_ndim, default=0)

# Broadcast buffer to batch_shape
if x.type.broadcastable != out.type.broadcastable:
batch_shape = [1] * batch_ndim
for inp in node.inputs:
for i, (broadcastable, batch_dim) in enumerate(
zip(inp.type.broadcastable[:batch_ndim], tuple(inp.shape)[:batch_ndim])
):
if broadcastable:
# This dimension is broadcastable, it doesn't provide shape information
continue
if batch_shape[i] != 1:
# We already found a source of shape for this batch dimension
continue
batch_shape[i] = batch_dim
# Broadcast buffer to batch_shape. The output batch shape and broadcast
# pattern are derived from the inputs, never from `out.type`, which can be
# stale after an upstream rewrite swaps an input.
batch_shape = [1] * batch_ndim
out_batch_bcast = [True] * batch_ndim
for inp in node.inputs:
for i, (broadcastable, batch_dim) in enumerate(
zip(inp.type.broadcastable[:batch_ndim], tuple(inp.shape)[:batch_ndim])
):
if broadcastable:
# This dimension is broadcastable, it doesn't provide shape information
continue
out_batch_bcast[i] = False
if batch_shape[i] != 1:
# We already found a source of shape for this batch dimension
continue
batch_shape[i] = batch_dim

if list(x.type.broadcastable[:batch_ndim]) != out_batch_bcast:
x = broadcast_to(x, (*batch_shape, *x.shape[batch_ndim:]))
assert x.type.broadcastable == out.type.broadcastable

# Massage indices so they respect blockwise semantics while using regular indexing
core_idxs = []
Expand Down
37 changes: 20 additions & 17 deletions pytensor/tensor/rewriting/subtensor_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,15 +567,22 @@ def local_subtensor_of_expand_dims(fgraph, node):

idx_tuple = indices_from_subtensor(idx, node.op.idx_list)

# Keep indexes for the original dimensions, and drop indexes for the expanded dimensions when safe
# Keep indexes for the original dimensions, and drop indexes for the expanded
# dimensions when safe. We also track where each kept expanded dimension lands
# in the output (`out_pos`), so it can be re-introduced with expand_dims below.
# These axes are derived from the indices, not the (possibly stale) output type.
new_idxs = []
expand_axes = []
out_pos = 0
for i, idx_item in enumerate(idx_tuple):
if i in expanded_axes:
if isinstance(idx_item, slice):
# Slice could be keeping or dropping this dimension
if idx_item == slice(None):
# A None slice, always keeps the dimension.
# We skip the index, and later introduce the needed expand_dim
expand_axes.append(out_pos)
out_pos += 1
continue
else:
# Other slices could keep or drop the dimension.
Expand All @@ -589,27 +596,23 @@ def local_subtensor_of_expand_dims(fgraph, node):
else:
# Keep indexes for non-expanded dimensions
new_idxs.append(idx_item)
# An integer index drops the dimension; any slice keeps one.
if isinstance(idx_item, slice):
out_pos += 1

# Trailing dimensions beyond the explicit indices are implicit full slices;
# the expanded ones among them must also be re-introduced.
for axis in range(len(idx_tuple), ds.type.ndim):
if axis in expanded_axes:
expand_axes.append(out_pos)
out_pos += 1

[old_out] = node.outputs
out = x[tuple(new_idxs)]
if expand_axes:
out = expand_dims(out, axis=expand_axes)
copy_stack_trace(old_out, out)

if out.type.broadcastable != old_out.type.broadcastable:
# Re-introduce needed new dimensions (corresponding to full slices on the original expanded dimensions)
# If out.type.broadcastable == (False) and old_out.type.broadcastable == (True, False, True)
# then axis = (0, 2)
old_bcast = list(old_out.type.broadcastable)
expanded_bcast = list(out.type.broadcastable)
axis = []
i = 0
while i < len(old_bcast):
if i == len(expanded_bcast) or expanded_bcast[i] != old_bcast[i]:
expanded_bcast.insert(i, True)
axis.append(i)
i += 1
out = expand_dims(out, axis=axis)
copy_stack_trace(old_out, out)

return [out]


Expand Down
44 changes: 43 additions & 1 deletion tests/tensor/rewriting/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pytensor.compile.mode import Mode, get_default_mode, get_mode
from pytensor.compile.ops import DeepCopyOp
from pytensor.configdefaults import config
from pytensor.graph import rewrite_graph, vectorize_graph
from pytensor.graph import FunctionGraph, rewrite_graph, vectorize_graph
from pytensor.graph.basic import Constant, Variable, equal_computations
from pytensor.graph.rewriting.basic import check_stack_trace, in2out, out2in
from pytensor.graph.traversal import ancestors
Expand All @@ -22,6 +22,7 @@
_slice_to_arange,
local_add_of_sparse_write,
local_adv_idx_to_slice,
local_blockwise_inc_subtensor,
local_replace_AdvancedSubtensor,
local_useless_slice,
)
Expand Down Expand Up @@ -2508,6 +2509,47 @@ def test_non_consecutive_integer_indices(self, vectorize_idx, basic_idx):
)
np.testing.assert_allclose(fn(test_vec_a), ref_fn(test_vec_a))

def test_stale_output_type(self):
"""The rewrite must not depend on the Blockwise output type, which can be

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should these be marked as regression tests for the related issue?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they didn't have a related issue really, just inspired by the issue we already fixed

stale after an upstream rewrite swaps an input for a more broadcastable one.

Here the batched buffer dim becomes length 1; the (stale) Blockwise output
type still claims it is non-broadcastable. The batch shape / broadcast
decision is derived from the inputs, so the lift still succeeds.
"""
core_x = tensor("core_x", shape=(6, 6))
core_y = tensor("core_y", shape=(3,), dtype=int)
core_graph = core_x[-1, :3].set(core_y)

x = tensor("x", shape=(None, 6, 6))
x_new = tensor("x_new", shape=(1, 6, 6))
out = vectorize_graph(core_graph, replace={core_x: x})
assert isinstance(out.owner.op, Blockwise)

fgraph = FunctionGraph([x, core_y], [out], clone=True)
[cloned_out] = fgraph.outputs
cloned_x, cloned_y = fgraph.inputs
node = cloned_out.owner

# Forge a stale state: the batched buffer dim becomes length 1, but the
# Blockwise output type still claims it is non-broadcastable.
fgraph.replace(cloned_x, x_new, import_missing=True)
assert not cloned_out.type.broadcastable[0]
assert node.inputs[0].type.broadcastable[0]

# Before the fix this raised an AssertionError (asserting against the
# stale output type); now it succeeds and matches the untouched `out`.
[new_out] = local_blockwise_inc_subtensor.transform(fgraph, node)

rng = np.random.default_rng(2167)
test_x = rng.normal(size=(1, 6, 6))
test_y = rng.integers(1, 10, size=(3,))
ref_mode = Mode(linker="py", optimizer=None)
np.testing.assert_allclose(
new_out.eval({x_new: test_x, cloned_y: test_y}, mode=ref_mode),
out.eval({x: test_x, core_y: test_y}, mode=ref_mode),
)


class TestUselessSlice:
def test_positive_step(self):
Expand Down
36 changes: 36 additions & 0 deletions tests/tensor/rewriting/test_subtensor_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
_diag_indices,
local_subtensor_make_vector,
local_subtensor_of_batch_dims,
local_subtensor_of_expand_dims,
local_subtensor_shape_constant,
)
from pytensor.tensor.shape import Shape_i, SpecifyShape, _shape
Expand Down Expand Up @@ -488,6 +489,41 @@ def test_local_subtensor_of_expand_dims(original_fn, expected_fn):
)


def test_local_subtensor_of_expand_dims_stale_output_type():
"""The re-expansion axes must be derived from the indices, not from the
Subtensor's (possibly stale) output broadcastable pattern.

When an upstream rewrite turns an indexed dim from non-broadcastable to
length 1, the cached Subtensor output type still claims non-broadcastable.
Comparing the freshly indexed result against that stale type would misplace
the expand_dims and yield a mis-shaped (here: wrong-ndim) graph.
"""
x = pt.tensor("x", shape=(None, None, None), dtype="float64")
x_new = pt.tensor("x_new", shape=(None, 1, None), dtype="float64")
out = expand_dims(x, 3)[7]

fgraph = FunctionGraph([x], [out], clone=True)
[cloned_out] = fgraph.outputs
[cloned_x] = fgraph.inputs
node = cloned_out.owner

# Forge a stale state: dim 1 becomes length 1, but the Subtensor's cached
# output type still claims dim 1 is non-broadcastable.
fgraph.replace(cloned_x, x_new, import_missing=True)
assert not cloned_out.type.broadcastable[1]
assert node.inputs[0].owner.inputs[0].type.broadcastable[1]

[new_out] = local_subtensor_of_expand_dims.transform(fgraph, node)
# `new_out` (on x_new) must match the untouched `out` (on x) for the same data,
# in particular keep the same ndim (the bug produced ndim 4).
assert new_out.type.ndim == out.type.ndim
x_test = np.random.default_rng(232).normal(size=(10, 1, 3))
np.testing.assert_allclose(
new_out.eval({x_new: x_test}, mode=NO_OPTIMIZATION_MODE),
out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
)


@pytest.mark.parametrize(
"original_fn, expected_fn",
[
Expand Down
Loading