From bc54120fec4ebe17f262baba53ebfc7919c2964b Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 8 Jun 2026 17:29:51 +0200 Subject: [PATCH 1/2] Centralize cached constant predicates in constant_props module Move the tag-cached, constant-data predicates into a new dependency-free leaf module `pytensor.tensor.constant_props`: - `constant_indices_are_unique` and `constant_is_arange` (from `rewriting/subtensor.py`) - `constant_is_all_negative` (new; replaces the inline all-negative scan in `local_sign_div`) These only inspect constant data, so the module imports nothing beyond numpy and `Constant`. The recursive, graph-walking sign analysis `is_provably_positive` / `is_provably_non_negative` stays in `subtensor` with the ops it recurses over, and is renamed to drop the leading underscore now that it is imported across modules. --- pytensor/assumptions/elemwise.py | 4 +- pytensor/assumptions/positive_definite.py | 6 +- pytensor/tensor/constant_props.py | 77 +++++++++++++++++++ pytensor/tensor/rewriting/math.py | 13 ++-- pytensor/tensor/rewriting/subtensor.py | 82 ++++----------------- pytensor/tensor/rewriting/subtensor_lift.py | 4 +- pytensor/tensor/subtensor.py | 25 +++---- tests/tensor/test_subtensor.py | 14 ++-- 8 files changed, 124 insertions(+), 101 deletions(-) create mode 100644 pytensor/tensor/constant_props.py diff --git a/pytensor/assumptions/elemwise.py b/pytensor/assumptions/elemwise.py index b2851223fc..a1a3f696e3 100644 --- a/pytensor/assumptions/elemwise.py +++ b/pytensor/assumptions/elemwise.py @@ -11,7 +11,7 @@ NotScalarConstantError, get_underlying_scalar_constant_value, ) -from pytensor.tensor.subtensor import _is_provably_positive +from pytensor.tensor.subtensor import is_provably_positive def elemwise_preserves_zero_pattern( @@ -76,7 +76,7 @@ def _not_singleton_matrix(var) -> bool: return [FactState.UNKNOWN] # 0 ** p == 0 for p > 0, so a provably-positive exponent (scalar or # elementwise matrix) preserves the base's zero pattern. - return true_if(_is_provably_positive(node.inputs[1])) + return true_if(is_provably_positive(node.inputs[1])) if isinstance(scalar_op, UnaryScalarOp) and scalar_op.preserves_zero: return true_if(input_states[0] is FactState.TRUE) diff --git a/pytensor/assumptions/positive_definite.py b/pytensor/assumptions/positive_definite.py index b021883e19..e840c658bf 100644 --- a/pytensor/assumptions/positive_definite.py +++ b/pytensor/assumptions/positive_definite.py @@ -20,7 +20,7 @@ SolveBilinearDiscreteLyapunov, ) from pytensor.tensor.math import Dot -from pytensor.tensor.subtensor import Subtensor, _is_provably_positive +from pytensor.tensor.subtensor import Subtensor, is_provably_positive register_assumption(POSITIVE_DEFINITE, Eye)(eye_identity_rule) @@ -34,7 +34,7 @@ def _alloc_diag(key, op, feature, fgraph, node, input_states): return [FactState.FALSE] [diag_values] = node.inputs - return true_if(_is_provably_positive(diag_values)) + return true_if(is_provably_positive(diag_values)) register_assumption(POSITIVE_DEFINITE, BlockDiagonal)(all_inputs_have_key) @@ -74,7 +74,7 @@ def _elemwise(key, op, feature, fgraph, node, input_states): # Scaling a PD matrix by a positive scalar keeps it PD. The factor # must be constant across the matrix axes (both broadcastable); # per-batch variation is fine since every batch slice stays PD. - if not (all(inp.type.broadcastable[-2:]) and _is_provably_positive(inp)): + if not (all(inp.type.broadcastable[-2:]) and is_provably_positive(inp)): continue other_inputs = [node.inputs[j] for j in range(len(node.inputs)) if j != i] if other_inputs and all( diff --git a/pytensor/tensor/constant_props.py b/pytensor/tensor/constant_props.py new file mode 100644 index 0000000000..adbbffa7b4 --- /dev/null +++ b/pytensor/tensor/constant_props.py @@ -0,0 +1,77 @@ +"""Cached static predicates over the *data* of constant variables. + +Each helper memoizes its result on the variable's ``tag`` so repeated rewrite +passes do not re-scan the same constant. These only inspect constant data, so +this stays a dependency-free leaf module (graph-walking sign analysis such as +``is_provably_positive`` lives in ``subtensor`` with the ops it recurses over). +""" + +import numpy as np + +from pytensor.graph.basic import Constant + + +def constant_is_all_negative(var) -> bool: + """Whether ``var`` is a constant whose entries are all negative, cached on its tag.""" + if not isinstance(var, Constant): + return False + cached: bool | None = getattr(var.tag, "all_negative", None) + if cached is not None: + return cached + result = bool(np.all(np.asarray(var.data) < 0)) + var.tag.all_negative = result + return result + + +def constant_indices_are_unique(idx) -> bool: + """Check whether a constant index has no duplicate entries. + + Boolean indices, scalars, and single-element arrays are trivially unique. + For larger integer arrays, indices that mix positive and negative values + may alias, so those are treated as potentially duplicated. The result + is cached on ``idx.tag``. + """ + if not isinstance(idx, Constant): + return False + cached = getattr(idx.tag, "unique_indices", None) + if cached is not None: + return bool(cached) + idx_val = np.asarray(idx.data) + if idx_val.dtype == bool: + result = True + elif idx_val.size <= 1: + result = True + else: + has_pos = (idx_val >= 0).any() + has_neg = (idx_val < 0).any() + result = not (has_pos and has_neg) and np.unique(idx_val).size == idx_val.size + idx.tag.unique_indices = result + return result + + +def constant_is_arange(idx) -> tuple[int, int, int] | None: + """Match ``idx`` to ``np.arange(offset, offset + d * step, step)`` + and return ``(d, offset, step)``, else ``None``. + + Single-element constants return ``(1, value, 1)``. The result is cached + on ``idx.tag.is_arange`` (``False`` sentinels a no-match). + """ + if not isinstance(idx, Constant): + return None + cached = getattr(idx.tag, "is_arange", None) + if cached is not None: + return cached if cached is not False else None + idx_val = np.asarray(idx.data) + if idx_val.ndim != 1 or idx_val.size == 0 or idx_val.dtype.kind not in "iu": + result: tuple[int, int, int] | None = None + elif idx_val.size == 1: + result = (1, int(idx_val[0]), 1) + else: + diffs = np.diff(idx_val) + step = int(diffs[0]) + if step != 0 and np.all(diffs == step): + result = (int(idx_val.size), int(idx_val[0]), step) + else: + result = None + idx.tag.is_arange = result if result is not None else False + return result diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 1c1b3ed29c..4982ed1490 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -41,6 +41,7 @@ zeros, zeros_like, ) +from pytensor.tensor.constant_props import constant_is_all_negative from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.extra_ops import broadcast_arrays, concat_with_broadcast @@ -112,7 +113,7 @@ from pytensor.tensor.rewriting.blockwise import blockwise_of from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift from pytensor.tensor.shape import Shape, Shape_i, specify_shape -from pytensor.tensor.subtensor import Subtensor, _is_provably_positive +from pytensor.tensor.subtensor import Subtensor, is_provably_positive from pytensor.tensor.type import ( complex_dtypes, uint_dtypes, @@ -709,8 +710,8 @@ def local_log_div(fgraph, node): if isinstance(scalar_op, ps.TrueDiv): num, den = inp.owner.inputs - if (isinstance(num, Constant) and _is_provably_positive(num, strict=True)) or ( - isinstance(den, Constant) and _is_provably_positive(den, strict=True) + if (isinstance(num, Constant) and is_provably_positive(num, strict=True)) or ( + isinstance(den, Constant) and is_provably_positive(den, strict=True) ): return [log(num) - log(den)] @@ -739,13 +740,13 @@ def local_sign_div(fgraph, node): num, den = inp.owner.inputs - if _is_provably_positive(num, strict=True): + if is_provably_positive(num, strict=True): return [sign(den)] - if _is_provably_positive(den, strict=True): + if is_provably_positive(den, strict=True): return [sign(num)] for side, other in ((num, den), (den, num)): - if isinstance(side, Constant) and np.all(np.asarray(side.data) < 0): + if constant_is_all_negative(side): return [neg(sign(other))] diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 6c07ddb807..a0dbdc9f77 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -37,6 +37,10 @@ ) from pytensor.tensor.basic import constant as tensor_constant from pytensor.tensor.blockwise import _squeeze_left +from pytensor.tensor.constant_props import ( + constant_indices_are_unique, + constant_is_arange, +) from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.extra_ops import broadcast_to, squeeze @@ -74,7 +78,6 @@ AdvancedSubtensor1, IncSubtensor, Subtensor, - _is_provably_non_negative, _non_consecutive_adv_indexing, advanced_inc_subtensor1, advanced_subtensor1, @@ -85,6 +88,7 @@ get_slice_elements, inc_subtensor, indices_from_subtensor, + is_provably_non_negative, unflatten_index_variables, ) from pytensor.tensor.type import TensorType @@ -205,60 +209,6 @@ def get_advsubtensor_axis(indices): return axis -def _constant_has_unique_indices(idx) -> bool: - """Check whether a constant index has no duplicate entries. - - Boolean indices, scalars, and single-element arrays are trivially unique. - For larger integer arrays, indices that mix positive and negative values - may alias, so those are treated as potentially duplicated. The result - is cached on ``idx.tag``. - """ - if not isinstance(idx, Constant): - return False - cached = getattr(idx.tag, "unique_indices", None) - if cached is not None: - return bool(cached) - idx_val = np.asarray(idx.data) - if idx_val.dtype == bool: - result = True - elif idx_val.size <= 1: - result = True - else: - has_pos = (idx_val >= 0).any() - has_neg = (idx_val < 0).any() - result = not (has_pos and has_neg) and np.unique(idx_val).size == idx_val.size - idx.tag.unique_indices = result - return result - - -def _constant_is_arange(idx) -> tuple[int, int, int] | None: - """Match ``idx`` to ``np.arange(offset, offset + d * step, step)`` - and return ``(d, offset, step)``, else ``None``. - - Single-element constants return ``(1, value, 1)``. The result is cached - on ``idx.tag.is_arange`` (``False`` sentinels a no-match). - """ - if not isinstance(idx, Constant): - return None - cached = getattr(idx.tag, "is_arange", None) - if cached is not None: - return cached if cached is not False else None - idx_val = np.asarray(idx.data) - if idx_val.ndim != 1 or idx_val.size == 0 or idx_val.dtype.kind not in "iu": - result: tuple[int, int, int] | None = None - elif idx_val.size == 1: - result = (1, int(idx_val[0]), 1) - else: - diffs = np.diff(idx_val) - step = int(diffs[0]) - if step != 0 and np.all(diffs == step): - result = (int(idx_val.size), int(idx_val[0]), step) - else: - result = None - idx.tag.is_arange = result if result is not None else False - return result - - def _match_arange_0_to_d_plus_offset(idx): """Match ``arange(0, d, 1) + offset`` and return ``(arange_node, offset)`` where ``arange_node`` is the ``arange(0, d, 1)`` output and ``offset`` is @@ -774,7 +724,7 @@ def _merge_scalar_into_slice_unsafe(inner_slice, scalar_index, dim, xshape): def _eager_lt_0(x): """Return ``True``/``False`` (Python bool) when the sign of *x* is known, otherwise return the ``lt(x, 0)`` graph node.""" - if _is_provably_non_negative(x): + if is_provably_non_negative(x): return False if isinstance(x, Constant): return int(x.data) < 0 @@ -792,9 +742,9 @@ def _eager_switch(cond, a, b): def _eager_minimum(a, b): if a is b: return a - if _eager_lt_0(a) is True and _is_provably_non_negative(b): + if _eager_lt_0(a) is True and is_provably_non_negative(b): return a - if _eager_lt_0(b) is True and _is_provably_non_negative(a): + if _eager_lt_0(b) is True and is_provably_non_negative(a): return b return minimum(a, b) @@ -1169,7 +1119,7 @@ def local_add_of_sparse_write(fgraph, node): # duplicate-free. Basic (slice/scalar) indexing is always unique; # advanced integer-array indices must be checked. if not inner_op.set_instead_of_inc and not isinstance(inner_op, IncSubtensor): - if not all(_constant_has_unique_indices(idx) for idx in idx_vars): + if not all(constant_indices_are_unique(idx) for idx in idx_vars): continue others = [node.inputs[j] for j in range(len(node.inputs)) if j != i] @@ -1368,7 +1318,7 @@ def _arange_index_to_slice(idx): if not isinstance(idx, TensorVariable) or idx.type.ndim != 1: return None - const_match = _constant_is_arange(idx) + const_match = constant_is_arange(idx) if const_match is not None: d, offset, step = const_match if offset < 0 or offset + (d - 1) * step < 0: @@ -1391,9 +1341,9 @@ def _arange_index_to_slice(idx): if isinstance(arange_stop, TensorVariable) and arange_stop.type.dtype != "int64": arange_stop = arange_stop.astype("int64") offset = _eager_scalar(offset) - if not _is_provably_non_negative(offset): + if not is_provably_non_negative(offset): return None - if not _is_provably_non_negative(arange_stop): + if not is_provably_non_negative(arange_stop): return None stop = eager_add_zero(arange_stop, offset) return slice(offset, stop) @@ -1426,7 +1376,7 @@ def local_adv_idx_to_diagonal(fgraph, node): # Match both indices as arange(d) + offset (const or symbolic). # Both must be the same kind (both const or both symbolic). def _match_arange(idx): - const = _constant_is_arange(idx) + const = constant_is_arange(idx) if const is not None and const[2] == 1: return "const", const[0], const[1] sym = _match_arange_0_to_d_plus_offset(idx) @@ -2001,7 +1951,7 @@ def local_read_of_write_same_indices(fgraph, node): indices = indices_from_subtensor(outer_idx_vars, node.op.idx_list) for idx in indices: if isinstance(idx, TensorVariable) and idx.type.ndim > 0: - if not _constant_has_unique_indices(idx): + if not constant_indices_are_unique(idx): return None x_at_idx = x[tuple(indices)] @@ -2043,7 +1993,7 @@ def _slice_to_arange(sl, dim_length): return None if sl.stop is None: return arange(dim_length) - if not _is_provably_non_negative(sl.stop): + if not is_provably_non_negative(sl.stop): return None return arange(minimum(sl.stop, dim_length)) @@ -2363,7 +2313,7 @@ def local_write_of_write_same_indices(fgraph, node): # sufficient: it guarantees no duplicates in the joint cross-product # after broadcasting. if not isinstance(node.op, IncSubtensor): - if not all(_constant_has_unique_indices(v) for v in outer_idx_vars): + if not all(constant_indices_are_unique(v) for v in outer_idx_vars): return new_val = a + b if ( diff --git a/pytensor/tensor/rewriting/subtensor_lift.py b/pytensor/tensor/rewriting/subtensor_lift.py index 98920e415d..d99ff2c575 100644 --- a/pytensor/tensor/rewriting/subtensor_lift.py +++ b/pytensor/tensor/rewriting/subtensor_lift.py @@ -35,6 +35,7 @@ register_infer_shape, ) from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.constant_props import constant_indices_are_unique from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.extra_ops import squeeze @@ -48,7 +49,6 @@ ) from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift from pytensor.tensor.rewriting.subtensor import ( - _constant_has_unique_indices, local_adv_idx_to_diagonal, local_adv_idx_to_slice, local_advanced_read_of_write_constant_indices, @@ -223,7 +223,7 @@ def _index_provably_smaller(idx, val_static_dim) -> bool: return True if idx.type.dtype == "bool": return True - if _constant_has_unique_indices(idx): + if constant_indices_are_unique(idx): return True if isinstance(idx.owner_op, ARange): return True diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 403331cef6..3e9e1ccc2f 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -239,7 +239,7 @@ def as_index_literal( raise NotScalarConstantError() -def _is_provably_positive(var, strict: bool = True) -> bool: +def is_provably_positive(var, strict: bool = True) -> bool: """``True`` when ``var`` can be statically shown to be positive. With ``strict=True`` this proves :math:`var > 0`; with ``strict=False`` it @@ -289,26 +289,23 @@ def _is_provably_positive(var, strict: bool = True) -> bool: if not strict and isinstance(op, Shape | Shape_i): return True if isinstance(op, MakeVector): - return all(_is_provably_positive(i, strict) for i in var.owner.inputs) + return all(is_provably_positive(i, strict) for i in var.owner.inputs) if isinstance(op, Subtensor | ScalarFromTensor | TensorFromScalar | DimShuffle): - return _is_provably_positive(var.owner.inputs[0], strict) + return is_provably_positive(var.owner.inputs[0], strict) if isinstance(op, Elemwise): scalar_op = op.scalar_op if not strict and isinstance(scalar_op, Cast): - return _is_provably_positive(var.owner.inputs[0], strict) + return is_provably_positive(var.owner.inputs[0], strict) if isinstance(scalar_op, ScalarMinimum): - return all(_is_provably_positive(i, strict) for i in var.owner.inputs) + return all(is_provably_positive(i, strict) for i in var.owner.inputs) if isinstance(scalar_op, ScalarMaximum): - return any(_is_provably_positive(i, strict) for i in var.owner.inputs) + return any(is_provably_positive(i, strict) for i in var.owner.inputs) return False -def _is_provably_non_negative(var) -> bool: - """``True`` when ``var`` can be statically shown to be non-negative (:math:`\\geq 0`). - - Thin wrapper over :func:`_is_provably_positive` with ``strict=False``. - """ - return _is_provably_positive(var, strict=False) +def is_provably_non_negative(var) -> bool: + """Whether ``var`` is provably ``>= 0``; thin wrapper over `is_provably_positive`.""" + return is_provably_positive(var, strict=False) def get_idx_list(inputs, idx_list): @@ -448,7 +445,7 @@ def analyze(x): if is_stop_length: # Full slice. return slice(0, length, 1), 1 - if _is_provably_non_negative(stop): + if is_provably_non_negative(stop): return (slice(0, minimum(stop, length), 1), 1) stop_plus_len = stop + length stop = switch( @@ -571,7 +568,7 @@ def slice_len(slc, n): and canon_slc.step == 1 and isinstance(canon_slc.start, int) and canon_slc.start == 0 - and _is_provably_non_negative(stop) + and is_provably_non_negative(stop) ): return stop return switch( diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index 31a4eb9d90..7baf21fd36 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -35,8 +35,6 @@ AdvancedSubtensor1, IncSubtensor, Subtensor, - _is_provably_non_negative, - _is_provably_positive, advanced_inc_subtensor, advanced_inc_subtensor1, advanced_set_subtensor, @@ -49,6 +47,8 @@ get_canonical_form_slice, inc_subtensor, indexed_result_shape, + is_provably_non_negative, + is_provably_positive, set_subtensor, slice_at_axis, take, @@ -127,9 +127,7 @@ class TestProvablyPositive: ], ) def test_constant_data_respects_strictness(self, data, strict, expected): - assert ( - _is_provably_positive(constant(np.array(data)), strict=strict) is expected - ) + assert is_provably_positive(constant(np.array(data)), strict=strict) is expected @pytest.mark.parametrize( "make_var", @@ -143,8 +141,8 @@ def test_proves_non_negative_but_not_strict_positive(self, make_var): """A uint value, a shape dimension, and a cast can each equal zero, so they establish ``>= 0`` but never strict ``> 0``.""" var = make_var() - assert _is_provably_non_negative(var) is True - assert _is_provably_positive(var, strict=True) is False + assert is_provably_non_negative(var) is True + assert is_provably_positive(var, strict=True) is False @pytest.mark.parametrize( "expr, strict, expected", @@ -164,7 +162,7 @@ def test_proves_non_negative_but_not_strict_positive(self, make_var): ], ) def test_recurses_through_min_and_max(self, expr, strict, expected): - assert _is_provably_positive(expr, strict=strict) is expected + assert is_provably_positive(expr, strict=strict) is expected class TestGetCanonicalFormSlice: From 84e68c5bfa918b9b652ac922f165ea799c52ea72 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 8 Jun 2026 17:30:10 +0200 Subject: [PATCH 2/2] Rewrite division by a reciprocal or negative power as multiplication Add two specialize rewrites: - `local_div_reciprocal_to_mul`: `A / reciprocal(B) -> A * B` and `A / y ** (-p) -> A * y ** p`. The first form is left opaque by the AlgebraicCanonizer and is what `local_pow_specialize` emits for negative constant exponents (e.g. `y ** -2 -> reciprocal(sqr(y))`), which previously left a redundant reciprocal and division. - `local_reciprocal_neg_pow_to_pow`: `reciprocal(y ** (-p)) -> y ** p`, covering `1 / y ** (-p)`, which `local_div_to_reciprocal` canonicalizes to a reciprocal before it can reach the division rewrite. Both only fire when the exponent is recognisably negative (a `neg` node or an all-negative constant), so a positive exponent is never flipped. --- pytensor/tensor/rewriting/math.py | 51 +++++++++++++++++++++++++++ tests/tensor/rewriting/test_math.py | 54 +++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 4982ed1490..7cdb280cc2 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -860,6 +860,57 @@ def local_div_exp_to_mul_exp(fgraph, node): return [new_out] +@register_specialize +@node_rewriter([true_div]) +def local_div_reciprocal_to_mul(fgraph, node): + """Replace ``A / reciprocal(B)`` with ``A * B`` and ``A / y ** (-p)`` with ``A * y ** p``.""" + num, denom = node.inputs + + match denom.owner_op_and_inputs: + case (Elemwise(scalar_op=ps.Reciprocal()), b): + inverted = b + case (Elemwise(scalar_op=ps.Pow()), base, exponent): + match exponent.owner_op_and_inputs: + case (Elemwise(scalar_op=ps.Neg()), pos_exponent): + inverted = base**pos_exponent + case _ if constant_is_all_negative(exponent): + inverted = base**-exponent.data + case _: + return None + case _: + return None + + new_out = num * inverted + if new_out.dtype != node.outputs[0].dtype: + new_out = cast(new_out, dtype=node.outputs[0].dtype) + copy_stack_trace(node.outputs[0], new_out) + return [new_out] + + +@register_specialize +@node_rewriter([reciprocal]) +def local_reciprocal_neg_pow_to_pow(fgraph, node): + """Replace ``reciprocal(y ** (-p))`` with ``y ** p`` (the ``1 / y ** (-p)`` form).""" + [arg] = node.inputs + + match arg.owner_op_and_inputs: + case (Elemwise(scalar_op=ps.Pow()), base, exponent): + match exponent.owner_op_and_inputs: + case (Elemwise(scalar_op=ps.Neg()), pos_exponent): + new_out = base**pos_exponent + case _ if constant_is_all_negative(exponent): + new_out = base**-exponent.data + case _: + return None + case _: + return None + + if new_out.dtype != node.outputs[0].dtype: + new_out = cast(new_out, dtype=node.outputs[0].dtype) + copy_stack_trace(node.outputs[0], new_out) + return [new_out] + + @register_specialize @node_rewriter([mul, true_div]) def local_mul_pow_to_pow_add(fgraph, node): diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 83ec5a5e10..360cec0a3f 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -3843,6 +3843,60 @@ def test_local_div_exp_to_mul_exp(): assert_equal_computations([rewritten], [expected]) +def test_local_div_reciprocal_to_mul(): + x = scalar("x") + y = scalar("y") + p = scalar("p") + + # A / reciprocal(B) -> A * B + out = x / reciprocal(y) + rewritten = rewrite_graph(out, include=("specialize",)) + assert_equal_computations([rewritten], [x * y]) + + # local_pow_specialize turns y**-2 into reciprocal(sqr(y)); the outer division + # must then cancel against the reciprocal: x / y**-2 -> x * sqr(y) + out = x / y ** (-2) + rewritten = rewrite_graph(out, include=("specialize",)) + assert_equal_computations([rewritten], [x * sqr(y)]) + + # symbolic neg exponent: x / y**(-p) -> x * y**p (the neg is dropped too) + out = x / y ** (-p) + rewritten = rewrite_graph(out, include=("specialize",)) + assert_equal_computations([rewritten], [x * y**p]) + + # negative constant exponent (here non-integer, so no nested-squaring path): + # x / y**(-2.5) -> x * y**2.5 + out = x / y ** (-2.5) + rewritten = rewrite_graph(out, include=("specialize",)) + assert_equal_computations([rewritten], [x * y**2.5]) + + # guard: a positive exponent must never be flipped into a negative one + out = x / y**p + rewritten = rewrite_graph(out, include=("specialize",)) + assert_equal_computations([rewritten], [out]) + + +def test_local_reciprocal_neg_pow_to_pow(): + y = scalar("y") + p = scalar("p") + + # reciprocal(y**(-p)) -> y**p + out = reciprocal(y ** (-p)) + rewritten = rewrite_graph(out, include=("specialize",)) + assert_equal_computations([rewritten], [y**p]) + + # 1 / y**(-p) is canonicalised to reciprocal(y**(-p)) before reaching a + # true_div, so it is handled here: -> y**p + out = true_div(np.float64(1.0), y ** (-p)) + rewritten = rewrite_graph(out, include=("specialize",)) + assert_equal_computations([rewritten], [y**p]) + + # guard: positive exponent untouched + out = reciprocal(y**p) + rewritten = rewrite_graph(out, include=("specialize",)) + assert_equal_computations([rewritten], [out]) + + def test_local_mul_pow_to_pow_add(): # Default and FAST_RUN modes put a Composite op into the final graph, # whereas FAST_COMPILE doesn't. To unify the graph the test cases analyze across runs,