diff --git a/pytensor/tensor/rewriting/linalg/solvers.py b/pytensor/tensor/rewriting/linalg/solvers.py index a84a183cda..a49e98acb4 100644 --- a/pytensor/tensor/rewriting/linalg/solvers.py +++ b/pytensor/tensor/rewriting/linalg/solvers.py @@ -18,6 +18,7 @@ from pytensor.tensor.linalg.constructors import BlockDiagonal from pytensor.tensor.linalg.decomposition.cholesky import Cholesky, cholesky from pytensor.tensor.linalg.decomposition.lu import lu_factor +from pytensor.tensor.linalg.inverse import MatrixInverse from pytensor.tensor.linalg.solvers.core import SolveBase from pytensor.tensor.linalg.solvers.general import Solve, lu_solve from pytensor.tensor.linalg.solvers.linear_control import ( @@ -163,6 +164,23 @@ def scalar_solve_to_division(fgraph, node): return [new_out] +@register_stabilize +@node_rewriter([blockwise_of(SolveBase)]) +def solve_of_inv_to_matmul(fgraph, node): + """Replace solve(matrix_inverse(X), b) with X @ b. + + If A = inv(X), then solve(A, b) finds x such that A @ x = b, + i.e., inv(X) @ x = b, so x = X @ b. + """ + A, b = node.inputs + + match A.owner_op_and_inputs: + case (Blockwise(MatrixInverse()), X): + new_out = X @ b + copy_stack_trace(node.outputs[0], new_out) + return [new_out] + + @register_canonicalize @register_stabilize @node_rewriter([blockwise_of(SolveBase)]) diff --git a/tests/tensor/rewriting/linalg/test_solvers.py b/tests/tensor/rewriting/linalg/test_solvers.py index 8aa8f161ba..de04108131 100644 --- a/tests/tensor/rewriting/linalg/test_solvers.py +++ b/tests/tensor/rewriting/linalg/test_solvers.py @@ -10,6 +10,7 @@ from pytensor.configdefaults import config from pytensor.gradient import grad from pytensor.graph import ancestors +from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.scan.op import Scan from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape from pytensor.tensor.linalg.constructors import BlockDiagonal @@ -28,6 +29,7 @@ scan_split_non_sequence_decomposition_and_solve, ) from pytensor.tensor.type import matrix, tensor +from tests.unittest_tools import assert_equal_computations def test_generic_solve_to_solve_triangular(): @@ -443,6 +445,22 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed): np.testing.assert_allclose(resx0, resx1, rtol=rtol) +@pytest.mark.parametrize("b_ndim", [1, 2], ids=lambda x: f"b_ndim={x}") +def test_solve_of_inv_to_matmul(b_ndim): + X = pt.dmatrix("X") + b = pt.dvector("b") if b_ndim == 1 else pt.dmatrix("b") + out = solve(pt.linalg.inv(X), b, b_ndim=b_ndim) + + # We include 'stabilize' because solve_of_inv_to_matmul is registered there. + # Note: rewrite_graph includes 'canonicalize' by default. + rewritten_out = rewrite_graph(out, include=["stabilize"]) + + # Verify the rewrite against stabilized 'X @ b' to ensure structural equality. + # stabilization lowers 'X @ b' (Matmul) to specific BLAS ops (like Dot). + expected = rewrite_graph(X @ b, include=["stabilize"]) + assert_equal_computations([rewritten_out], [expected]) + + @pytest.mark.parametrize( "b_ndim, solve_fn, expected_op, batch", [