Conversation
c834302 to
32f2ac3
Compare
transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_common.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_gfx942.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_common.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/gemm/ck_grouped_gemm_fp16_factory.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h
Outdated
Show resolved
Hide resolved
transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8_impl.h
Outdated
Show resolved
Hide resolved
transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h
Outdated
Show resolved
Hide resolved
transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8_impl.h
Outdated
Show resolved
Hide resolved
transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8_impl.h
Outdated
Show resolved
Hide resolved
transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8_impl.h
Outdated
Show resolved
Hide resolved
transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h
Outdated
Show resolved
Hide resolved
transformer_engine/common/include/transformer_engine/transformer_engine.h
Outdated
Show resolved
Hide resolved
transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp
Outdated
Show resolved
Hide resolved
bd76659 to
5632303
Compare
ipanfilo
left a comment
There was a problem hiding this comment.
Looks good in general. Few small details
| NVTETensor* workspace, | ||
| bool accumulate, | ||
| hipStream_t stream); | ||
| cudaStream_t stream); |
There was a problem hiding this comment.
Nit: since those files are ROCm specific consider using HIP headers, datatypes and function calls
| ************************************************************************/ | ||
|
|
||
| #include "ck_grouped_gemm_common.h" | ||
| #include "ck_grouped_gemm_fp16_impl.h" |
There was a problem hiding this comment.
It seems packed even better than was expected. Now we have 1-1 link between grouped_gemm_fpXX.cpp and grouped_gemm_fpXX_impl.h. What is rationale behind splitting them?
There was a problem hiding this comment.
Agreed, not much use in splitting them anymore. Combined in 1b00bc0
|
|
||
| #pragma once | ||
|
|
||
| #include <cuda.h> |
There was a problem hiding this comment.
Nit: since those files are ROCm specific consider using HIP headers, datatypes and function calls
transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/include/transformer_engine/transformer_engine.h
Show resolved
Hide resolved
transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8_impl.h
Outdated
Show resolved
Hide resolved
transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h
Outdated
Show resolved
Hide resolved
transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h
Outdated
Show resolved
Hide resolved
transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h
Outdated
Show resolved
Hide resolved
transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp
Outdated
Show resolved
Hide resolved
| static std::unique_ptr<RunnerInterface> make_fp8_runner( | ||
| DType a_dtype, | ||
| DType b_dtype, | ||
| DType d_dtype, | ||
| const GroupedGemmRunContext& ctx) { | ||
| switch (detect_gpu_arch()) { | ||
| case GPUArch::GFX942: | ||
| return make_fp8_runner_gfx942(a_dtype, b_dtype, d_dtype, ctx); | ||
| case GPUArch::GFX950: | ||
| return make_fp8_runner_gfx950(a_dtype, b_dtype, d_dtype, ctx); | ||
| default: | ||
| NVTE_ERROR("ck_tile_grouped_gemm: available architectures = {gfx942, gfx950}"); | ||
| return nullptr; | ||
| } | ||
| } |
There was a problem hiding this comment.
There are 4+ small functions here, can they be merged into the dispatch function? (same for fp16)
| #pragma once | ||
|
|
||
| namespace transformer_engine { | ||
| namespace grouped_gemm { | ||
|
|
||
| bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, | ||
| DType b_dtype, | ||
| DType d_dtype, | ||
| const GroupedGemmRunContext& ctx); |
There was a problem hiding this comment.
Can this be moved into _common.h?
There was a problem hiding this comment.
This was how I originally had it, but I think this would contradict Ye's earlier comment because ck_grouped_gemm_fp8.cpp includes _common.h, and doesn't need to see the fp16_dispatch declaration?
There was a problem hiding this comment.
Hi @matthiasdiener , Aristotle was addressing my review comment:
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes https://github.com/ROCm/frameworks-internal/issues/15787
TODO:
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: