From 6ccda782ef7f5fc5f3f2dc94b2f080b9c03919ab Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 8 Jun 2026 22:25:39 +0200 Subject: [PATCH] Generalize local_sum__prod_alloc Applies to non-constant values and other CAReduce types --- pytensor/tensor/rewriting/math.py | 128 +++++++++++-------- tests/tensor/rewriting/test_math.py | 188 ++++++++++++++++++++++++---- 2 files changed, 241 insertions(+), 75 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 1c1b3ed29c..148922f07f 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -46,8 +46,13 @@ from pytensor.tensor.extra_ops import broadcast_arrays, concat_with_broadcast from pytensor.tensor.linalg.constructors import BlockDiagonal from pytensor.tensor.math import ( + All, + Any, Dot, + Max, + Min, Prod, + ProdWithoutZeros, Sum, _conj, _dot, @@ -2119,62 +2124,77 @@ def local_reduce_broadcastable(fgraph, node): @register_specialize -@node_rewriter([Sum, Prod]) -def local_opt_alloc(fgraph, node): - """ - sum(alloc(constant,shapes...)) => constant*prod(shapes) - or - prod(alloc(constant,shapes...)) => constant**prod(shapes) +@node_rewriter([Sum, Prod, ProdWithoutZeros, Max, Min, All, Any]) +def local_careduce_of_alloc(fgraph, node): + """ + Push a reduction through the alloc it reduces, factoring out the broadcast + axes:: + + sum(alloc(x, shapes)) => sum(x) * prod(broadcast shapes) + prod(alloc(x, shapes)) => prod(x) ** prod(broadcast shapes) + max(alloc(x, shapes)) => max(x) # and min/all/any + + A reduced axis that the alloc merely broadcasts (a new dimension, or one + where ``x`` is broadcastable) repeats the same value. Reducing such an axis + is a multiplication (``Sum``) or power (``Prod``/``ProdWithoutZeros``) by the + axis size, and a no-op for idempotent reductions (``Max``/``Min``/``All``/ + ``Any``). Reduced axes where ``x`` holds real data are reduced on ``x`` + itself. Kept axes - including broadcast ones that aren't reduced - are + re-materialized by an alloc on the output. + """ + match node.inputs[0].owner_op_and_inputs: + case (Alloc(), value, *shapes): + pass + case _: + return None - """ - (node_inps,) = node.inputs - if node_inps.owner and isinstance(node_inps.owner.op, Alloc): - inp = node_inps.owner.inputs[0] - shapes = node_inps.owner.inputs[1:] - try: - val = get_underlying_scalar_constant_value(inp, only_process_constants=True) - assert val.size == 1 - val = val.reshape(1)[0] - # check which type of op - size = mul(*shapes) - if inp.dtype in ("float16", "float32"): - # shapes are ints and normally int64. - # We don't want to have a float64 upcast - # We don't want to downcast to float16 - # as we fear it could loose too much precision - # that will be amplified by the mul/pow below. + ndim = len(shapes) + axis = node.op.axis + axis = tuple(range(ndim)) if axis is None else axis + + # ``value`` is right-aligned against the alloc output dimensions. + offset = ndim - value.type.ndim + value_bcast = value.type.broadcastable + + # Split the value's reduced dimensions: the ones it broadcasts (size-1) are + # just squeezed out (their repeat count is folded into ``size`` below), the + # rest are genuinely reduced. + squeeze_axes, reduce_axes = [], [] + for a in axis: + if (j := a - offset) >= 0: + if value_bcast[j]: + squeeze_axes.append(j) + else: + reduce_axes.append(j) + if squeeze_axes: + value = value.squeeze(squeeze_axes) + # squeezing shifts each remaining axis left past the dropped ones + reduce_axes = [j - sum(s < j for s in squeeze_axes) for j in reduce_axes] + if reduce_axes: + value = node.op.clone(axis=tuple(reduce_axes))(value) + + # The remaining reduced axes are pure broadcasts of ``value``. ``Sum`` turns + # them into a multiplication by their size, ``Prod``/``ProdWithoutZeros`` + # into a power; idempotent reductions (max/min/all/any) are unaffected, as + # repeating a value doesn't change its maximum, minimum, or truth. + if isinstance(node.op, Sum | Prod | ProdWithoutZeros): + size_shapes = [shapes[a] for a in axis if a < offset or value_bcast[a - offset]] + if size_shapes: + size = variadic_mul(*size_shapes) + if value.dtype in ("float16", "float32"): + # Avoid a float64 upcast from the int64 shapes (or a float16 + # downcast); either would be amplified by the mul/pow below. size = size.astype("float32") - if node.op.axis is None or node.op.axis == tuple(range(inp.ndim)): - if isinstance(node.op, Sum): - val = val * size - else: - val = val**size - # Sum can change the input dtype (upcast or bool - # -> float32) by default or by user request. - # We can ignore the acc_dtype, as there is only 1 - # elemwise we will do and not a sequence, so there is no - # accumulation of errors. - # So mostly, we just need to cast the output to the old - # dtype. - val = val.astype(node.outputs[0].dtype) - return [val] - to_prod = [shapes[i] for i in range(len(shapes)) if i in node.op.axis] - if to_prod: - size = mul(*to_prod) - if isinstance(node.op, Sum): - val *= size - else: - val = val**size - # See comments above. - val = val.astype(node.outputs[0].dtype) - return [ - alloc( - val, - *[shapes[i] for i in range(len(shapes)) if i not in node.op.axis], - ) - ] - except NotScalarConstantError: - pass + value = value * size if isinstance(node.op, Sum) else value**size + + # The reduction may change the dtype; a single elemwise has no accumulation + # error, so ignore acc_dtype and just cast to the reduction's output dtype. + value = value.astype(node.outputs[0].dtype) + + kept_shapes = [shapes[a] for a in range(ndim) if a not in axis] + out = alloc(value, *kept_shapes) if kept_shapes else value + copy_stack_trace(node.outputs[0], out) + return [out] @register_specialize diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 7042ddf598..8708c55211 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -41,6 +41,7 @@ Dot, Max, Prod, + ProdWithoutZeros, Sum, _conj, _matmul, @@ -3139,14 +3140,19 @@ def test_prod_of_non_scalar_mul(self): rewritten_out_fn(*test_vals), ) - def test_local_sum_prod_alloc(self): + def test_local_careduce_of_alloc(self): a = dtensor3() input = np.asarray(np.arange(2 * 3 * 4).reshape(2, 3, 4), dtype="float64") mode = self.mode.including("specialize").excluding("fusion") for t_like, n_like, nb_nodes in [ - (pt.zeros_like, np.zeros_like, (1, 3, 3, 2)), - (pt.ones_like, np.ones_like, (5, 5, 5, 6)), + # node counts for: full reduction, single-axis reduction, two chained + # single-axis reductions + (pt.zeros_like, np.zeros_like, (1, 3, 2)), + # ones_like allocs an all-broadcastable value; partial reductions + # leave a redundant ExpandDims (a generic alloc-value concern, not + # this rewrite's), hence the extra node vs zeros_like. + (pt.ones_like, np.ones_like, (5, 6, 8)), ]: # test sum f = function([a], t_like(a).sum(None), mode=mode) @@ -3164,34 +3170,17 @@ def test_local_sum_prod_alloc(self): topo = f.maker.fgraph.toposort() assert topo[-1].op == pt.alloc assert not any(isinstance(node.op, Sum) for node in topo) - for i in range(3): - f = function([a], t_like(a).sum(i), mode=mode) - utt.assert_allclose(f(input), n_like(input).sum(i)) - assert len(f.maker.fgraph.apply_nodes) == nb_nodes[2] - topo = f.maker.fgraph.toposort() - assert topo[-1].op == pt.alloc - assert not any(isinstance(node.op, Sum) for node in topo) # test prod f = function([a], t_like(a).prod(None), mode=mode) utt.assert_allclose(f(input), n_like(input).prod()) - # assert len(f.maker.fgraph.apply_nodes) == nb_nodes[0] f = function([a], t_like(a).prod([0, 1, 2]), mode=mode) utt.assert_allclose(f(input), n_like(input).prod()) - # assert len(f.maker.fgraph.apply_nodes) == nb_nodes[0] for d in range(3): f = function([a], t_like(a).prod(d), mode=mode) utt.assert_allclose(f(input), n_like(input).prod(d)) - # assert len(f.maker.fgraph.apply_nodes) == nb_nodes[1] - topo = f.maker.fgraph.toposort() - assert topo[-1].op == pt.alloc - assert not any(isinstance(node.op, Prod) for node in topo) - for i in range(3): - f = function([a], t_like(a).prod(i), mode=mode) - utt.assert_allclose(f(input), n_like(input).prod(i)) - # assert len(f.maker.fgraph.apply_nodes) == nb_nodes[2] topo = f.maker.fgraph.toposort() assert topo[-1].op == pt.alloc assert not any(isinstance(node.op, Prod) for node in topo) @@ -3199,11 +3188,168 @@ def test_local_sum_prod_alloc(self): for d, dd in [(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1)]: f = function([a], t_like(a).sum(d).sum(dd), mode=mode) utt.assert_allclose(f(input), n_like(input).sum(d).sum(dd)) - assert len(f.maker.fgraph.apply_nodes) == nb_nodes[3] + assert len(f.maker.fgraph.apply_nodes) == nb_nodes[2] topo = f.maker.fgraph.toposort() assert topo[-1].op == pt.alloc assert not any(isinstance(node.op, Sum) for node in topo) + def test_local_careduce_of_alloc_symbolic_scalar(self): + # `local_careduce_of_alloc` should also fire when the alloc'd value is a + # symbolic (non-constant) scalar, since the alloc only broadcasts it. + x = scalar("x") + mode = self.mode.excluding("fusion") + x_val = np.float64(1.5) + + # Full reduction over 2 broadcast dims -> x * (3 * 4) + f = function([x], pt.broadcast_to(x, (3, 4)).sum(), mode=mode) + utt.assert_allclose(f(x_val), np.full((3, 4), x_val).sum()) + topo = f.maker.fgraph.toposort() + assert not any(isinstance(node.op, Alloc | Sum) for node in topo) + + # Partial reduction -> an Alloc of (x * 3) remains, but no Sum + f = function([x], pt.broadcast_to(x, (3, 4)).sum(axis=0), mode=mode) + utt.assert_allclose(f(x_val), np.full((3, 4), x_val).sum(axis=0)) + topo = f.maker.fgraph.toposort() + assert topo[-1].op == pt.alloc + assert not any(isinstance(node.op, Sum) for node in topo) + + # Prod -> x ** (2 * 3) + f = function([x], pt.alloc(x, 2, 3).prod(), mode=mode) + utt.assert_allclose(f(x_val), np.full((2, 3), x_val).prod()) + topo = f.maker.fgraph.toposort() + assert not any(isinstance(node.op, Alloc | Prod) for node in topo) + + def test_local_careduce_of_alloc_broadcast_reduced_axes(self): + # `local_careduce_of_alloc` only needs the alloc to broadcast its value along the + # *reduced* axes; the value may carry real data on the kept axes (e.g. a + # batch dimension introduced by vectorizing ``alloc(x, n).sum()``). + mode = self.mode.excluding("fusion") + bx = vector("bx") + bx_val = np.array([1.0, 2.0, 3.0, 4.0]) + batched = bx[:, None] # (B, 1): broadcast along the to-be-reduced axis + + # Sum over the broadcast core axis -> bx * n, keeping the batch axis + f = function([bx], pt.alloc(batched, bx.shape[0], 5).sum(axis=1), mode=mode) + utt.assert_allclose(f(bx_val), bx_val * 5) + topo = f.maker.fgraph.toposort() + assert not any(isinstance(node.op, Alloc | Sum) for node in topo) + + # Prod likewise -> bx ** n + f = function([bx], pt.alloc(batched, bx.shape[0], 3).prod(axis=1), mode=mode) + utt.assert_allclose(f(bx_val), bx_val**3) + topo = f.maker.fgraph.toposort() + assert not any(isinstance(node.op, Alloc | Prod) for node in topo) + + # Reducing only a real-data axis (no broadcast axis reduced) pushes the + # reduction *through* the alloc: reduce the value first, then broadcast. + m = matrix("m") + m_val = np.arange(6.0).reshape(2, 3) + f = function([m], pt.alloc(m, 2, m.shape[0], m.shape[1]).sum(axis=1), mode=mode) + utt.assert_allclose(f(m_val), np.broadcast_to(m_val, (2, 2, 3)).sum(axis=1)) + topo = f.maker.fgraph.toposort() + assert topo[-1].op == pt.alloc + # The Sum now reduces ``m`` directly, not the larger alloc output. + assert not any( + isinstance(node.op, Sum) + and any( + inp.owner and isinstance(inp.owner.op, Alloc) for inp in node.inputs + ) + for node in topo + ) + + def test_local_careduce_of_alloc_mixed_axes(self): + # When a reduction spans both broadcast and real-data axes of the alloc, + # the broadcast axes are factored out as a multiplier while the real ones + # are reduced on the (smaller) value, e.g. + # ``sum(ones((5, 3)) * x)`` -> ``sum(x) * 5``. + mode = self.mode.excluding("fusion") + + x = vector("x") + x_val = np.array([1.0, 2.0, 3.0]) + f = function([x], pt.sum(pt.ones((5, 3)) * x), mode=mode) + utt.assert_allclose(f(x_val), np.broadcast_to(x_val, (5, 3)).sum()) + topo = f.maker.fgraph.toposort() + # The alloc is gone; the inner sum over ``x``'s real axis remains. + assert not any(isinstance(node.op, Alloc) for node in topo) + assert any(isinstance(node.op, Sum) for node in topo) + + # Reduce a new (broadcast) axis and a real axis, keeping another real one. + m = matrix("m") + m_val = np.arange(6.0).reshape(2, 3) + f = function([m], pt.alloc(m, 4, 2, 3).sum(axis=(0, 1)), mode=mode) + utt.assert_allclose( + f(m_val), np.broadcast_to(m_val, (4, 2, 3)).sum(axis=(0, 1)) + ) + assert not any(isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort()) + + # Prod variant: broadcast axes become a power of the reduced value. + f = function([m], pt.alloc(m + 1, 4, 2, 3).prod(axis=(0, 1)), mode=mode) + utt.assert_allclose( + f(m_val), np.broadcast_to(m_val + 1, (4, 2, 3)).prod(axis=(0, 1)) + ) + assert not any(isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort()) + + # A broadcast dimension that is *kept* (not in the value, not reduced) + # must still be materialized by an alloc on the rewrite's output. + v = vector("v") + v_val = np.array([1.0, 2.0, 3.0]) + f = function([v], pt.alloc(v, 4, 6, 3).sum(axis=1), mode=mode) + utt.assert_allclose(f(v_val), np.broadcast_to(v_val, (4, 6, 3)).sum(axis=1)) + topo = f.maker.fgraph.toposort() + assert topo[-1].op == pt.alloc + assert not any(isinstance(node.op, Sum) for node in topo) + + @pytest.mark.parametrize( + "pt_reduce, np_reduce", + [ + (pt_max, np.max), + (pt_min, np.min), + (pt_all, np.all), + (pt_any, np.any), + ], + ) + def test_local_careduce_of_alloc_idempotent(self, pt_reduce, np_reduce): + # Idempotent reductions (max/min/all/any) over the alloc's broadcast axes + # just drop those axes, with no size factor: repeating a value doesn't + # change its max, min, or truth. + mode = self.mode.excluding("fusion") + m = matrix("m") + m_val = np.arange(6.0).reshape(2, 3) + + # Full reduction, and a mixed reduction over a broadcast + a real axis. + for axis in (None, (0, 1)): + f = function([m], pt_reduce(pt.alloc(m, 4, 2, 3), axis=axis), mode=mode) + utt.assert_allclose( + f(m_val), np_reduce(np.broadcast_to(m_val, (4, 2, 3)), axis=axis) + ) + assert not any( + isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort() + ) + + def test_local_careduce_of_alloc_prod_without_zeros(self): + # ProdWithoutZeros follows Prod's value**size rule for a broadcast axis + # (a repeated zero stays zero). The input has a zero so its zero handling + # is exercised too. + mode = self.mode.excluding("fusion") + m = matrix("m") + m_val = np.array([[2.0, 0.0, 4.0], [1.0, 3.0, 5.0]]) + bcast = np.broadcast_to(m_val, (4, 2, 3)) + + def np_prod_without_zeros(a, axis): + nonzeros = np.where(a == 0, 1.0, a).prod(axis=axis) + return np.where((a == 0).all(axis=axis), 0.0, nonzeros) + + # Full reduction and a mixed reduction over a broadcast + a real axis. + for axis in (None, (0, 1)): + f = function( + [m], ProdWithoutZeros(axis=axis)(pt.alloc(m, 4, 2, 3)), mode=mode + ) + np_axis = (0, 1, 2) if axis is None else axis + utt.assert_allclose(f(m_val), np_prod_without_zeros(bcast, np_axis)) + assert not any( + isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort() + ) + def test_local_sum_prod_mul_by_scalar_stack_trace(self): """Test that stack trace is copied over correctly for `local_sum_prod_mul_by_scalar`.""" m0 = (