Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pytensor/assumptions/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
NotScalarConstantError,
get_underlying_scalar_constant_value,
)
from pytensor.tensor.subtensor import _is_provably_positive
from pytensor.tensor.subtensor import is_provably_positive


def elemwise_preserves_zero_pattern(
Expand Down Expand Up @@ -76,7 +76,7 @@ def _not_singleton_matrix(var) -> bool:
return [FactState.UNKNOWN]
# 0 ** p == 0 for p > 0, so a provably-positive exponent (scalar or
# elementwise matrix) preserves the base's zero pattern.
return true_if(_is_provably_positive(node.inputs[1]))
return true_if(is_provably_positive(node.inputs[1]))

if isinstance(scalar_op, UnaryScalarOp) and scalar_op.preserves_zero:
return true_if(input_states[0] is FactState.TRUE)
Expand Down
6 changes: 3 additions & 3 deletions pytensor/assumptions/positive_definite.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
SolveBilinearDiscreteLyapunov,
)
from pytensor.tensor.math import Dot
from pytensor.tensor.subtensor import Subtensor, _is_provably_positive
from pytensor.tensor.subtensor import Subtensor, is_provably_positive


register_assumption(POSITIVE_DEFINITE, Eye)(eye_identity_rule)
Expand All @@ -34,7 +34,7 @@ def _alloc_diag(key, op, feature, fgraph, node, input_states):
return [FactState.FALSE]

[diag_values] = node.inputs
return true_if(_is_provably_positive(diag_values))
return true_if(is_provably_positive(diag_values))


register_assumption(POSITIVE_DEFINITE, BlockDiagonal)(all_inputs_have_key)
Expand Down Expand Up @@ -74,7 +74,7 @@ def _elemwise(key, op, feature, fgraph, node, input_states):
# Scaling a PD matrix by a positive scalar keeps it PD. The factor
# must be constant across the matrix axes (both broadcastable);
# per-batch variation is fine since every batch slice stays PD.
if not (all(inp.type.broadcastable[-2:]) and _is_provably_positive(inp)):
if not (all(inp.type.broadcastable[-2:]) and is_provably_positive(inp)):
continue
other_inputs = [node.inputs[j] for j in range(len(node.inputs)) if j != i]
if other_inputs and all(
Expand Down
77 changes: 77 additions & 0 deletions pytensor/tensor/constant_props.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Cached static predicates over the *data* of constant variables.

Each helper memoizes its result on the variable's ``tag`` so repeated rewrite
passes do not re-scan the same constant. These only inspect constant data, so
this stays a dependency-free leaf module (graph-walking sign analysis such as
``is_provably_positive`` lives in ``subtensor`` with the ops it recurses over).
"""

import numpy as np

from pytensor.graph.basic import Constant


def constant_is_all_negative(var) -> bool:
"""Whether ``var`` is a constant whose entries are all negative, cached on its tag."""
if not isinstance(var, Constant):
return False
cached: bool | None = getattr(var.tag, "all_negative", None)
if cached is not None:
return cached
result = bool(np.all(np.asarray(var.data) < 0))
var.tag.all_negative = result
return result


def constant_indices_are_unique(idx) -> bool:
"""Check whether a constant index has no duplicate entries.

Boolean indices, scalars, and single-element arrays are trivially unique.
For larger integer arrays, indices that mix positive and negative values
may alias, so those are treated as potentially duplicated. The result
is cached on ``idx.tag``.
"""
if not isinstance(idx, Constant):
return False
cached = getattr(idx.tag, "unique_indices", None)
if cached is not None:
return bool(cached)
idx_val = np.asarray(idx.data)
if idx_val.dtype == bool:
result = True
elif idx_val.size <= 1:
result = True
else:
has_pos = (idx_val >= 0).any()
has_neg = (idx_val < 0).any()
result = not (has_pos and has_neg) and np.unique(idx_val).size == idx_val.size
idx.tag.unique_indices = result
return result


def constant_is_arange(idx) -> tuple[int, int, int] | None:
"""Match ``idx`` to ``np.arange(offset, offset + d * step, step)``
and return ``(d, offset, step)``, else ``None``.

Single-element constants return ``(1, value, 1)``. The result is cached
on ``idx.tag.is_arange`` (``False`` sentinels a no-match).
"""
if not isinstance(idx, Constant):
return None
cached = getattr(idx.tag, "is_arange", None)
if cached is not None:
return cached if cached is not False else None
idx_val = np.asarray(idx.data)
if idx_val.ndim != 1 or idx_val.size == 0 or idx_val.dtype.kind not in "iu":
result: tuple[int, int, int] | None = None
elif idx_val.size == 1:
result = (1, int(idx_val[0]), 1)
else:
diffs = np.diff(idx_val)
step = int(diffs[0])
if step != 0 and np.all(diffs == step):
result = (int(idx_val.size), int(idx_val[0]), step)
else:
result = None
idx.tag.is_arange = result if result is not None else False
return result
64 changes: 58 additions & 6 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
zeros,
zeros_like,
)
from pytensor.tensor.constant_props import constant_is_all_negative
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays, concat_with_broadcast
Expand Down Expand Up @@ -112,7 +113,7 @@
from pytensor.tensor.rewriting.blockwise import blockwise_of
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
from pytensor.tensor.shape import Shape, Shape_i, specify_shape
from pytensor.tensor.subtensor import Subtensor, _is_provably_positive
from pytensor.tensor.subtensor import Subtensor, is_provably_positive
from pytensor.tensor.type import (
complex_dtypes,
uint_dtypes,
Expand Down Expand Up @@ -709,8 +710,8 @@ def local_log_div(fgraph, node):

if isinstance(scalar_op, ps.TrueDiv):
num, den = inp.owner.inputs
if (isinstance(num, Constant) and _is_provably_positive(num, strict=True)) or (
isinstance(den, Constant) and _is_provably_positive(den, strict=True)
if (isinstance(num, Constant) and is_provably_positive(num, strict=True)) or (
isinstance(den, Constant) and is_provably_positive(den, strict=True)
):
return [log(num) - log(den)]

Expand Down Expand Up @@ -739,13 +740,13 @@ def local_sign_div(fgraph, node):

num, den = inp.owner.inputs

if _is_provably_positive(num, strict=True):
if is_provably_positive(num, strict=True):
return [sign(den)]
if _is_provably_positive(den, strict=True):
if is_provably_positive(den, strict=True):
return [sign(num)]

for side, other in ((num, den), (den, num)):
if isinstance(side, Constant) and np.all(np.asarray(side.data) < 0):
if constant_is_all_negative(side):
return [neg(sign(other))]


Expand Down Expand Up @@ -859,6 +860,57 @@ def local_div_exp_to_mul_exp(fgraph, node):
return [new_out]


@register_specialize
@node_rewriter([true_div])
def local_div_reciprocal_to_mul(fgraph, node):
"""Replace ``A / reciprocal(B)`` with ``A * B`` and ``A / y ** (-p)`` with ``A * y ** p``."""
num, denom = node.inputs

match denom.owner_op_and_inputs:
case (Elemwise(scalar_op=ps.Reciprocal()), b):
inverted = b
case (Elemwise(scalar_op=ps.Pow()), base, exponent):
match exponent.owner_op_and_inputs:
case (Elemwise(scalar_op=ps.Neg()), pos_exponent):
inverted = base**pos_exponent
case _ if constant_is_all_negative(exponent):
inverted = base**-exponent.data
case _:
return None
case _:
return None

new_out = num * inverted
if new_out.dtype != node.outputs[0].dtype:
new_out = cast(new_out, dtype=node.outputs[0].dtype)
copy_stack_trace(node.outputs[0], new_out)
return [new_out]


@register_specialize
@node_rewriter([reciprocal])
def local_reciprocal_neg_pow_to_pow(fgraph, node):
"""Replace ``reciprocal(y ** (-p))`` with ``y ** p`` (the ``1 / y ** (-p)`` form)."""
[arg] = node.inputs

match arg.owner_op_and_inputs:
case (Elemwise(scalar_op=ps.Pow()), base, exponent):
match exponent.owner_op_and_inputs:
case (Elemwise(scalar_op=ps.Neg()), pos_exponent):
new_out = base**pos_exponent
case _ if constant_is_all_negative(exponent):
new_out = base**-exponent.data
case _:
return None
case _:
return None

if new_out.dtype != node.outputs[0].dtype:
new_out = cast(new_out, dtype=node.outputs[0].dtype)
copy_stack_trace(node.outputs[0], new_out)
return [new_out]


@register_specialize
@node_rewriter([mul, true_div])
def local_mul_pow_to_pow_add(fgraph, node):
Expand Down
Loading
Loading