Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/cpp/operator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ add_executable(test_operator
test_normalization.cu
test_normalization_mxfp8.cu
test_memset.cu
test_splits_to_offsets.cu
test_multi_cast_transpose.cu
test_multi_padding.cu
test_multi_unpadding.cu
Expand Down
80 changes: 80 additions & 0 deletions tests/cpp/operator/test_splits_to_offsets.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include <cstdint>
#include <string>
#include <tuple>
#include <vector>

#include <cuda_runtime.h>
#include <gtest/gtest.h>

#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"

class SplitsToOffsetsTestSuite : public ::testing::TestWithParam<std::tuple<size_t, int64_t>> {};

TEST_P(SplitsToOffsetsTestSuite, TestSplitsToOffsets) {
const size_t num_tensors = std::get<0>(GetParam());
const int64_t logical_last_dim = std::get<1>(GetParam());

std::vector<int64_t> h_first_dims(num_tensors);
for (size_t i = 0; i < num_tensors; ++i) {
h_first_dims[i] = static_cast<int64_t>((i % 17) + 1);
}

std::vector<int64_t> h_expected(num_tensors + 1, 0);
for (size_t i = 0; i < num_tensors; ++i) {
h_expected[i + 1] = h_expected[i] + h_first_dims[i] * logical_last_dim;
}

std::vector<int64_t> h_output(num_tensors + 1, -1);

int64_t *d_first_dims = nullptr;
int64_t *d_output = nullptr;
NVTE_CHECK_CUDA(cudaMalloc(&d_first_dims, sizeof(int64_t) * num_tensors));
NVTE_CHECK_CUDA(cudaMalloc(&d_output, sizeof(int64_t) * (num_tensors + 1)));
NVTE_CHECK_CUDA(cudaMemcpy(d_first_dims, h_first_dims.data(), sizeof(int64_t) * num_tensors,
cudaMemcpyHostToDevice));

nvte_splits_to_offsets(d_first_dims, d_output, num_tensors, logical_last_dim, 0 /* stream */);
NVTE_CHECK_CUDA(cudaDeviceSynchronize());

NVTE_CHECK_CUDA(cudaMemcpy(h_output.data(), d_output, sizeof(int64_t) * (num_tensors + 1),
cudaMemcpyDeviceToHost));

NVTE_CHECK_CUDA(cudaFree(d_first_dims));
NVTE_CHECK_CUDA(cudaFree(d_output));

for (size_t i = 0; i < h_output.size(); ++i) {
EXPECT_EQ(h_output[i], h_expected[i])
<< "Mismatch at index " << i << ": expected " << h_expected[i] << ", got " << h_output[i];
}
}

namespace {

std::vector<size_t> splits_to_offsets_num_tensors = {
1,
4,
255,
256,
257,
1024,
};

} // namespace

INSTANTIATE_TEST_SUITE_P(
OperatorTest, SplitsToOffsetsTestSuite,
::testing::Combine(::testing::ValuesIn(splits_to_offsets_num_tensors),
::testing::Values(static_cast<int64_t>(1), static_cast<int64_t>(7),
static_cast<int64_t>(128))),
[](const testing::TestParamInfo<SplitsToOffsetsTestSuite::ParamType> &info) {
std::string name = std::to_string(std::get<0>(info.param)) + "X" +
std::to_string(std::get<1>(info.param));
return name;
});
55 changes: 55 additions & 0 deletions transformer_engine/common/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,48 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
reinterpret_cast<TVectorized *>(ptr)[idx] = data.value;
}

__global__ void __launch_bounds__(kThreadsPerBlock)
splits_to_offsets_kernel(const int64_t *__restrict__ first_dims, int64_t *__restrict__ output,
size_t num_tensors, int64_t logical_last_dim) {
__shared__ int64_t block_scan[kThreadsPerBlock];
__shared__ int64_t chunk_prefix;

const size_t tid = threadIdx.x;
if (tid == 0) {
output[0] = 0;
chunk_prefix = 0;
}
__syncthreads();

for (size_t chunk_start = 0; chunk_start < num_tensors; chunk_start += kThreadsPerBlock) {
const size_t idx = chunk_start + tid;
int64_t value = 0;
if (idx < num_tensors) {
value = first_dims[idx] * logical_last_dim;
}
block_scan[tid] = value;
__syncthreads();

// Inclusive scan in shared memory.
for (size_t offset = 1; offset < kThreadsPerBlock; offset <<= 1) {
const int64_t addend = (tid >= offset) ? block_scan[tid - offset] : 0;
__syncthreads();
block_scan[tid] += addend;
__syncthreads();
}

if (idx < num_tensors) {
output[idx + 1] = chunk_prefix + block_scan[tid];
}
__syncthreads();

if (tid == kThreadsPerBlock - 1) {
chunk_prefix += block_scan[tid];
}
__syncthreads();
}
}

} // namespace

#define MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, vectorizedType, stream) \
Expand Down Expand Up @@ -116,6 +158,19 @@ void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream
MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, float, stream);
MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, uint8_t, stream);
}

void nvte_splits_to_offsets(const int64_t *first_dims, int64_t *output, size_t num_tensors,
int64_t logical_last_dim, cudaStream_t stream) {
NVTE_API_CALL(nvte_splits_to_offsets);
NVTE_CHECK(output != nullptr, "Output pointer must be allocated.");
NVTE_CHECK(num_tensors > 0, "num_tensors must be greater than 0.");
NVTE_CHECK(first_dims != nullptr, "first_dims pointer must be allocated.");
NVTE_CHECK(logical_last_dim > 0, "logical_last_dim must be greater than 0.");

splits_to_offsets_kernel<<<1, kThreadsPerBlock, 0, stream>>>(first_dims, output, num_tensors,
logical_last_dim);
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // extern "C"

void checkCuDriverContext(CUstream stream) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,22 @@ int nvte_is_non_tn_fp8_gemm_supported();
*/
void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream);

/*! \brief Compute scaled prefix-sum offsets for grouped tensors.
*
* Computes:
* output[0] = 0
* output[i + 1] = sum_{j=0..i}(first_dims[j] * logical_last_dim)
* for i in [0, num_tensors - 1].
*
* \param[in] first_dims Pointer to device int64 array of size num_tensors.
* \param[out] output Pointer to device int64 array of size num_tensors + 1.
* \param[in] num_tensors Number of entries in first_dims.
* \param[in] logical_last_dim Scale factor applied to each first_dims entry.
* \param[in] stream CUDA stream to use for the operation.
*/
void nvte_splits_to_offsets(const int64_t *first_dims, int64_t *output, size_t num_tensors,
int64_t logical_last_dim, cudaStream_t stream);

/*! \brief TE Grouped Tensor type
*
* NVTEGroupedTensor is a collection of tensors with potentially different shapes
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,8 @@ size_t get_cublasLt_version();

size_t get_cudnn_version();

at::Tensor splits_to_offsets(const at::Tensor &first_dims, int64_t logical_last_dim);

/***************************************************************************************************
* Support THD format for Context Parallel
**************************************************************************************************/
Expand Down
18 changes: 18 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,22 @@ size_t get_cublasLt_version() { return cublasLtGetVersion(); }

size_t get_cudnn_version() { return cudnnGetVersion(); }

at::Tensor splits_to_offsets(const at::Tensor &first_dims, int64_t logical_last_dim) {
NVTE_CHECK(first_dims.is_cuda(), "first_dims must be on CUDA.");
NVTE_CHECK(first_dims.scalar_type() == at::kLong, "first_dims must have dtype int64.");
NVTE_CHECK(first_dims.dim() == 1, "first_dims must be a 1D tensor.");
NVTE_CHECK(logical_last_dim > 0, "logical_last_dim must be greater than 0.");

auto first_dims_contiguous = first_dims.contiguous();
const auto num_tensors = static_cast<size_t>(first_dims_contiguous.numel());
auto output = at::empty({static_cast<int64_t>(num_tensors) + 1},
first_dims_contiguous.options().dtype(at::kLong));

nvte_splits_to_offsets(static_cast<const int64_t *>(first_dims_contiguous.data_ptr()),
static_cast<int64_t *>(output.data_ptr()), num_tensors, logical_last_dim,
at::cuda::getCurrentCUDAStream());

return output;
}

} // namespace transformer_engine::pytorch
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Get cublasLt version", py::call_guard<py::gil_scoped_release>());
m.def("get_cudnn_version", &transformer_engine::pytorch::get_cudnn_version, "Get cuDNN version",
py::call_guard<py::gil_scoped_release>());
m.def("splits_to_offsets", &transformer_engine::pytorch::splits_to_offsets,
"Compute grouped tensor offsets from split sizes", py::arg("first_dims"),
py::arg("logical_last_dim"), py::call_guard<py::gil_scoped_release>());
m.def("get_num_cublas_streams", &nvte_get_num_compute_streams, "Get number of compute streams",
py::call_guard<py::gil_scoped_release>());

Expand Down
14 changes: 9 additions & 5 deletions transformer_engine/pytorch/csrc/quantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,15 @@ std::optional<at::Tensor> build_grouped_tensor_offsets(const size_t num_tensors,
"first_dims must have length ", num_tensors, ".");

const int64_t logical_last_dim_i64 = static_cast<int64_t>(logical_last_dim);
auto scaled_first_dims = (first_dims_tensor * logical_last_dim_i64).contiguous();
// Single kernel needed for these ops.
auto cumsum = at::cumsum(scaled_first_dims, 0);
auto zero = at::zeros({1}, cumsum.options());
return at::cat({zero, cumsum});
const auto first_dims_contiguous = first_dims_tensor.contiguous();
auto tensor_offsets =
at::empty({static_cast<int64_t>(num_tensors) + 1}, first_dims_contiguous.options());
NVTE_SCOPED_GIL_RELEASE({
nvte_splits_to_offsets(static_cast<const int64_t*>(first_dims_contiguous.data_ptr()),
static_cast<int64_t*>(tensor_offsets.data_ptr()), num_tensors,
logical_last_dim_i64, at::cuda::getCurrentCUDAStream());
});
return tensor_offsets;
}

at::TensorOptions grouped_tensor_data_options(const DType dtype) {
Expand Down
Loading