diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 030023d949..06d85b6d84 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -212,10 +212,22 @@ list(APPEND transformer_engine_cuda_sources list(APPEND transformer_engine_cuda_arch_specific_sources fused_attn/flash_attn.cu activation/gelu.cu + activation/gelu_dbias.cu + activation/gelu_grouped.cu + activation/gelu_grouped_dbias.cu activation/glu.cu activation/relu.cu + activation/relu_dbias.cu + activation/relu_grouped.cu + activation/relu_grouped_dbias.cu activation/swiglu.cu + activation/swiglu_dbias.cu + activation/swiglu_grouped.cu + activation/swiglu_grouped_dbias.cu cast/cast.cu + cast/cast_dbias.cu + cast/cast_grouped.cu + cast/cast_grouped_dbias.cu gemm/cutlass_grouped_gemm.cu hadamard_transform/group_hadamard_transform.cu hadamard_transform/graph_safe_group_hadamard_transform.cu @@ -447,9 +459,18 @@ list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF) if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) list(APPEND nvte_sources_with_fast_math activation/gelu.cu + activation/gelu_dbias.cu + activation/gelu_grouped.cu + activation/gelu_grouped_dbias.cu activation/glu.cu activation/relu.cu - activation/swiglu.cu) + activation/relu_dbias.cu + activation/relu_grouped.cu + activation/relu_grouped_dbias.cu + activation/swiglu.cu + activation/swiglu_dbias.cu + activation/swiglu_grouped.cu + activation/swiglu_grouped_dbias.cu) endif() foreach(cuda_source IN LISTS nvte_sources_with_fast_math) diff --git a/transformer_engine/common/activation/gelu.cu b/transformer_engine/common/activation/gelu.cu index ea864813bf..6bd63672ca 100644 --- a/transformer_engine/common/activation/gelu.cu +++ b/transformer_engine/common/activation/gelu.cu @@ -13,14 +13,6 @@ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { act_fn>(input, output, stream); } -void nvte_group_gelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) { - NVTE_API_CALL(nvte_group_gelu); - using namespace transformer_engine; - constexpr bool IS_ACT = true; - dispatch::group_quantize_fwd_helper>(input, output, nullptr, - stream); -} - void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dgelu); @@ -28,47 +20,6 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output dact_fn>(grad, input, output, stream); } -void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, - NVTEGroupedTensor output, cudaStream_t stream) { - NVTE_API_CALL(nvte_group_dgelu); - using namespace transformer_engine; - NVTEGroupedTensor dbias = nullptr; - NVTETensor workspace = nullptr; - - constexpr bool IS_DBIAS = false; - constexpr bool IS_DACT = true; - - dispatch::group_quantize_bwd_helper>( - grad, input, output, dbias, workspace, nullptr, stream); -} - -void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input, - NVTETensor output, NVTETensor dbias, NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize_dbias_dgelu); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = true; - - dispatch::quantize_bwd_helper>( - input, activation_input, output, dbias, workspace, nullptr, stream); -} - -void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input, - const NVTEGroupedTensor activation_input, - NVTEGroupedTensor output, NVTEGroupedTensor dbias, - NVTETensor workspace, cudaStream_t stream) { - NVTE_API_CALL(nvte_group_quantize_dbias_dgelu); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = true; - - dispatch::group_quantize_bwd_helper>( - input, activation_input, output, dbias, workspace, nullptr, stream); -} - void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_geglu); using namespace transformer_engine; @@ -90,15 +41,6 @@ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) act_fn>(input, output, stream); } -void nvte_group_qgelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, - cudaStream_t stream) { - NVTE_API_CALL(nvte_group_qgelu); - using namespace transformer_engine; - constexpr bool IS_ACT = true; - dispatch::group_quantize_fwd_helper>(input, output, nullptr, - stream); -} - void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dqgelu); @@ -106,47 +48,6 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu dact_fn>(grad, input, output, stream); } -void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, - NVTEGroupedTensor output, cudaStream_t stream) { - NVTE_API_CALL(nvte_group_dqgelu); - using namespace transformer_engine; - NVTEGroupedTensor dbias = nullptr; - NVTETensor workspace = nullptr; - - constexpr bool IS_DBIAS = false; - constexpr bool IS_DACT = true; - - dispatch::group_quantize_bwd_helper>( - grad, input, output, dbias, workspace, nullptr, stream); -} - -void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input, - NVTETensor output, NVTETensor dbias, NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize_dbias_dqgelu); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = true; - - dispatch::quantize_bwd_helper>( - input, activation_input, output, dbias, workspace, nullptr, stream); -} - -void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input, - const NVTEGroupedTensor activation_input, - NVTEGroupedTensor output, NVTEGroupedTensor dbias, - NVTETensor workspace, cudaStream_t stream) { - NVTE_API_CALL(nvte_group_quantize_dbias_dqgelu); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = true; - - dispatch::group_quantize_bwd_helper>( - input, activation_input, output, dbias, workspace, nullptr, stream); -} - void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_qgeglu); using namespace transformer_engine; diff --git a/transformer_engine/common/activation/gelu_dbias.cu b/transformer_engine/common/activation/gelu_dbias.cu new file mode 100644 index 0000000000..4eaa9e355b --- /dev/null +++ b/transformer_engine/common/activation/gelu_dbias.cu @@ -0,0 +1,34 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../util/math.h" +#include "./activation_template.h" + +void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dgelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + +void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dqgelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} diff --git a/transformer_engine/common/activation/gelu_grouped.cu b/transformer_engine/common/activation/gelu_grouped.cu new file mode 100644 index 0000000000..c3267356f8 --- /dev/null +++ b/transformer_engine/common/activation/gelu_grouped.cu @@ -0,0 +1,53 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../util/math.h" +#include "./activation_template.h" + +void nvte_group_gelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_gelu); + using namespace transformer_engine; + constexpr bool IS_ACT = true; + dispatch::group_quantize_fwd_helper>(input, output, nullptr, + stream); +} + +void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_dgelu); + using namespace transformer_engine; + NVTEGroupedTensor dbias = nullptr; + NVTETensor workspace = nullptr; + + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + grad, input, output, dbias, workspace, nullptr, stream); +} + +void nvte_group_qgelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_group_qgelu); + using namespace transformer_engine; + constexpr bool IS_ACT = true; + dispatch::group_quantize_fwd_helper>(input, output, nullptr, + stream); +} + +void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_dqgelu); + using namespace transformer_engine; + NVTEGroupedTensor dbias = nullptr; + NVTETensor workspace = nullptr; + + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + grad, input, output, dbias, workspace, nullptr, stream); +} diff --git a/transformer_engine/common/activation/gelu_grouped_dbias.cu b/transformer_engine/common/activation/gelu_grouped_dbias.cu new file mode 100644 index 0000000000..e8b549f692 --- /dev/null +++ b/transformer_engine/common/activation/gelu_grouped_dbias.cu @@ -0,0 +1,36 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../util/math.h" +#include "./activation_template.h" + +void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor activation_input, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_quantize_dbias_dgelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + +void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor activation_input, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_quantize_dbias_dqgelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} diff --git a/transformer_engine/common/activation/relu.cu b/transformer_engine/common/activation/relu.cu index fc9122b7ec..57222262f3 100644 --- a/transformer_engine/common/activation/relu.cu +++ b/transformer_engine/common/activation/relu.cu @@ -13,14 +13,6 @@ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { act_fn>(input, output, stream); } -void nvte_group_relu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) { - NVTE_API_CALL(nvte_group_relu); - using namespace transformer_engine; - constexpr bool IS_ACT = true; - dispatch::group_quantize_fwd_helper>(input, output, nullptr, - stream); -} - void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_drelu); @@ -28,47 +20,6 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output dact_fn>(grad, input, output, stream); } -void nvte_group_drelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, - NVTEGroupedTensor output, cudaStream_t stream) { - NVTE_API_CALL(nvte_group_drelu); - using namespace transformer_engine; - NVTEGroupedTensor dbias = nullptr; - NVTETensor workspace = nullptr; - - constexpr bool IS_DBIAS = false; - constexpr bool IS_DACT = true; - - dispatch::group_quantize_bwd_helper>( - grad, input, output, dbias, workspace, nullptr, stream); -} - -void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input, - NVTETensor output, NVTETensor dbias, NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize_dbias_drelu); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = true; - - dispatch::quantize_bwd_helper>( - input, activation_input, output, dbias, workspace, nullptr, stream); -} - -void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input, - const NVTEGroupedTensor activation_input, - NVTEGroupedTensor output, NVTEGroupedTensor dbias, - NVTETensor workspace, cudaStream_t stream) { - NVTE_API_CALL(nvte_group_quantize_dbias_drelu); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = true; - - dispatch::group_quantize_bwd_helper>( - input, activation_input, output, dbias, workspace, nullptr, stream); -} - void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_reglu); using namespace transformer_engine; @@ -90,15 +41,6 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) act_fn>(input, output, stream); } -void nvte_group_srelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, - cudaStream_t stream) { - NVTE_API_CALL(nvte_group_srelu); - using namespace transformer_engine; - constexpr bool IS_ACT = true; - dispatch::group_quantize_fwd_helper>(input, output, nullptr, - stream); -} - void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsrelu); @@ -106,47 +48,6 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu dact_fn>(grad, input, output, stream); } -void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, - NVTEGroupedTensor output, cudaStream_t stream) { - NVTE_API_CALL(nvte_group_dsrelu); - using namespace transformer_engine; - NVTEGroupedTensor dbias = nullptr; - NVTETensor workspace = nullptr; - - constexpr bool IS_DBIAS = false; - constexpr bool IS_DACT = true; - - dispatch::group_quantize_bwd_helper>( - grad, input, output, dbias, workspace, nullptr, stream); -} - -void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input, - NVTETensor output, NVTETensor dbias, NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize_dbias_dsrelu); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = true; - - dispatch::quantize_bwd_helper>( - input, activation_input, output, dbias, workspace, nullptr, stream); -} - -void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input, - const NVTEGroupedTensor activation_input, - NVTEGroupedTensor output, NVTEGroupedTensor dbias, - NVTETensor workspace, cudaStream_t stream) { - NVTE_API_CALL(nvte_group_quantize_dbias_dsrelu); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = true; - - dispatch::group_quantize_bwd_helper>( - input, activation_input, output, dbias, workspace, nullptr, stream); -} - void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_sreglu); using namespace transformer_engine; diff --git a/transformer_engine/common/activation/relu_dbias.cu b/transformer_engine/common/activation/relu_dbias.cu new file mode 100644 index 0000000000..bd14dc6c9e --- /dev/null +++ b/transformer_engine/common/activation/relu_dbias.cu @@ -0,0 +1,34 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../util/math.h" +#include "./activation_template.h" + +void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_drelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + +void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dsrelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} diff --git a/transformer_engine/common/activation/relu_grouped.cu b/transformer_engine/common/activation/relu_grouped.cu new file mode 100644 index 0000000000..93ce6b82fe --- /dev/null +++ b/transformer_engine/common/activation/relu_grouped.cu @@ -0,0 +1,53 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../util/math.h" +#include "./activation_template.h" + +void nvte_group_relu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_relu); + using namespace transformer_engine; + constexpr bool IS_ACT = true; + dispatch::group_quantize_fwd_helper>(input, output, nullptr, + stream); +} + +void nvte_group_drelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_drelu); + using namespace transformer_engine; + NVTEGroupedTensor dbias = nullptr; + NVTETensor workspace = nullptr; + + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + grad, input, output, dbias, workspace, nullptr, stream); +} + +void nvte_group_srelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_group_srelu); + using namespace transformer_engine; + constexpr bool IS_ACT = true; + dispatch::group_quantize_fwd_helper>(input, output, nullptr, + stream); +} + +void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_dsrelu); + using namespace transformer_engine; + NVTEGroupedTensor dbias = nullptr; + NVTETensor workspace = nullptr; + + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + grad, input, output, dbias, workspace, nullptr, stream); +} diff --git a/transformer_engine/common/activation/relu_grouped_dbias.cu b/transformer_engine/common/activation/relu_grouped_dbias.cu new file mode 100644 index 0000000000..2b9dcd35d4 --- /dev/null +++ b/transformer_engine/common/activation/relu_grouped_dbias.cu @@ -0,0 +1,36 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../util/math.h" +#include "./activation_template.h" + +void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor activation_input, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_quantize_dbias_drelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + +void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input, + const NVTEGroupedTensor activation_input, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_quantize_dbias_dsrelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index 12478af4cf..0b5b6069b6 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -13,14 +13,6 @@ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { act_fn>(input, output, stream); } -void nvte_group_silu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) { - NVTE_API_CALL(nvte_group_silu); - using namespace transformer_engine; - constexpr bool IS_ACT = true; - dispatch::group_quantize_fwd_helper>(input, output, nullptr, - stream); -} - void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsilu); @@ -28,47 +20,6 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output dact_fn>(grad, input, output, stream); } -void nvte_group_dsilu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, - NVTEGroupedTensor output, cudaStream_t stream) { - NVTE_API_CALL(nvte_group_dsilu); - using namespace transformer_engine; - NVTEGroupedTensor dbias = nullptr; - NVTETensor workspace = nullptr; - - constexpr bool IS_DBIAS = false; - constexpr bool IS_DACT = true; - - dispatch::group_quantize_bwd_helper>( - grad, input, output, dbias, workspace, nullptr, stream); -} - -void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input, - NVTETensor output, NVTETensor dbias, NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize_dbias_dsilu); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = true; - - dispatch::quantize_bwd_helper>( - input, activation_input, output, dbias, workspace, nullptr, stream); -} - -void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input, - const NVTEGroupedTensor activation_input, - NVTEGroupedTensor output, NVTEGroupedTensor dbias, - NVTETensor workspace, cudaStream_t stream) { - NVTE_API_CALL(nvte_group_quantize_dbias_dsilu); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = true; - - dispatch::group_quantize_bwd_helper>( - input, activation_input, output, dbias, workspace, nullptr, stream); -} - void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_swiglu); using namespace transformer_engine; diff --git a/transformer_engine/common/activation/swiglu_dbias.cu b/transformer_engine/common/activation/swiglu_dbias.cu new file mode 100644 index 0000000000..0e532acc57 --- /dev/null +++ b/transformer_engine/common/activation/swiglu_dbias.cu @@ -0,0 +1,21 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../util/math.h" +#include "./activation_template.h" + +void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dsilu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} diff --git a/transformer_engine/common/activation/swiglu_grouped.cu b/transformer_engine/common/activation/swiglu_grouped.cu new file mode 100644 index 0000000000..160ab66288 --- /dev/null +++ b/transformer_engine/common/activation/swiglu_grouped.cu @@ -0,0 +1,30 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../util/math.h" +#include "./activation_template.h" + +void nvte_group_silu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_silu); + using namespace transformer_engine; + constexpr bool IS_ACT = true; + dispatch::group_quantize_fwd_helper>(input, output, nullptr, + stream); +} + +void nvte_group_dsilu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, + NVTEGroupedTensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_dsilu); + using namespace transformer_engine; + NVTEGroupedTensor dbias = nullptr; + NVTETensor workspace = nullptr; + + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + grad, input, output, dbias, workspace, nullptr, stream); +} diff --git a/transformer_engine/common/activation/swiglu_grouped_dbias.cu b/transformer_engine/common/activation/swiglu_grouped_dbias.cu new file mode 100644 index 0000000000..83d15e8024 --- /dev/null +++ b/transformer_engine/common/activation/swiglu_grouped_dbias.cu @@ -0,0 +1,22 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../util/math.h" +#include "./activation_template.h" + +void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input, + const NVTEGroupedTensor activation_input, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_quantize_dbias_dsilu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::group_quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 61cfacd334..1e3c04573b 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -26,15 +26,6 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea dispatch::quantize_fwd_helper(input, output, nullptr, stream); } -void nvte_group_quantize(const NVTEGroupedTensor input, NVTEGroupedTensor output, - const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - NVTE_API_CALL(nvte_group_quantize); - using namespace transformer_engine; - - constexpr bool IS_ACT = false; - dispatch::group_quantize_fwd_helper(input, output, quant_config, stream); -} - void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, cudaStream_t stream) { NVTE_API_CALL(nvte_quantize_noop); @@ -56,32 +47,6 @@ void nvte_quantize_v2(const NVTETensor input, NVTETensor output, dispatch::quantize_fwd_helper(input, output, quant_config, stream); } -void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, - NVTETensor workspace, cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize_dbias); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = false; - constexpr const NVTETensor activation_input = nullptr; - - dispatch::quantize_bwd_helper( - input, activation_input, output, dbias, workspace, nullptr, stream); -} - -void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output, - NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) { - NVTE_API_CALL(nvte_group_quantize_dbias); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = false; - constexpr const NVTEGroupedTensor activation_input = nullptr; - - dispatch::group_quantize_bwd_helper( - input, activation_input, output, dbias, workspace, nullptr, stream); -} - void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dequantize); using namespace transformer_engine; @@ -89,14 +54,6 @@ void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t str stream); } -void nvte_group_dequantize(const NVTEGroupedTensor input, NVTEGroupedTensor output, - cudaStream_t stream) { - NVTE_API_CALL(nvte_group_dequantize); - using namespace transformer_engine; - dispatch::group_dequantize_helper(*convertNVTEGroupedTensorCheck(input), - convertNVTEGroupedTensorCheck(output), stream); -} - void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, const NVTEQuantizationConfig quant_configs, const size_t num_tensors, cudaStream_t stream) { @@ -130,19 +87,3 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s))); } } - -// Group quantize assumes contiguous inputs and outputs in memory allocation -// Note: this API assumes knowing split sections from the host, if split information -// comes from D2H copy, it will break cuda graph capture -void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *outputs, - const size_t *split_sections, const size_t num_tensors, - const NVTEQuantizationConfig quant_config, - cudaStream_t stream) { - NVTE_API_CALL(nvte_group_nvfp4_quantize_with_amax); - using namespace transformer_engine; - - constexpr bool IS_ACT = false; - - dispatch::group_quantize_fwd_host_aware_helper( - input, outputs, split_sections, num_tensors, quant_config, stream); -} diff --git a/transformer_engine/common/cast/cast_dbias.cu b/transformer_engine/common/cast/cast_dbias.cu new file mode 100644 index 0000000000..480e8ca744 --- /dev/null +++ b/transformer_engine/common/cast/cast_dbias.cu @@ -0,0 +1,24 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include "../common.h" +#include "dispatch/quantize.cuh" + +void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = false; + constexpr const NVTETensor activation_input = nullptr; + + dispatch::quantize_bwd_helper( + input, activation_input, output, dbias, workspace, nullptr, stream); +} diff --git a/transformer_engine/common/cast/cast_grouped.cu b/transformer_engine/common/cast/cast_grouped.cu new file mode 100644 index 0000000000..853634c811 --- /dev/null +++ b/transformer_engine/common/cast/cast_grouped.cu @@ -0,0 +1,47 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include "../common.h" +#include "dispatch/dequantize.cuh" +#include "dispatch/quantize.cuh" + +void nvte_group_quantize(const NVTEGroupedTensor input, NVTEGroupedTensor output, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_quantize); + using namespace transformer_engine; + + constexpr bool IS_ACT = false; + dispatch::group_quantize_fwd_helper(input, output, quant_config, stream); +} + +void nvte_group_dequantize(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_group_dequantize); + using namespace transformer_engine; + dispatch::group_dequantize_helper(*convertNVTEGroupedTensorCheck(input), + convertNVTEGroupedTensorCheck(output), stream); +} + +// Group quantize assumes contiguous inputs and outputs in memory allocation. +// Note: this API assumes knowing split sections from the host. If split information +// comes from D2H copy, it will break cuda graph capture. +void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *outputs, + const size_t *split_sections, const size_t num_tensors, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream) { + NVTE_API_CALL(nvte_group_nvfp4_quantize_with_amax); + using namespace transformer_engine; + + constexpr bool IS_ACT = false; + + dispatch::group_quantize_fwd_host_aware_helper( + input, outputs, split_sections, num_tensors, quant_config, stream); +} diff --git a/transformer_engine/common/cast/cast_grouped_dbias.cu b/transformer_engine/common/cast/cast_grouped_dbias.cu new file mode 100644 index 0000000000..5290255a00 --- /dev/null +++ b/transformer_engine/common/cast/cast_grouped_dbias.cu @@ -0,0 +1,24 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include "../common.h" +#include "dispatch/quantize.cuh" + +void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output, + NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_quantize_dbias); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = false; + constexpr const NVTEGroupedTensor activation_input = nullptr; + + dispatch::group_quantize_bwd_helper( + input, activation_input, output, dbias, workspace, nullptr, stream); +} diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh index 90e57a6fe8..3e6eb55b73 100644 --- a/transformer_engine/common/cast/core/common.cuh +++ b/transformer_engine/common/cast/core/common.cuh @@ -17,6 +17,7 @@ #include #include "../../common.h" +#include "../../util/ptx.cuh" #include "../../utils.cuh" namespace transformer_engine {