Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 74 additions & 54 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,13 @@
from pytensor.tensor.extra_ops import broadcast_arrays, concat_with_broadcast
from pytensor.tensor.linalg.constructors import BlockDiagonal
from pytensor.tensor.math import (
All,
Any,
Dot,
Max,
Min,
Prod,
ProdWithoutZeros,
Sum,
_conj,
_dot,
Expand Down Expand Up @@ -2119,62 +2124,77 @@ def local_reduce_broadcastable(fgraph, node):


@register_specialize
@node_rewriter([Sum, Prod])
def local_opt_alloc(fgraph, node):
"""
sum(alloc(constant,shapes...)) => constant*prod(shapes)
or
prod(alloc(constant,shapes...)) => constant**prod(shapes)
@node_rewriter([Sum, Prod, ProdWithoutZeros, Max, Min, All, Any])
def local_careduce_of_alloc(fgraph, node):
"""
Push a reduction through the alloc it reduces, factoring out the broadcast
axes::

sum(alloc(x, shapes)) => sum(x) * prod(broadcast shapes)
prod(alloc(x, shapes)) => prod(x) ** prod(broadcast shapes)
max(alloc(x, shapes)) => max(x) # and min/all/any

A reduced axis that the alloc merely broadcasts (a new dimension, or one
where ``x`` is broadcastable) repeats the same value. Reducing such an axis
is a multiplication (``Sum``) or power (``Prod``/``ProdWithoutZeros``) by the
axis size, and a no-op for idempotent reductions (``Max``/``Min``/``All``/
``Any``). Reduced axes where ``x`` holds real data are reduced on ``x``
itself. Kept axes - including broadcast ones that aren't reduced - are
re-materialized by an alloc on the output.
"""
match node.inputs[0].owner_op_and_inputs:
case (Alloc(), value, *shapes):
pass
case _:
return None

"""
(node_inps,) = node.inputs
if node_inps.owner and isinstance(node_inps.owner.op, Alloc):
inp = node_inps.owner.inputs[0]
shapes = node_inps.owner.inputs[1:]
try:
val = get_underlying_scalar_constant_value(inp, only_process_constants=True)
assert val.size == 1
val = val.reshape(1)[0]
# check which type of op
size = mul(*shapes)
if inp.dtype in ("float16", "float32"):
# shapes are ints and normally int64.
# We don't want to have a float64 upcast
# We don't want to downcast to float16
# as we fear it could loose too much precision
# that will be amplified by the mul/pow below.
ndim = len(shapes)
axis = node.op.axis
axis = tuple(range(ndim)) if axis is None else axis

# ``value`` is right-aligned against the alloc output dimensions.
offset = ndim - value.type.ndim
value_bcast = value.type.broadcastable

# Split the value's reduced dimensions: the ones it broadcasts (size-1) are
# just squeezed out (their repeat count is folded into ``size`` below), the
# rest are genuinely reduced.
squeeze_axes, reduce_axes = [], []
for a in axis:
if (j := a - offset) >= 0:
if value_bcast[j]:
squeeze_axes.append(j)
else:
reduce_axes.append(j)
if squeeze_axes:
value = value.squeeze(squeeze_axes)
# squeezing shifts each remaining axis left past the dropped ones
reduce_axes = [j - sum(s < j for s in squeeze_axes) for j in reduce_axes]
if reduce_axes:
value = node.op.clone(axis=tuple(reduce_axes))(value)

# The remaining reduced axes are pure broadcasts of ``value``. ``Sum`` turns
# them into a multiplication by their size, ``Prod``/``ProdWithoutZeros``
# into a power; idempotent reductions (max/min/all/any) are unaffected, as
# repeating a value doesn't change its maximum, minimum, or truth.
if isinstance(node.op, Sum | Prod | ProdWithoutZeros):
size_shapes = [shapes[a] for a in axis if a < offset or value_bcast[a - offset]]
if size_shapes:
size = variadic_mul(*size_shapes)
if value.dtype in ("float16", "float32"):
# Avoid a float64 upcast from the int64 shapes (or a float16
# downcast); either would be amplified by the mul/pow below.
size = size.astype("float32")
if node.op.axis is None or node.op.axis == tuple(range(inp.ndim)):
if isinstance(node.op, Sum):
val = val * size
else:
val = val**size
# Sum can change the input dtype (upcast or bool
# -> float32) by default or by user request.
# We can ignore the acc_dtype, as there is only 1
# elemwise we will do and not a sequence, so there is no
# accumulation of errors.
# So mostly, we just need to cast the output to the old
# dtype.
val = val.astype(node.outputs[0].dtype)
return [val]
to_prod = [shapes[i] for i in range(len(shapes)) if i in node.op.axis]
if to_prod:
size = mul(*to_prod)
if isinstance(node.op, Sum):
val *= size
else:
val = val**size
# See comments above.
val = val.astype(node.outputs[0].dtype)
return [
alloc(
val,
*[shapes[i] for i in range(len(shapes)) if i not in node.op.axis],
)
]
except NotScalarConstantError:
pass
value = value * size if isinstance(node.op, Sum) else value**size

# The reduction may change the dtype; a single elemwise has no accumulation
# error, so ignore acc_dtype and just cast to the reduction's output dtype.
value = value.astype(node.outputs[0].dtype)

kept_shapes = [shapes[a] for a in range(ndim) if a not in axis]
out = alloc(value, *kept_shapes) if kept_shapes else value
copy_stack_trace(node.outputs[0], out)
return [out]


@register_specialize
Expand Down
188 changes: 167 additions & 21 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
Dot,
Max,
Prod,
ProdWithoutZeros,
Sum,
_conj,
_matmul,
Expand Down Expand Up @@ -3139,14 +3140,19 @@ def test_prod_of_non_scalar_mul(self):
rewritten_out_fn(*test_vals),
)

def test_local_sum_prod_alloc(self):
def test_local_careduce_of_alloc(self):
a = dtensor3()
input = np.asarray(np.arange(2 * 3 * 4).reshape(2, 3, 4), dtype="float64")
mode = self.mode.including("specialize").excluding("fusion")

for t_like, n_like, nb_nodes in [
(pt.zeros_like, np.zeros_like, (1, 3, 3, 2)),
(pt.ones_like, np.ones_like, (5, 5, 5, 6)),
# node counts for: full reduction, single-axis reduction, two chained
# single-axis reductions
(pt.zeros_like, np.zeros_like, (1, 3, 2)),
# ones_like allocs an all-broadcastable value; partial reductions
# leave a redundant ExpandDims (a generic alloc-value concern, not
# this rewrite's), hence the extra node vs zeros_like.
(pt.ones_like, np.ones_like, (5, 6, 8)),
]:
# test sum
f = function([a], t_like(a).sum(None), mode=mode)
Expand All @@ -3164,46 +3170,186 @@ def test_local_sum_prod_alloc(self):
topo = f.maker.fgraph.toposort()
assert topo[-1].op == pt.alloc
assert not any(isinstance(node.op, Sum) for node in topo)
for i in range(3):
f = function([a], t_like(a).sum(i), mode=mode)
utt.assert_allclose(f(input), n_like(input).sum(i))
assert len(f.maker.fgraph.apply_nodes) == nb_nodes[2]
topo = f.maker.fgraph.toposort()
assert topo[-1].op == pt.alloc
assert not any(isinstance(node.op, Sum) for node in topo)

# test prod
f = function([a], t_like(a).prod(None), mode=mode)
utt.assert_allclose(f(input), n_like(input).prod())
# assert len(f.maker.fgraph.apply_nodes) == nb_nodes[0]

f = function([a], t_like(a).prod([0, 1, 2]), mode=mode)
utt.assert_allclose(f(input), n_like(input).prod())
# assert len(f.maker.fgraph.apply_nodes) == nb_nodes[0]

for d in range(3):
f = function([a], t_like(a).prod(d), mode=mode)
utt.assert_allclose(f(input), n_like(input).prod(d))
# assert len(f.maker.fgraph.apply_nodes) == nb_nodes[1]
topo = f.maker.fgraph.toposort()
assert topo[-1].op == pt.alloc
assert not any(isinstance(node.op, Prod) for node in topo)
for i in range(3):
f = function([a], t_like(a).prod(i), mode=mode)
utt.assert_allclose(f(input), n_like(input).prod(i))
# assert len(f.maker.fgraph.apply_nodes) == nb_nodes[2]
topo = f.maker.fgraph.toposort()
assert topo[-1].op == pt.alloc
assert not any(isinstance(node.op, Prod) for node in topo)

for d, dd in [(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1)]:
f = function([a], t_like(a).sum(d).sum(dd), mode=mode)
utt.assert_allclose(f(input), n_like(input).sum(d).sum(dd))
assert len(f.maker.fgraph.apply_nodes) == nb_nodes[3]
assert len(f.maker.fgraph.apply_nodes) == nb_nodes[2]
topo = f.maker.fgraph.toposort()
assert topo[-1].op == pt.alloc
assert not any(isinstance(node.op, Sum) for node in topo)

def test_local_careduce_of_alloc_symbolic_scalar(self):
# `local_careduce_of_alloc` should also fire when the alloc'd value is a
# symbolic (non-constant) scalar, since the alloc only broadcasts it.
x = scalar("x")
mode = self.mode.excluding("fusion")
x_val = np.float64(1.5)

# Full reduction over 2 broadcast dims -> x * (3 * 4)
f = function([x], pt.broadcast_to(x, (3, 4)).sum(), mode=mode)
utt.assert_allclose(f(x_val), np.full((3, 4), x_val).sum())
topo = f.maker.fgraph.toposort()
assert not any(isinstance(node.op, Alloc | Sum) for node in topo)

# Partial reduction -> an Alloc of (x * 3) remains, but no Sum
f = function([x], pt.broadcast_to(x, (3, 4)).sum(axis=0), mode=mode)
utt.assert_allclose(f(x_val), np.full((3, 4), x_val).sum(axis=0))
topo = f.maker.fgraph.toposort()
assert topo[-1].op == pt.alloc
assert not any(isinstance(node.op, Sum) for node in topo)

# Prod -> x ** (2 * 3)
f = function([x], pt.alloc(x, 2, 3).prod(), mode=mode)
utt.assert_allclose(f(x_val), np.full((2, 3), x_val).prod())
topo = f.maker.fgraph.toposort()
assert not any(isinstance(node.op, Alloc | Prod) for node in topo)

def test_local_careduce_of_alloc_broadcast_reduced_axes(self):
# `local_careduce_of_alloc` only needs the alloc to broadcast its value along the
# *reduced* axes; the value may carry real data on the kept axes (e.g. a
# batch dimension introduced by vectorizing ``alloc(x, n).sum()``).
mode = self.mode.excluding("fusion")
bx = vector("bx")
bx_val = np.array([1.0, 2.0, 3.0, 4.0])
batched = bx[:, None] # (B, 1): broadcast along the to-be-reduced axis

# Sum over the broadcast core axis -> bx * n, keeping the batch axis
f = function([bx], pt.alloc(batched, bx.shape[0], 5).sum(axis=1), mode=mode)
utt.assert_allclose(f(bx_val), bx_val * 5)
topo = f.maker.fgraph.toposort()
assert not any(isinstance(node.op, Alloc | Sum) for node in topo)

# Prod likewise -> bx ** n
f = function([bx], pt.alloc(batched, bx.shape[0], 3).prod(axis=1), mode=mode)
utt.assert_allclose(f(bx_val), bx_val**3)
topo = f.maker.fgraph.toposort()
assert not any(isinstance(node.op, Alloc | Prod) for node in topo)

# Reducing only a real-data axis (no broadcast axis reduced) pushes the
# reduction *through* the alloc: reduce the value first, then broadcast.
m = matrix("m")
m_val = np.arange(6.0).reshape(2, 3)
f = function([m], pt.alloc(m, 2, m.shape[0], m.shape[1]).sum(axis=1), mode=mode)
utt.assert_allclose(f(m_val), np.broadcast_to(m_val, (2, 2, 3)).sum(axis=1))
topo = f.maker.fgraph.toposort()
assert topo[-1].op == pt.alloc
# The Sum now reduces ``m`` directly, not the larger alloc output.
assert not any(
isinstance(node.op, Sum)
and any(
inp.owner and isinstance(inp.owner.op, Alloc) for inp in node.inputs
)
for node in topo
)

def test_local_careduce_of_alloc_mixed_axes(self):
# When a reduction spans both broadcast and real-data axes of the alloc,
# the broadcast axes are factored out as a multiplier while the real ones
# are reduced on the (smaller) value, e.g.
# ``sum(ones((5, 3)) * x)`` -> ``sum(x) * 5``.
mode = self.mode.excluding("fusion")

x = vector("x")
x_val = np.array([1.0, 2.0, 3.0])
f = function([x], pt.sum(pt.ones((5, 3)) * x), mode=mode)
utt.assert_allclose(f(x_val), np.broadcast_to(x_val, (5, 3)).sum())
topo = f.maker.fgraph.toposort()
# The alloc is gone; the inner sum over ``x``'s real axis remains.
assert not any(isinstance(node.op, Alloc) for node in topo)
assert any(isinstance(node.op, Sum) for node in topo)

# Reduce a new (broadcast) axis and a real axis, keeping another real one.
m = matrix("m")
m_val = np.arange(6.0).reshape(2, 3)
f = function([m], pt.alloc(m, 4, 2, 3).sum(axis=(0, 1)), mode=mode)
utt.assert_allclose(
f(m_val), np.broadcast_to(m_val, (4, 2, 3)).sum(axis=(0, 1))
)
assert not any(isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort())

# Prod variant: broadcast axes become a power of the reduced value.
f = function([m], pt.alloc(m + 1, 4, 2, 3).prod(axis=(0, 1)), mode=mode)
utt.assert_allclose(
f(m_val), np.broadcast_to(m_val + 1, (4, 2, 3)).prod(axis=(0, 1))
)
assert not any(isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort())

# A broadcast dimension that is *kept* (not in the value, not reduced)
# must still be materialized by an alloc on the rewrite's output.
v = vector("v")
v_val = np.array([1.0, 2.0, 3.0])
f = function([v], pt.alloc(v, 4, 6, 3).sum(axis=1), mode=mode)
utt.assert_allclose(f(v_val), np.broadcast_to(v_val, (4, 6, 3)).sum(axis=1))
topo = f.maker.fgraph.toposort()
assert topo[-1].op == pt.alloc
assert not any(isinstance(node.op, Sum) for node in topo)

@pytest.mark.parametrize(
"pt_reduce, np_reduce",
[
(pt_max, np.max),
(pt_min, np.min),
(pt_all, np.all),
(pt_any, np.any),
],
)
def test_local_careduce_of_alloc_idempotent(self, pt_reduce, np_reduce):
# Idempotent reductions (max/min/all/any) over the alloc's broadcast axes
# just drop those axes, with no size factor: repeating a value doesn't
# change its max, min, or truth.
mode = self.mode.excluding("fusion")
m = matrix("m")
m_val = np.arange(6.0).reshape(2, 3)

# Full reduction, and a mixed reduction over a broadcast + a real axis.
for axis in (None, (0, 1)):
f = function([m], pt_reduce(pt.alloc(m, 4, 2, 3), axis=axis), mode=mode)
utt.assert_allclose(
f(m_val), np_reduce(np.broadcast_to(m_val, (4, 2, 3)), axis=axis)
)
assert not any(
isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort()
)

def test_local_careduce_of_alloc_prod_without_zeros(self):
# ProdWithoutZeros follows Prod's value**size rule for a broadcast axis
# (a repeated zero stays zero). The input has a zero so its zero handling
# is exercised too.
mode = self.mode.excluding("fusion")
m = matrix("m")
m_val = np.array([[2.0, 0.0, 4.0], [1.0, 3.0, 5.0]])
bcast = np.broadcast_to(m_val, (4, 2, 3))

def np_prod_without_zeros(a, axis):
nonzeros = np.where(a == 0, 1.0, a).prod(axis=axis)
return np.where((a == 0).all(axis=axis), 0.0, nonzeros)

# Full reduction and a mixed reduction over a broadcast + a real axis.
for axis in (None, (0, 1)):
f = function(
[m], ProdWithoutZeros(axis=axis)(pt.alloc(m, 4, 2, 3)), mode=mode
)
np_axis = (0, 1, 2) if axis is None else axis
utt.assert_allclose(f(m_val), np_prod_without_zeros(bcast, np_axis))
assert not any(
isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort()
)

def test_local_sum_prod_mul_by_scalar_stack_trace(self):
"""Test that stack trace is copied over correctly for `local_sum_prod_mul_by_scalar`."""
m0 = (
Expand Down
Loading