Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,22 @@ list(APPEND transformer_engine_cuda_sources
list(APPEND transformer_engine_cuda_arch_specific_sources
fused_attn/flash_attn.cu
activation/gelu.cu
activation/gelu_dbias.cu
activation/gelu_grouped.cu
activation/gelu_grouped_dbias.cu
activation/glu.cu
activation/relu.cu
activation/relu_dbias.cu
activation/relu_grouped.cu
activation/relu_grouped_dbias.cu
activation/swiglu.cu
activation/swiglu_dbias.cu
activation/swiglu_grouped.cu
activation/swiglu_grouped_dbias.cu
cast/cast.cu
cast/cast_dbias.cu
cast/cast_grouped.cu
cast/cast_grouped_dbias.cu
gemm/cutlass_grouped_gemm.cu
hadamard_transform/group_hadamard_transform.cu
hadamard_transform/graph_safe_group_hadamard_transform.cu
Expand Down Expand Up @@ -447,9 +459,18 @@ list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu
option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF)
if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
list(APPEND nvte_sources_with_fast_math activation/gelu.cu
activation/gelu_dbias.cu
activation/gelu_grouped.cu
activation/gelu_grouped_dbias.cu
activation/glu.cu
activation/relu.cu
activation/swiglu.cu)
activation/relu_dbias.cu
activation/relu_grouped.cu
activation/relu_grouped_dbias.cu
activation/swiglu.cu
activation/swiglu_dbias.cu
activation/swiglu_grouped.cu
activation/swiglu_grouped_dbias.cu)
endif()

foreach(cuda_source IN LISTS nvte_sources_with_fast_math)
Expand Down
99 changes: 0 additions & 99 deletions transformer_engine/common/activation/gelu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,62 +13,13 @@ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
act_fn<fp32, Empty, gelu<fp32, fp32>>(input, output, stream);
}

void nvte_group_gelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_gelu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, gelu<fp32, fp32>>(input, output, nullptr,
stream);
}

void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dgelu);
using namespace transformer_engine;
dact_fn<fp32, Empty, dgelu<fp32, fp32>>(grad, input, output, stream);
}

void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dgelu);
using namespace transformer_engine;
NVTEGroupedTensor dbias = nullptr;
NVTETensor workspace = nullptr;

constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;

dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}

void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dgelu);
using namespace transformer_engine;

constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;

dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}

void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTEGroupedTensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dgelu);
using namespace transformer_engine;

constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;

dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}

void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_geglu);
using namespace transformer_engine;
Expand All @@ -90,63 +41,13 @@ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream)
act_fn<fp32, Empty, qgelu<fp32, fp32>>(input, output, stream);
}

void nvte_group_qgelu(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_qgelu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, qgelu<fp32, fp32>>(input, output, nullptr,
stream);
}

void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgelu);
using namespace transformer_engine;
dact_fn<fp32, Empty, dqgelu<fp32, fp32>>(grad, input, output, stream);
}

void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dqgelu);
using namespace transformer_engine;
NVTEGroupedTensor dbias = nullptr;
NVTETensor workspace = nullptr;

constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;

dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}

void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dqgelu);
using namespace transformer_engine;

constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;

dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}

void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTEGroupedTensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dqgelu);
using namespace transformer_engine;

constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;

dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}

void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_qgeglu);
using namespace transformer_engine;
Expand Down
34 changes: 34 additions & 0 deletions transformer_engine/common/activation/gelu_dbias.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include "../util/math.h"
#include "./activation_template.h"

void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dgelu);
using namespace transformer_engine;

constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;

dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}

void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dqgelu);
using namespace transformer_engine;

constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;

dispatch::quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
53 changes: 53 additions & 0 deletions transformer_engine/common/activation/gelu_grouped.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include "../util/math.h"
#include "./activation_template.h"

void nvte_group_gelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_gelu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, gelu<fp32, fp32>>(input, output, nullptr,
stream);
}

void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dgelu);
using namespace transformer_engine;
NVTEGroupedTensor dbias = nullptr;
NVTETensor workspace = nullptr;

constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;

dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}

void nvte_group_qgelu(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_qgelu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, qgelu<fp32, fp32>>(input, output, nullptr,
stream);
}

void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dqgelu);
using namespace transformer_engine;
NVTEGroupedTensor dbias = nullptr;
NVTETensor workspace = nullptr;

constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;

dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}
36 changes: 36 additions & 0 deletions transformer_engine/common/activation/gelu_grouped_dbias.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include "../util/math.h"
#include "./activation_template.h"

void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTEGroupedTensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dgelu);
using namespace transformer_engine;

constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;

dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}

void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTEGroupedTensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dqgelu);
using namespace transformer_engine;

constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;

dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
Loading
Loading