diff --git a/src/relax/transform/adjust_matmul_order.cc b/src/relax/transform/adjust_matmul_order.cc index 9ea47aa64844..efdacb4e0e83 100644 --- a/src/relax/transform/adjust_matmul_order.cc +++ b/src/relax/transform/adjust_matmul_order.cc @@ -141,20 +141,44 @@ std::tuple)>> auto shape_b = opt_shape_b.value(); auto shape_c = opt_shape_c.value(); + auto permute_last_two_dims = [&](Expr expr) -> Expr { + auto opt_shape = get_shape(expr); + if (!opt_shape) return expr; + + size_t ndim = opt_shape.value().size(); + TVM_FFI_ICHECK_GE(ndim, 2); + + ffi::Optional> axes; + + if (ndim == 2) { + // Pass none axes to permute_dims for simple transpose of 2D tensors. + axes = std::nullopt; + } else { + ffi::Array axes_array; + for (size_t i = 0; i < ndim; ++i) axes_array.push_back(i); + axes_array.Set(ndim - 1, ndim - 2); + axes_array.Set(ndim - 2, ndim - 1); + axes = ffi::Optional>(axes_array); + } + return permute_dims(std::move(expr), axes); + }; + + auto transpose_shape_last_two_dims = [&](ffi::Array& shape) { + PrimExpr last_dim_shape = shape[shape.size() - 1]; + shape.Set(shape.size() - 1, shape[shape.size() - 2]); + shape.Set(shape.size() - 2, last_dim_shape); + }; + if (matches.count(pat_permuted_matmul_on_lhs)) { - expr_a = permute_dims(expr_a, std::nullopt); - expr_b = permute_dims(expr_b, std::nullopt); - TVM_FFI_ICHECK_EQ(shape_a.size(), 2); - TVM_FFI_ICHECK_EQ(shape_b.size(), 2); - shape_a = {shape_a[1], shape_a[0]}; - shape_b = {shape_b[1], shape_b[0]}; + expr_a = permute_last_two_dims(expr_a); + expr_b = permute_last_two_dims(expr_b); + transpose_shape_last_two_dims(shape_a); + transpose_shape_last_two_dims(shape_b); } else if (matches.count(pat_permuted_matmul_on_rhs)) { - expr_b = permute_dims(expr_b, std::nullopt); - expr_c = permute_dims(expr_c, std::nullopt); - TVM_FFI_ICHECK_EQ(shape_b.size(), 2); - TVM_FFI_ICHECK_EQ(shape_c.size(), 2); - shape_b = {shape_b[1], shape_b[0]}; - shape_c = {shape_c[1], shape_c[0]}; + expr_b = permute_last_two_dims(expr_b); + expr_c = permute_last_two_dims(expr_c); + transpose_shape_last_two_dims(shape_b); + transpose_shape_last_two_dims(shape_c); } // If two of the three are compile-time, group those two values @@ -166,13 +190,7 @@ std::tuple)>> } // Otherwise, select the order that reduces the total number of - // operations required, assuming a naive matmul. - - // Matmul on LHS: ([N,R]*[R,M]) * [M,batch] - // Matmul on RHS: [N,R] * ([R,M]*[M,batch]) - // - // LHS first: `N*R*M + N*M*batch = N*M*(R+batch)` - // RHS first: `N*R*batch + R*M*batch = (N+M)*R*batch` + // operations required, assuming a naive matmul (see below). if (shape_a.size() == 1) { shape_a = {IntImm(shape_a[0].dtype(), 1), shape_a[0]}; @@ -192,13 +210,41 @@ std::tuple)>> shape_c = {shape_c[0], IntImm(shape_c[0].dtype(), 1)}; } - auto size_N = shape_a[shape_a.size() - 2]; - auto size_R = shape_a[shape_a.size() - 1]; - auto size_M = shape_c[shape_c.size() - 2]; - auto size_B = shape_c[shape_c.size() - 1]; + PrimExpr size_N = shape_a[shape_a.size() - 2]; // row of A + PrimExpr size_R = shape_a[shape_a.size() - 1]; // col of A and row of B + PrimExpr size_M = shape_c[shape_c.size() - 2]; // row of C and col of B + PrimExpr size_B = shape_c[shape_c.size() - 1]; // col of C + + auto calculate_batch = [](ffi::Array& shape) { + PrimExpr batch = 1; + for (size_t i = 0; i < shape.size() - 2; ++i) { + batch *= shape[i]; + } + return batch; + }; + + PrimExpr batch_A = calculate_batch(shape_a); + PrimExpr batch_B = calculate_batch(shape_b); + PrimExpr batch_C = calculate_batch(shape_c); - auto ops_with_lhs_first = (size_R + size_B) * size_N * size_M; - auto ops_with_rhs_first = (size_M + size_N) * size_R * size_B; + // Compare naive matmul FLOPs for two evaluation orders of + // matmul(A, matmul(B, C)) vs matmul(matmul(A, B), C) + // + // Matrix dims (last two axes of each operand): + // A: [N, R] B: [R, M] C: [M, B_last] + // Batch prefixes (product of all leading axes): + // batch_A, batch_B, batch_C + // + // LHS first — matmul(matmul(A, B), C): + // inner matmul(A, B): batch_A * batch_B * N * R * M + // outer matmul(., C): batch_A * batch_B * batch_C * N * M * B_last + // total: batch_A * batch_B * N * M * (R + batch_C * B_last) + PrimExpr ops_with_lhs_first = (size_R + batch_C * size_B) * size_N * size_M * batch_A * batch_B; + // RHS first — matmul(A, matmul(B, C)): + // inner matmul(B, C): batch_B * batch_C * R * M * B_last + // outer matmul(A, .): batch_A * batch_B * batch_C * N * R * B_last + // total: batch_B * batch_C * R * B_last * (M + batch_A * N) + PrimExpr ops_with_rhs_first = (size_M + batch_A * size_N) * size_R * size_B * batch_B * batch_C; arith::Analyzer analyzer; analyzer.rewrite_simplify.SetEnabledExtensions(static_cast( @@ -214,8 +260,7 @@ std::tuple)>> return matmul(expr_a, matmul(expr_b, expr_c, DataType::Void()), DataType::Void()); } - // If we cannot determine which order is best, keep the existing - // order. + // If we cannot determine which order is best, keep the existing order. return expr; }; diff --git a/tests/python/relax/test_transform_adjust_matmul_order.py b/tests/python/relax/test_transform_adjust_matmul_order.py index a086f3abdb8d..6adf1184581b 100644 --- a/tests/python/relax/test_transform_adjust_matmul_order.py +++ b/tests/python/relax/test_transform_adjust_matmul_order.py @@ -17,8 +17,11 @@ import inspect +import numpy as np import pytest +import torch +import tvm import tvm.testing from tvm import relax from tvm.script import ir as I @@ -234,8 +237,26 @@ class TestIdempotentRHSDynamic(Base): Expected = TestRHSDynamic.Expected -class TestLHSDynamicWithBatch(Base): - """Prefer (x*A)*B instead of x*(A*B)""" +class TestDynamicWithBatchSymbolic1(Base): + """Keep existing order when batch_size and lora_r are both symbolic. + + Before computes `x @ (A @ B)` with + `x: [batch_size, 1, 16]`, `A: [16, lora_r]`, `B: [lora_r, 32]`. + + RHS first (fuse A@B once, no batch on inner matmul): + 16*lora_r*32 + batch_size*1*16*32 = 512*(lora_r + batch_size) + + LHS first (both matmuls scale with batch_size): + batch_size*1*16*lora_r + batch_size*1*lora_r*32 = 48*batch_size*lora_r + + When `batch_size` and `lora_r` are known at compile-time: + - satisfy the inequality 48*batch_size*lora_r < 512*(lora_r + batch_size), + the LHS first is preferred. + - satisfy the inequality 512*(lora_r + batch_size) < 48*batch_size*lora_r, + the RHS first is preferred. + + Without bounds on `batch_size` and `lora_r`, neither side is provably cheaper. + """ @I.ir_module class Before: @@ -250,6 +271,31 @@ def main( out: R.Tensor([batch_size, 1, 32]) = R.matmul(x, weight) return out + Expected = Before + + +class TestDynamicWithBatchConcrete1LHSFirst(Base): + """With concrete shapes, LHS first is provably cheaper. + + batch_size=4, lora_r=16: + LHS first: 48*4*16 = 3072 + RHS first: 512*(16 + 4) = 10240 + """ + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor(["batch_size", 1, 16]), + A: R.Tensor([16, "lora_r"]), + B: R.Tensor(["lora_r", 32]), + ) -> R.Tensor(["batch_size", 1, 32]): + batch_size = T.int64(4) + lora_r = T.int64(16) # noqa: F841 + weight: R.Tensor([16, 32]) = R.matmul(A, B) + out: R.Tensor([batch_size, 1, 32]) = R.matmul(x, weight) + return out + @I.ir_module class Expected: @R.function @@ -258,15 +304,70 @@ def main( A: R.Tensor([16, "lora_r"]), B: R.Tensor(["lora_r", 32]), ) -> R.Tensor(["batch_size", 1, 32]): - lora_r = T.int64() - batch_size = T.int64() - x: R.Tensor([batch_size, 1, lora_r]) = R.matmul(x, A) - x: R.Tensor([batch_size, 1, 32]) = R.matmul(x, B) - return x + batch_size = T.int64(4) + lora_r = T.int64(16) + weight: R.Tensor([batch_size, 1, lora_r]) = R.matmul(x, A) + out: R.Tensor([batch_size, 1, 32]) = R.matmul(weight, B) + return out -class TestRHSDynamicWithBatch(Base): - """Prefer A*(B*x) instead of (A*B)*x""" +class TestDynamicWithBatchConcrete1RHSFirst(Base): + """With concrete shapes, RHS first is provably cheaper. + + batch_size=64, lora_r=16: + LHS first: 48*64*16 = 49152 + RHS first: 512*(16 + 64) = 40960 + """ + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor(["batch_size", 1, 16]), + A: R.Tensor([16, "lora_r"]), + B: R.Tensor(["lora_r", 32]), + ) -> R.Tensor(["batch_size", 1, 32]): + batch_size = T.int64(64) + lora_r = T.int64(16) + weight: R.Tensor([batch_size, 1, lora_r]) = R.matmul(x, A) + out: R.Tensor([batch_size, 1, 32]) = R.matmul(weight, B) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(["batch_size", 1, 16]), + A: R.Tensor([16, "lora_r"]), + B: R.Tensor(["lora_r", 32]), + ) -> R.Tensor(["batch_size", 1, 32]): + batch_size = T.int64(64) + lora_r = T.int64(16) # noqa: F841 + weight: R.Tensor([16, 32]) = R.matmul(A, B) + out: R.Tensor([batch_size, 1, 32]) = R.matmul(x, weight) + return out + + +class TestDynamicWithBatchSymbolic2(Base): + """Keep existing order when batch_size and lora_r are both symbolic. + + Before computes `(A @ B) @ x` with + `A: [32, lora_r]`, `B: [lora_r, 16]`, `x: [batch_size, 16, 1]`. + + LHS first (fuse A@B once, no batch on inner matmul): + 32*lora_r*16 + batch_size*32*16*1 = 512*(lora_r + batch_size) + + RHS first (both matmuls scale with batch_size): + batch_size*lora_r*16*1 + batch_size*32*lora_r*1 = 48*batch_size*lora_r + + When `batch_size` and `lora_r` are known at compile-time: + - satisfy the inequality 48*batch_size*lora_r < 512*(lora_r + batch_size), + the RHS first is preferred. + - satisfy the inequality 512*(lora_r + batch_size) < 48*batch_size*lora_r, + the LHS first is preferred. + + Without bounds on `batch_size` and `lora_r`, neither side is provably cheaper. + """ @I.ir_module class Before: @@ -281,6 +382,31 @@ def main( out: R.Tensor([batch_size, 32, 1]) = R.matmul(weight, x) return out + Expected = Before + + +class TestDynamicWithBatchConcrete2RHSFirst(Base): + """With concrete shapes, RHS first is provably cheaper. + + batch_size=4, lora_r=16: + RHS first: 48*4*16 = 3072 + LHS first: 512*(16 + 4) = 10240 + """ + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor(["batch_size", 16, 1]), + A: R.Tensor([32, "lora_r"]), + B: R.Tensor(["lora_r", 16]), + ) -> R.Tensor(["batch_size", 32, 1]): + batch_size = T.int64(4) + lora_r = T.int64(16) # noqa: F841 + weight: R.Tensor([32, 16]) = R.matmul(A, B) + out: R.Tensor([batch_size, 32, 1]) = R.matmul(weight, x) + return out + @I.ir_module class Expected: @R.function @@ -289,11 +415,48 @@ def main( A: R.Tensor([32, "lora_r"]), B: R.Tensor(["lora_r", 16]), ) -> R.Tensor(["batch_size", 32, 1]): - lora_r = T.int64() - batch_size = T.int64() - x: R.Tensor([batch_size, lora_r, 1]) = R.matmul(B, x) - x: R.Tensor([batch_size, 32, 1]) = R.matmul(A, x) - return x + batch_size = T.int64(4) + lora_r = T.int64(16) + weight: R.Tensor([batch_size, lora_r, 1]) = R.matmul(B, x) + out: R.Tensor([batch_size, 32, 1]) = R.matmul(A, weight) + return out + + +class TestDynamicWithBatchConcrete2LHSFirst(Base): + """With concrete shapes, LHS first is provably cheaper. + + batch_size=64, lora_r=16: + RHS first: 48*64*16 = 49152 + LHS first: 512*(16 + 64) = 40960 + """ + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor(["batch_size", 16, 1]), + A: R.Tensor([32, "lora_r"]), + B: R.Tensor(["lora_r", 16]), + ) -> R.Tensor(["batch_size", 32, 1]): + batch_size = T.int64(64) + lora_r = T.int64(16) + weight: R.Tensor([batch_size, lora_r, 1]) = R.matmul(B, x) + out: R.Tensor([batch_size, 32, 1]) = R.matmul(A, weight) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(["batch_size", 16, 1]), + A: R.Tensor([32, "lora_r"]), + B: R.Tensor(["lora_r", 16]), + ) -> R.Tensor(["batch_size", 32, 1]): + batch_size = T.int64(64) + lora_r = T.int64(16) # noqa: F841 + weight: R.Tensor([32, 16]) = R.matmul(A, B) + out: R.Tensor([batch_size, 32, 1]) = R.matmul(weight, x) + return out class TestNoOpForFullyDynamicOnLHS(Base): @@ -513,5 +676,73 @@ def main( return x +class TestAdjustMatmulOrderAttentionBlock: + """AdjustMatmulOrder preserves numerics on a batched attention block. + + Covers ND `permute_dims` (swap last two axes) inside `matmul(q, kt)`, + regression for issue #19576. + """ + + def _build_attention_module(self, batch, seq, dim): + """Minimal batched attention block exercising ND permute_dims + matmul.""" + bb = relax.BlockBuilder() + x = relax.Var("x", relax.TensorStructInfo((batch, seq, dim), "float32")) + wq = relax.Var("wq", relax.TensorStructInfo((dim, dim), "float32")) + wk = relax.Var("wk", relax.TensorStructInfo((dim, dim), "float32")) + wv = relax.Var("wv", relax.TensorStructInfo((dim, dim), "float32")) + wo = relax.Var("wo", relax.TensorStructInfo((dim, dim), "float32")) + with bb.function("main", [x, wq, wk, wv, wo]): + with bb.dataflow(): + q = bb.emit(relax.op.matmul(x, wq)) + k = bb.emit(relax.op.matmul(x, wk)) + v = bb.emit(relax.op.matmul(x, wv)) + kt = bb.emit(relax.op.permute_dims(k, axes=[0, 2, 1])) + scores = bb.emit(relax.op.matmul(q, kt)) + scale = bb.emit(relax.const(1.0 / np.sqrt(dim), "float32")) + scores = bb.emit(relax.op.multiply(scores, scale)) + attn = bb.emit(relax.op.nn.softmax(scores, axis=-1)) + out = bb.emit(relax.op.matmul(attn, v)) + proj = bb.emit_output(relax.op.matmul(out, wo)) + bb.emit_func_output(proj) + return bb.finalize() + + def _run_relax_main(self, mod, inputs): + exe = relax.build(mod, target="llvm") + vm = relax.VirtualMachine(exe, device=tvm.cpu()) + args = [tvm.runtime.tensor(arr, device=tvm.cpu()) for arr in inputs] + return vm["main"](*args).numpy() + + def _torch_attention_ref(self, x_np, w_np, dim): + x = torch.from_numpy(x_np) + w = torch.from_numpy(w_np) + with torch.no_grad(): + q = torch.matmul(x, w) + k = torch.matmul(x, w) + v = torch.matmul(x, w) + scores = torch.matmul(q, k.transpose(-2, -1)) + scores = scores * (1.0 / np.sqrt(dim)) + attn = torch.nn.functional.softmax(scores, dim=-1) + out = torch.matmul(attn, v) + out = torch.matmul(out, w) + return out.detach().numpy() + + @pytest.mark.parametrize("batch,seq,dim", [(2, 16, 64)]) + def test_attention_block_numerics(self, batch, seq, dim): + mod = self._build_attention_module(batch, seq, dim) + mod_opt = relax.transform.AdjustMatmulOrder()(mod) + + x_np = np.random.randn(batch, seq, dim).astype("float32") + w_np = np.random.randn(dim, dim).astype("float32") + inputs = [x_np, w_np, w_np, w_np, w_np] + + ref = self._torch_attention_ref(x_np, w_np, dim) + out_before = self._run_relax_main(mod, inputs) + out_after = self._run_relax_main(mod_opt, inputs) + + tvm.testing.assert_allclose(out_before, ref, rtol=1e-3, atol=1e-3) + tvm.testing.assert_allclose(out_after, ref, rtol=1e-3, atol=1e-3) + tvm.testing.assert_allclose(out_before, out_after, rtol=1e-5, atol=1e-5) + + if __name__ == "__main__": tvm.testing.main()