From 23749303b194d4980ac2dc5e9465d1eb3609d1a7 Mon Sep 17 00:00:00 2001 From: Alessandro Gentili Date: Wed, 29 Apr 2026 20:17:10 +0200 Subject: [PATCH 1/3] rewrite solve matrix inverse as matmul --- pytensor/tensor/rewriting/linalg/solvers.py | 18 +++++++++++++ tests/tensor/rewriting/linalg/test_solvers.py | 25 +++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/pytensor/tensor/rewriting/linalg/solvers.py b/pytensor/tensor/rewriting/linalg/solvers.py index 01ad379e1f..658722171a 100644 --- a/pytensor/tensor/rewriting/linalg/solvers.py +++ b/pytensor/tensor/rewriting/linalg/solvers.py @@ -17,6 +17,7 @@ from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.linalg.decomposition.cholesky import Cholesky, cholesky from pytensor.tensor.linalg.decomposition.lu import lu_factor +from pytensor.tensor.linalg.inverse import MatrixInverse from pytensor.tensor.linalg.solvers.core import SolveBase from pytensor.tensor.linalg.solvers.general import Solve, lu_solve from pytensor.tensor.linalg.solvers.linear_control import ( @@ -162,6 +163,23 @@ def scalar_solve_to_division(fgraph, node): return [new_out] +@register_stabilize +@node_rewriter([blockwise_of(SolveBase)]) +def solve_of_inv_to_matmul(fgraph, node): + """Replace solve(matrix_inverse(X), b) with X @ b. + + If A = inv(X), then solve(A, b) finds x such that A @ x = b, + i.e., inv(X) @ x = b, so x = X @ b. + """ + A, b = node.inputs + + match A.owner_op_and_inputs: + case (Blockwise(MatrixInverse()), X): + new_out = X @ b + copy_stack_trace(node.outputs[0], new_out) + return [new_out] + + def decompose_A(A, assume_a, lower): if assume_a == "gen": return lu_factor(A) diff --git a/tests/tensor/rewriting/linalg/test_solvers.py b/tests/tensor/rewriting/linalg/test_solvers.py index 6f554da1fd..f7f8c6bd97 100644 --- a/tests/tensor/rewriting/linalg/test_solvers.py +++ b/tests/tensor/rewriting/linalg/test_solvers.py @@ -13,6 +13,7 @@ from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape from pytensor.tensor.linalg.decomposition.cholesky import Cholesky, cholesky from pytensor.tensor.linalg.decomposition.lu import LUFactor +from pytensor.tensor.linalg.inverse import MatrixInverse from pytensor.tensor.linalg.solvers.core import SolveBase from pytensor.tensor.linalg.solvers.general import Solve, solve from pytensor.tensor.linalg.solvers.psd import CholeskySolve, cho_solve @@ -24,6 +25,7 @@ from pytensor.tensor.rewriting.linalg.solvers import ( reuse_decomposition_multiple_solves, scan_split_non_sequence_decomposition_and_solve, + solve_of_inv_to_matmul, ) from pytensor.tensor.type import matrix, tensor @@ -439,3 +441,26 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed): resx1 = fn_opt(A_test, x0_test) rtol = 1e-7 if config.floatX == "float64" else 1e-4 np.testing.assert_allclose(resx0, resx1, rtol=rtol) + + +@pytest.mark.parametrize("b_ndim", [1, 2], ids=lambda x: f"b_ndim={x}") +def test_solve_of_inv_to_matmul(b_ndim): + X = pt.dmatrix("X") + if b_ndim == 1: + b = pt.dvector("b") + else: + b = pt.dmatrix("b") + + out = solve(pt.linalg.inv(X), b, b_ndim=b_ndim) + + # Graph rewrite test + rewrite_name = solve_of_inv_to_matmul.__name__ + mode = get_default_mode() + + fn_opt = function([X, b], out, mode=mode.including(rewrite_name)) + opt_nodes = fn_opt.maker.fgraph.apply_nodes + + assert not any( + isinstance(getattr(node.op, "core_op", node.op), Solve | MatrixInverse) + for node in opt_nodes + ) From ae607a105ce69004b179aa9d0cc11c7cdbc3ad30 Mon Sep 17 00:00:00 2001 From: Alessandro Gentili Date: Thu, 30 Apr 2026 23:45:27 +0200 Subject: [PATCH 2/3] implement the rewrite manually inside the test_solve_of_inv_to_matmul() --- tests/tensor/rewriting/linalg/test_solvers.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/tensor/rewriting/linalg/test_solvers.py b/tests/tensor/rewriting/linalg/test_solvers.py index f7f8c6bd97..da7c68bd22 100644 --- a/tests/tensor/rewriting/linalg/test_solvers.py +++ b/tests/tensor/rewriting/linalg/test_solvers.py @@ -9,6 +9,8 @@ from pytensor.configdefaults import config from pytensor.gradient import grad from pytensor.graph import ancestors +from pytensor.graph.rewriting.utils import rewrite_graph +from pytensor.graph.traversal import io_toposort from pytensor.scan.op import Scan from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape from pytensor.tensor.linalg.decomposition.cholesky import Cholesky, cholesky @@ -25,7 +27,6 @@ from pytensor.tensor.rewriting.linalg.solvers import ( reuse_decomposition_multiple_solves, scan_split_non_sequence_decomposition_and_solve, - solve_of_inv_to_matmul, ) from pytensor.tensor.type import matrix, tensor @@ -454,13 +455,14 @@ def test_solve_of_inv_to_matmul(b_ndim): out = solve(pt.linalg.inv(X), b, b_ndim=b_ndim) # Graph rewrite test - rewrite_name = solve_of_inv_to_matmul.__name__ - mode = get_default_mode() + # We include 'stabilize' because solve_of_inv_to_matmul is registered there. + # This avoids dependency on the global config.mode (e.g. FAST_COMPILE). + rewritten_out = rewrite_graph(out, include=["stabilize"]) - fn_opt = function([X, b], out, mode=mode.including(rewrite_name)) - opt_nodes = fn_opt.maker.fgraph.apply_nodes + # Get all nodes in the rewritten graph + all_nodes = io_toposort([], [rewritten_out]) assert not any( isinstance(getattr(node.op, "core_op", node.op), Solve | MatrixInverse) - for node in opt_nodes + for node in all_nodes ) From 8fa4b92ff9eaa16c4256db05a97071e403a6add3 Mon Sep 17 00:00:00 2001 From: Alessandro Gentili Date: Fri, 1 May 2026 13:45:41 +0200 Subject: [PATCH 3/3] update test funciton --- tests/tensor/rewriting/linalg/test_solvers.py | 37 ++++++++++--------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/tests/tensor/rewriting/linalg/test_solvers.py b/tests/tensor/rewriting/linalg/test_solvers.py index da7c68bd22..fe36d886cf 100644 --- a/tests/tensor/rewriting/linalg/test_solvers.py +++ b/tests/tensor/rewriting/linalg/test_solvers.py @@ -9,13 +9,12 @@ from pytensor.configdefaults import config from pytensor.gradient import grad from pytensor.graph import ancestors +from pytensor.graph.rewriting.basic import WalkingGraphRewriter from pytensor.graph.rewriting.utils import rewrite_graph -from pytensor.graph.traversal import io_toposort from pytensor.scan.op import Scan from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape from pytensor.tensor.linalg.decomposition.cholesky import Cholesky, cholesky from pytensor.tensor.linalg.decomposition.lu import LUFactor -from pytensor.tensor.linalg.inverse import MatrixInverse from pytensor.tensor.linalg.solvers.core import SolveBase from pytensor.tensor.linalg.solvers.general import Solve, solve from pytensor.tensor.linalg.solvers.psd import CholeskySolve, cho_solve @@ -27,8 +26,10 @@ from pytensor.tensor.rewriting.linalg.solvers import ( reuse_decomposition_multiple_solves, scan_split_non_sequence_decomposition_and_solve, + solve_of_inv_to_matmul, ) from pytensor.tensor.type import matrix, tensor +from tests.unittest_tools import assert_equal_computations def test_generic_solve_to_solve_triangular(): @@ -447,22 +448,24 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed): @pytest.mark.parametrize("b_ndim", [1, 2], ids=lambda x: f"b_ndim={x}") def test_solve_of_inv_to_matmul(b_ndim): X = pt.dmatrix("X") - if b_ndim == 1: - b = pt.dvector("b") - else: - b = pt.dmatrix("b") - + b = pt.dvector("b") if b_ndim == 1 else pt.dmatrix("b") out = solve(pt.linalg.inv(X), b, b_ndim=b_ndim) - # Graph rewrite test - # We include 'stabilize' because solve_of_inv_to_matmul is registered there. - # This avoids dependency on the global config.mode (e.g. FAST_COMPILE). - rewritten_out = rewrite_graph(out, include=["stabilize"]) + # Just include the rewrite we are testing + rewriter = WalkingGraphRewriter(solve_of_inv_to_matmul) + rewritten_out = rewrite_graph(out, custom_rewrite=rewriter) - # Get all nodes in the rewritten graph - all_nodes = io_toposort([], [rewritten_out]) + # Verify the rewrite + expected = X @ b + assert_equal_computations([rewritten_out], [expected]) - assert not any( - isinstance(getattr(node.op, "core_op", node.op), Solve | MatrixInverse) - for node in all_nodes - ) + # Numerical check + rng = np.random.default_rng(42) + X_val = (rng.random((4, 4)) + np.eye(4) * 4).astype(X.type.dtype) + b_val = rng.random((4,) if b_ndim == 1 else (4, 3)).astype(b.type.dtype) + + f_opt = function([X, b], rewritten_out) + res_opt = f_opt(X_val, b_val) + res_expected = np.linalg.solve(np.linalg.inv(X_val), b_val) + + np.testing.assert_allclose(res_opt, res_expected, rtol=1e-7)