Skip to content

Add FP8 Support For CK Tile Group GEMM#475

Open
aris134 wants to merge 7 commits intodevfrom
amartin/ck-grouped-gemm-fp8
Open

Add FP8 Support For CK Tile Group GEMM#475
aris134 wants to merge 7 commits intodevfrom
amartin/ck-grouped-gemm-fp8

Conversation

@aris134
Copy link
Copy Markdown

@aris134 aris134 commented Mar 6, 2026

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes https://github.com/ROCm/frameworks-internal/issues/15787

TODO:

  • Add support for gfx942/gfx950
  • Performance analysis and tuning

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Enables mixed precision (fp8/bf8) support for CK tile grouped GEMM with tensor quantization on gfx942/gfx950

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@aris134 aris134 self-assigned this Mar 6, 2026
@aris134 aris134 force-pushed the amartin/ck-grouped-gemm-fp8 branch 2 times, most recently from c834302 to 32f2ac3 Compare March 11, 2026 15:51
@aris134 aris134 marked this pull request as ready for review March 24, 2026 16:07
@aris134 aris134 requested a review from ipanfilo March 27, 2026 15:59
@aris134 aris134 requested a review from ipanfilo March 30, 2026 20:26
@matthiasdiener matthiasdiener added the ci-level 1 CI test level 1 label Mar 31, 2026
@matthiasdiener matthiasdiener self-requested a review March 31, 2026 18:51
@aris134 aris134 requested a review from matthiasdiener March 31, 2026 22:08
@aris134 aris134 requested a review from matthiasdiener April 1, 2026 14:37
@aris134 aris134 requested a review from matthiasdiener April 1, 2026 18:28
@aris134 aris134 requested a review from matthiasdiener April 1, 2026 19:37
Copy link
Copy Markdown
Contributor

@matthiasdiener matthiasdiener left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@aris134 aris134 force-pushed the amartin/ck-grouped-gemm-fp8 branch from bd76659 to 5632303 Compare April 2, 2026 15:08
Copy link
Copy Markdown
Collaborator

@ipanfilo ipanfilo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good in general. Few small details

NVTETensor* workspace,
bool accumulate,
hipStream_t stream);
cudaStream_t stream);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: since those files are ROCm specific consider using HIP headers, datatypes and function calls

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 1b00bc0

************************************************************************/

#include "ck_grouped_gemm_common.h"
#include "ck_grouped_gemm_fp16_impl.h"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems packed even better than was expected. Now we have 1-1 link between grouped_gemm_fpXX.cpp and grouped_gemm_fpXX_impl.h. What is rationale behind splitting them?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, not much use in splitting them anymore. Combined in 1b00bc0


#pragma once

#include <cuda.h>
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: since those files are ROCm specific consider using HIP headers, datatypes and function calls

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 1b00bc0

Comment on lines +372 to +386
static std::unique_ptr<RunnerInterface> make_fp8_runner(
DType a_dtype,
DType b_dtype,
DType d_dtype,
const GroupedGemmRunContext& ctx) {
switch (detect_gpu_arch()) {
case GPUArch::GFX942:
return make_fp8_runner_gfx942(a_dtype, b_dtype, d_dtype, ctx);
case GPUArch::GFX950:
return make_fp8_runner_gfx950(a_dtype, b_dtype, d_dtype, ctx);
default:
NVTE_ERROR("ck_tile_grouped_gemm: available architectures = {gfx942, gfx950}");
return nullptr;
}
}
Copy link
Copy Markdown
Contributor

@matthiasdiener matthiasdiener Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are 4+ small functions here, can they be merged into the dispatch function? (same for fp16)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 391f22a

Comment on lines +7 to +15
#pragma once

namespace transformer_engine {
namespace grouped_gemm {

bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype,
DType b_dtype,
DType d_dtype,
const GroupedGemmRunContext& ctx);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be moved into _common.h?

Copy link
Copy Markdown
Author

@aris134 aris134 Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was how I originally had it, but I think this would contradict Ye's earlier comment because ck_grouped_gemm_fp8.cpp includes _common.h, and doesn't need to see the fp16_dispatch declaration?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @matthiasdiener , Aristotle was addressing my review comment:

Image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 1 CI test level 1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants