Skip to content

Commit a9b4a4f

Browse files
committed
Refactoring Duplicate cuBLAS/hipBLAS Tests
1 parent 45bef45 commit a9b4a4f

File tree

3 files changed

+149
-147
lines changed

3 files changed

+149
-147
lines changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Shared test utilities for cuBLAS and hipBLAS codegen tests."""
18+
import numpy as np
19+
20+
import tvm
21+
from tvm import relax
22+
from tvm.relax.testing import get_relax_matmul_module
23+
24+
25+
def build_and_run(mod, inputs_np, target, legalize=False, cuda_graph=False):
26+
dev = tvm.device(target, 0)
27+
with tvm.transform.PassContext(
28+
config={
29+
"relax.backend.use_cuda_graph": cuda_graph,
30+
"relax.transform.apply_legalize_ops": legalize,
31+
}
32+
):
33+
ex = tvm.compile(mod, target)
34+
vm = relax.VirtualMachine(ex, dev)
35+
f = vm["main"]
36+
inputs = [tvm.runtime.tensor(inp, dev) for inp in inputs_np]
37+
38+
# For cuda graph, run the compiled function twice to make sure that we can launch the cached
39+
# graph on the second run.
40+
if cuda_graph:
41+
f(*inputs)
42+
43+
return f(*inputs).numpy()
44+
45+
46+
def to_concrete_shape(symbolic_shape, var_table):
47+
result = []
48+
for dim in symbolic_shape:
49+
if not isinstance(dim, tvm.tir.expr.Var):
50+
result.append(dim)
51+
continue
52+
53+
if dim not in var_table:
54+
var_table[dim] = np.random.randint(10, 50)
55+
result.append(var_table[dim])
56+
57+
return tuple(result)
58+
59+
60+
def run_matmul_offload_test(
61+
x_shape,
62+
y_shape,
63+
transpose_y,
64+
epilogue,
65+
in_dtype,
66+
out_dtype,
67+
epilogue_table,
68+
partition_fn,
69+
target,
70+
):
71+
"""Shared test logic for matmul offload tests across different BLAS backends.
72+
73+
Parameters
74+
----------
75+
x_shape : tuple
76+
Shape of the first input tensor.
77+
y_shape : tuple
78+
Shape of the second input tensor.
79+
transpose_y : bool
80+
Whether to transpose the second input.
81+
epilogue : str
82+
Type of epilogue operation.
83+
in_dtype : str
84+
Input data type.
85+
out_dtype : str
86+
Output data type.
87+
epilogue_table : dict
88+
Mapping of epilogue names to (with_bias, activation) tuples.
89+
partition_fn : callable
90+
Function to partition the module for the specific BLAS backend.
91+
target : str
92+
Target device (e.g., "cuda" or "rocm").
93+
"""
94+
with_bias, activation = epilogue_table[epilogue]
95+
var_table = {}
96+
concrete_x_shape = to_concrete_shape(x_shape, var_table)
97+
concrete_y_shape = to_concrete_shape(y_shape, var_table)
98+
x = np.random.randn(*concrete_x_shape).astype(in_dtype)
99+
y = np.random.randn(*concrete_y_shape).astype(in_dtype)
100+
101+
if transpose_y:
102+
y = np.swapaxes(y, -2, -1)
103+
y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2])
104+
105+
if with_bias:
106+
bias = np.random.randn(concrete_y_shape[-1]).astype(out_dtype)
107+
args = (x, y, bias)
108+
else:
109+
bias = None
110+
args = (x, y)
111+
112+
mod = get_relax_matmul_module(
113+
x_shape,
114+
y_shape,
115+
in_dtype,
116+
out_dtype,
117+
bias_shape=bias.shape if with_bias else None,
118+
transposed_y=transpose_y,
119+
activation=activation,
120+
)
121+
122+
mod = partition_fn(mod)
123+
mod = relax.transform.RunCodegen()(mod)
124+
out = build_and_run(mod, args, target)
125+
ref = build_and_run(mod, args, "llvm", legalize=True)
126+
127+
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)

tests/python/relax/test_codegen_cublas.py

Lines changed: 14 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from tvm.script.ir_builder import IRBuilder
2828
from tvm.script.ir_builder import relax as relax_builder
2929

30+
from test_codegen_blas_common import build_and_run, to_concrete_shape, run_matmul_offload_test
31+
3032
try:
3133
import ml_dtypes
3234
except ImportError:
@@ -41,48 +43,13 @@ def reset_seed():
4143
pytestmark = tvm.testing.requires_cublas.marks()
4244

4345

44-
def build_and_run(mod, inputs_np, target, legalize=False, cuda_graph=False):
45-
dev = tvm.device(target, 0)
46-
with tvm.transform.PassContext(
47-
config={
48-
"relax.backend.use_cuda_graph": cuda_graph,
49-
"relax.transform.apply_legalize_ops": legalize,
50-
}
51-
):
52-
ex = tvm.compile(mod, target)
53-
vm = relax.VirtualMachine(ex, dev)
54-
f = vm["main"]
55-
inputs = [tvm.runtime.tensor(inp, dev) for inp in inputs_np]
56-
57-
# For cuda graph, run the compiled function twice to make sure that we can launch the cached
58-
# graph on the second run.
59-
if cuda_graph:
60-
f(*inputs)
61-
62-
return f(*inputs).numpy()
63-
64-
6546
def get_result_with_relax_cublas_offload(mod, np_inputs, cuda_graph=False, bind_constants=False):
6647
mod = partition_for_cublas(mod, bind_constants=bind_constants)
6748
mod = relax.transform.RunCodegen()(mod)
6849

6950
return build_and_run(mod, np_inputs, "cuda", cuda_graph)
7051

7152

72-
def _to_concrete_shape(symbolic_shape, var_table):
73-
result = []
74-
for dim in symbolic_shape:
75-
if not isinstance(dim, tvm.tir.expr.Var):
76-
result.append(dim)
77-
continue
78-
79-
if dim not in var_table:
80-
var_table[dim] = np.random.randint(10, 50)
81-
result.append(var_table[dim])
82-
83-
return tuple(result)
84-
85-
8653
_vars = {
8754
"a": tvm.tir.expr.Var("a", "int64"),
8855
"b": tvm.tir.expr.Var("b", "int64"),
@@ -204,39 +171,18 @@ def test_matmul_offload(
204171
in_dtype,
205172
out_dtype,
206173
):
207-
with_bias, activation = _epilogue_table[epilogue]
208-
var_table = {}
209-
concrete_x_shape = _to_concrete_shape(x_shape, var_table)
210-
concrete_y_shape = _to_concrete_shape(y_shape, var_table)
211-
x = np.random.randn(*concrete_x_shape).astype(in_dtype)
212-
y = np.random.randn(*concrete_y_shape).astype(in_dtype)
213-
214-
if transpose_y:
215-
y = np.swapaxes(y, -2, -1)
216-
y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2])
217-
218-
if with_bias:
219-
bias = np.random.randn(concrete_y_shape[-1]).astype(out_dtype)
220-
args = (x, y, bias)
221-
else:
222-
bias = None
223-
args = (x, y)
224-
225-
mod = get_relax_matmul_module(
174+
run_matmul_offload_test(
226175
x_shape,
227176
y_shape,
177+
transpose_y,
178+
epilogue,
228179
in_dtype,
229180
out_dtype,
230-
bias_shape=bias.shape if with_bias else None,
231-
transposed_y=transpose_y,
232-
activation=activation,
181+
_epilogue_table,
182+
partition_for_cublas,
183+
"cuda",
233184
)
234185

235-
out = get_result_with_relax_cublas_offload(mod, args)
236-
ref = build_and_run(mod, args, "llvm", legalize=True)
237-
238-
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
239-
240186

241187
@pytest.mark.parametrize(
242188
"x_shape, y_shape, transpose_y, epilogue",
@@ -265,39 +211,18 @@ def test_matmul_igemm_offload(
265211
):
266212
in_dtype = "int8"
267213
out_dtype = "int32"
268-
with_bias, activation = _epilogue_table[epilogue]
269-
var_table = {}
270-
concrete_x_shape = _to_concrete_shape(x_shape, var_table)
271-
concrete_y_shape = _to_concrete_shape(y_shape, var_table)
272-
x = np.random.randn(*concrete_x_shape).astype(in_dtype)
273-
y = np.random.randn(*concrete_y_shape).astype(in_dtype)
274-
275-
if transpose_y:
276-
y = np.swapaxes(y, -2, -1)
277-
y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2])
278-
279-
if with_bias:
280-
bias = np.random.randn(concrete_y_shape[-1]).astype(out_dtype)
281-
args = (x, y, bias)
282-
else:
283-
bias = None
284-
args = (x, y)
285-
286-
mod = get_relax_matmul_module(
214+
run_matmul_offload_test(
287215
x_shape,
288216
y_shape,
217+
transpose_y,
218+
epilogue,
289219
in_dtype,
290220
out_dtype,
291-
bias_shape=bias.shape if with_bias else None,
292-
transposed_y=transpose_y,
293-
activation=activation,
221+
_epilogue_table,
222+
partition_for_cublas,
223+
"cuda",
294224
)
295225

296-
out = get_result_with_relax_cublas_offload(mod, args)
297-
ref = build_and_run(mod, args, "llvm", legalize=True)
298-
299-
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
300-
301226

302227
@tvm.testing.requires_cuda_compute_version(9)
303228
@pytest.mark.skipif(ml_dtypes is None, reason="requires ml_dtypes to be installed")

tests/python/relax/test_codegen_hipblas.py

Lines changed: 8 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from tvm.relax.testing import get_relax_matmul_module
2626
from tvm.script import relax as R
2727

28+
from test_codegen_blas_common import run_matmul_offload_test
29+
2830
try:
2931
import ml_dtypes
3032
except ImportError:
@@ -39,37 +41,6 @@ def reset_seed():
3941
pytestmark = tvm.testing.requires_hipblas.marks()
4042

4143

42-
def build_and_run(mod, inputs_np, target, legalize=False):
43-
dev = tvm.device(target, 0)
44-
with tvm.transform.PassContext(config={"relax.transform.apply_legalize_ops": legalize}):
45-
ex = tvm.compile(mod, target)
46-
vm = relax.VirtualMachine(ex, dev)
47-
f = vm["main"]
48-
inputs = [tvm.runtime.tensor(inp, dev) for inp in inputs_np]
49-
return f(*inputs).numpy()
50-
51-
52-
def get_result_with_relax_cublas_offload(mod, np_inputs):
53-
mod = partition_for_hipblas(mod)
54-
mod = relax.transform.RunCodegen()(mod)
55-
56-
return build_and_run(mod, np_inputs, "rocm")
57-
58-
59-
def _to_concrete_shape(symbolic_shape, var_table):
60-
result = []
61-
for dim in symbolic_shape:
62-
if not isinstance(dim, tvm.tir.expr.Var):
63-
result.append(dim)
64-
continue
65-
66-
if dim not in var_table:
67-
var_table[dim] = np.random.randint(10, 50)
68-
result.append(var_table[dim])
69-
70-
return tuple(result)
71-
72-
7344
_vars = {
7445
"a": tvm.tir.expr.Var("a", "int64"),
7546
"b": tvm.tir.expr.Var("b", "int64"),
@@ -118,39 +89,18 @@ def test_matmul_offload(
11889
in_dtype,
11990
out_dtype,
12091
):
121-
with_bias, activation = _epilogue_table[epilogue]
122-
var_table = {}
123-
concrete_x_shape = _to_concrete_shape(x_shape, var_table)
124-
concrete_y_shape = _to_concrete_shape(y_shape, var_table)
125-
x = np.random.randn(*concrete_x_shape).astype(in_dtype)
126-
y = np.random.randn(*concrete_y_shape).astype(in_dtype)
127-
128-
if transpose_y:
129-
y = np.swapaxes(y, -2, -1)
130-
y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2])
131-
132-
if with_bias:
133-
bias = np.random.randn(concrete_y_shape[-1]).astype(out_dtype)
134-
args = (x, y, bias)
135-
else:
136-
bias = None
137-
args = (x, y)
138-
139-
mod = get_relax_matmul_module(
92+
run_matmul_offload_test(
14093
x_shape,
14194
y_shape,
95+
transpose_y,
96+
epilogue,
14297
in_dtype,
14398
out_dtype,
144-
bias_shape=bias.shape if with_bias else None,
145-
transposed_y=transpose_y,
146-
activation=activation,
99+
_epilogue_table,
100+
partition_for_hipblas,
101+
"rocm",
147102
)
148103

149-
out = get_result_with_relax_cublas_offload(mod, args)
150-
ref = build_and_run(mod, args, "llvm", legalize=True)
151-
152-
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
153-
154104

155105
def test_hipblas_partition_matmul_without_bias():
156106
# hipBLAS does not handle 2D bias (residual input)

0 commit comments

Comments
 (0)