From 13f5ff88ac36d4870b662e284dd32197fea947b3 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 9 Jun 2026 12:06:07 +0200 Subject: [PATCH] Fix stale output type bugs in rewrites --- pytensor/tensor/rewriting/subtensor.py | 34 +++++++------- pytensor/tensor/rewriting/subtensor_lift.py | 37 +++++++++------- tests/tensor/rewriting/test_subtensor.py | 44 ++++++++++++++++++- tests/tensor/rewriting/test_subtensor_lift.py | 36 +++++++++++++++ 4 files changed, 118 insertions(+), 33 deletions(-) diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 6c07ddb807..ef51cae306 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -2738,22 +2738,26 @@ def local_blockwise_inc_subtensor(fgraph, node): idxs_core_ndim = [len(inp_sig) for inp_sig in node.op.inputs_sig[2:]] max_idx_core_ndim = max(idxs_core_ndim, default=0) - # Broadcast buffer to batch_shape - if x.type.broadcastable != out.type.broadcastable: - batch_shape = [1] * batch_ndim - for inp in node.inputs: - for i, (broadcastable, batch_dim) in enumerate( - zip(inp.type.broadcastable[:batch_ndim], tuple(inp.shape)[:batch_ndim]) - ): - if broadcastable: - # This dimension is broadcastable, it doesn't provide shape information - continue - if batch_shape[i] != 1: - # We already found a source of shape for this batch dimension - continue - batch_shape[i] = batch_dim + # Broadcast buffer to batch_shape. The output batch shape and broadcast + # pattern are derived from the inputs, never from `out.type`, which can be + # stale after an upstream rewrite swaps an input. + batch_shape = [1] * batch_ndim + out_batch_bcast = [True] * batch_ndim + for inp in node.inputs: + for i, (broadcastable, batch_dim) in enumerate( + zip(inp.type.broadcastable[:batch_ndim], tuple(inp.shape)[:batch_ndim]) + ): + if broadcastable: + # This dimension is broadcastable, it doesn't provide shape information + continue + out_batch_bcast[i] = False + if batch_shape[i] != 1: + # We already found a source of shape for this batch dimension + continue + batch_shape[i] = batch_dim + + if list(x.type.broadcastable[:batch_ndim]) != out_batch_bcast: x = broadcast_to(x, (*batch_shape, *x.shape[batch_ndim:])) - assert x.type.broadcastable == out.type.broadcastable # Massage indices so they respect blockwise semantics while using regular indexing core_idxs = [] diff --git a/pytensor/tensor/rewriting/subtensor_lift.py b/pytensor/tensor/rewriting/subtensor_lift.py index 98920e415d..e170c2c61b 100644 --- a/pytensor/tensor/rewriting/subtensor_lift.py +++ b/pytensor/tensor/rewriting/subtensor_lift.py @@ -567,8 +567,13 @@ def local_subtensor_of_expand_dims(fgraph, node): idx_tuple = indices_from_subtensor(idx, node.op.idx_list) - # Keep indexes for the original dimensions, and drop indexes for the expanded dimensions when safe + # Keep indexes for the original dimensions, and drop indexes for the expanded + # dimensions when safe. We also track where each kept expanded dimension lands + # in the output (`out_pos`), so it can be re-introduced with expand_dims below. + # These axes are derived from the indices, not the (possibly stale) output type. new_idxs = [] + expand_axes = [] + out_pos = 0 for i, idx_item in enumerate(idx_tuple): if i in expanded_axes: if isinstance(idx_item, slice): @@ -576,6 +581,8 @@ def local_subtensor_of_expand_dims(fgraph, node): if idx_item == slice(None): # A None slice, always keeps the dimension. # We skip the index, and later introduce the needed expand_dim + expand_axes.append(out_pos) + out_pos += 1 continue else: # Other slices could keep or drop the dimension. @@ -589,27 +596,23 @@ def local_subtensor_of_expand_dims(fgraph, node): else: # Keep indexes for non-expanded dimensions new_idxs.append(idx_item) + # An integer index drops the dimension; any slice keeps one. + if isinstance(idx_item, slice): + out_pos += 1 + + # Trailing dimensions beyond the explicit indices are implicit full slices; + # the expanded ones among them must also be re-introduced. + for axis in range(len(idx_tuple), ds.type.ndim): + if axis in expanded_axes: + expand_axes.append(out_pos) + out_pos += 1 [old_out] = node.outputs out = x[tuple(new_idxs)] + if expand_axes: + out = expand_dims(out, axis=expand_axes) copy_stack_trace(old_out, out) - if out.type.broadcastable != old_out.type.broadcastable: - # Re-introduce needed new dimensions (corresponding to full slices on the original expanded dimensions) - # If out.type.broadcastable == (False) and old_out.type.broadcastable == (True, False, True) - # then axis = (0, 2) - old_bcast = list(old_out.type.broadcastable) - expanded_bcast = list(out.type.broadcastable) - axis = [] - i = 0 - while i < len(old_bcast): - if i == len(expanded_bcast) or expanded_bcast[i] != old_bcast[i]: - expanded_bcast.insert(i, True) - axis.append(i) - i += 1 - out = expand_dims(out, axis=axis) - copy_stack_trace(old_out, out) - return [out] diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index deb9ead5ac..ecb671e3ce 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -9,7 +9,7 @@ from pytensor.compile.mode import Mode, get_default_mode, get_mode from pytensor.compile.ops import DeepCopyOp from pytensor.configdefaults import config -from pytensor.graph import rewrite_graph, vectorize_graph +from pytensor.graph import FunctionGraph, rewrite_graph, vectorize_graph from pytensor.graph.basic import Constant, Variable, equal_computations from pytensor.graph.rewriting.basic import check_stack_trace, in2out, out2in from pytensor.graph.traversal import ancestors @@ -22,6 +22,7 @@ _slice_to_arange, local_add_of_sparse_write, local_adv_idx_to_slice, + local_blockwise_inc_subtensor, local_replace_AdvancedSubtensor, local_useless_slice, ) @@ -2508,6 +2509,47 @@ def test_non_consecutive_integer_indices(self, vectorize_idx, basic_idx): ) np.testing.assert_allclose(fn(test_vec_a), ref_fn(test_vec_a)) + def test_stale_output_type(self): + """The rewrite must not depend on the Blockwise output type, which can be + stale after an upstream rewrite swaps an input for a more broadcastable one. + + Here the batched buffer dim becomes length 1; the (stale) Blockwise output + type still claims it is non-broadcastable. The batch shape / broadcast + decision is derived from the inputs, so the lift still succeeds. + """ + core_x = tensor("core_x", shape=(6, 6)) + core_y = tensor("core_y", shape=(3,), dtype=int) + core_graph = core_x[-1, :3].set(core_y) + + x = tensor("x", shape=(None, 6, 6)) + x_new = tensor("x_new", shape=(1, 6, 6)) + out = vectorize_graph(core_graph, replace={core_x: x}) + assert isinstance(out.owner.op, Blockwise) + + fgraph = FunctionGraph([x, core_y], [out], clone=True) + [cloned_out] = fgraph.outputs + cloned_x, cloned_y = fgraph.inputs + node = cloned_out.owner + + # Forge a stale state: the batched buffer dim becomes length 1, but the + # Blockwise output type still claims it is non-broadcastable. + fgraph.replace(cloned_x, x_new, import_missing=True) + assert not cloned_out.type.broadcastable[0] + assert node.inputs[0].type.broadcastable[0] + + # Before the fix this raised an AssertionError (asserting against the + # stale output type); now it succeeds and matches the untouched `out`. + [new_out] = local_blockwise_inc_subtensor.transform(fgraph, node) + + rng = np.random.default_rng(2167) + test_x = rng.normal(size=(1, 6, 6)) + test_y = rng.integers(1, 10, size=(3,)) + ref_mode = Mode(linker="py", optimizer=None) + np.testing.assert_allclose( + new_out.eval({x_new: test_x, cloned_y: test_y}, mode=ref_mode), + out.eval({x: test_x, core_y: test_y}, mode=ref_mode), + ) + class TestUselessSlice: def test_positive_step(self): diff --git a/tests/tensor/rewriting/test_subtensor_lift.py b/tests/tensor/rewriting/test_subtensor_lift.py index 85344f3d8b..56dd04a2b9 100644 --- a/tests/tensor/rewriting/test_subtensor_lift.py +++ b/tests/tensor/rewriting/test_subtensor_lift.py @@ -51,6 +51,7 @@ _diag_indices, local_subtensor_make_vector, local_subtensor_of_batch_dims, + local_subtensor_of_expand_dims, local_subtensor_shape_constant, ) from pytensor.tensor.shape import Shape_i, SpecifyShape, _shape @@ -488,6 +489,41 @@ def test_local_subtensor_of_expand_dims(original_fn, expected_fn): ) +def test_local_subtensor_of_expand_dims_stale_output_type(): + """The re-expansion axes must be derived from the indices, not from the + Subtensor's (possibly stale) output broadcastable pattern. + + When an upstream rewrite turns an indexed dim from non-broadcastable to + length 1, the cached Subtensor output type still claims non-broadcastable. + Comparing the freshly indexed result against that stale type would misplace + the expand_dims and yield a mis-shaped (here: wrong-ndim) graph. + """ + x = pt.tensor("x", shape=(None, None, None), dtype="float64") + x_new = pt.tensor("x_new", shape=(None, 1, None), dtype="float64") + out = expand_dims(x, 3)[7] + + fgraph = FunctionGraph([x], [out], clone=True) + [cloned_out] = fgraph.outputs + [cloned_x] = fgraph.inputs + node = cloned_out.owner + + # Forge a stale state: dim 1 becomes length 1, but the Subtensor's cached + # output type still claims dim 1 is non-broadcastable. + fgraph.replace(cloned_x, x_new, import_missing=True) + assert not cloned_out.type.broadcastable[1] + assert node.inputs[0].owner.inputs[0].type.broadcastable[1] + + [new_out] = local_subtensor_of_expand_dims.transform(fgraph, node) + # `new_out` (on x_new) must match the untouched `out` (on x) for the same data, + # in particular keep the same ndim (the bug produced ndim 4). + assert new_out.type.ndim == out.type.ndim + x_test = np.random.default_rng(232).normal(size=(10, 1, 3)) + np.testing.assert_allclose( + new_out.eval({x_new: x_test}, mode=NO_OPTIMIZATION_MODE), + out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE), + ) + + @pytest.mark.parametrize( "original_fn, expected_fn", [