Skip to content
14 changes: 12 additions & 2 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2238,13 +2238,23 @@ def test_grouped_linear_accuracy(
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize(
"fp8_model_params",
all_boolean if IS_HIP_EXTENSION else [False],
)
@pytest.mark.parametrize(
"recipe",
(fp8_recipes + [None]) if IS_HIP_EXTENSION else [None],
)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
@pytest.mark.parametrize("delay_wgrad_compute", all_boolean)
def test_grouped_linear_accuracy_cutlass(
dtype,
num_gemms,
bs,
model,
recipe,
fp8_model_params,
fuse_wgrad_accumulation,
delay_wgrad_compute,
):
Expand All @@ -2254,8 +2264,8 @@ def test_grouped_linear_accuracy_cutlass(
num_gemms,
bs,
model,
None,
False,
recipe,
fp8_model_params,
fuse_wgrad_accumulation,
False,
delay_wgrad_compute,
Expand Down
4 changes: 3 additions & 1 deletion transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,9 @@ else()
comm_gemm_overlap/rocm_comm_gemm_overlap.cpp
fused_attn_rocm/fused_attn.cpp
gemm/rocm_gemm.cu
gemm/ck_grouped_gemm.cpp
gemm/ck_grouped_gemm/ck_grouped_gemm.cpp
gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp
gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp
amd_detail/system.cpp)
list(APPEND transformer_engine_cuda_sources
fused_attn_rocm/fused_attn_aotriton.cpp
Expand Down
338 changes: 0 additions & 338 deletions transformer_engine/common/gemm/ck_grouped_gemm.cpp

This file was deleted.

Loading
Loading