diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 56880a428d..5e73675f4f 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -25,6 +25,7 @@ add_executable(test_operator test_normalization.cu test_normalization_mxfp8.cu test_memset.cu + test_splits_to_offsets.cu test_multi_cast_transpose.cu test_multi_padding.cu test_multi_unpadding.cu diff --git a/tests/cpp/operator/test_splits_to_offsets.cu b/tests/cpp/operator/test_splits_to_offsets.cu new file mode 100644 index 0000000000..faac4b7b6f --- /dev/null +++ b/tests/cpp/operator/test_splits_to_offsets.cu @@ -0,0 +1,80 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include + +#include +#include "../test_common.h" + +class SplitsToOffsetsTestSuite : public ::testing::TestWithParam> {}; + +TEST_P(SplitsToOffsetsTestSuite, TestSplitsToOffsets) { + const size_t num_tensors = std::get<0>(GetParam()); + const int64_t logical_last_dim = std::get<1>(GetParam()); + + std::vector h_first_dims(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + h_first_dims[i] = static_cast((i % 17) + 1); + } + + std::vector h_expected(num_tensors + 1, 0); + for (size_t i = 0; i < num_tensors; ++i) { + h_expected[i + 1] = h_expected[i] + h_first_dims[i] * logical_last_dim; + } + + std::vector h_output(num_tensors + 1, -1); + + int64_t *d_first_dims = nullptr; + int64_t *d_output = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&d_first_dims, sizeof(int64_t) * num_tensors)); + NVTE_CHECK_CUDA(cudaMalloc(&d_output, sizeof(int64_t) * (num_tensors + 1))); + NVTE_CHECK_CUDA(cudaMemcpy(d_first_dims, h_first_dims.data(), sizeof(int64_t) * num_tensors, + cudaMemcpyHostToDevice)); + + nvte_splits_to_offsets(d_first_dims, d_output, num_tensors, logical_last_dim, 0 /* stream */); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + + NVTE_CHECK_CUDA(cudaMemcpy(h_output.data(), d_output, sizeof(int64_t) * (num_tensors + 1), + cudaMemcpyDeviceToHost)); + + NVTE_CHECK_CUDA(cudaFree(d_first_dims)); + NVTE_CHECK_CUDA(cudaFree(d_output)); + + for (size_t i = 0; i < h_output.size(); ++i) { + EXPECT_EQ(h_output[i], h_expected[i]) + << "Mismatch at index " << i << ": expected " << h_expected[i] << ", got " << h_output[i]; + } +} + +namespace { + +std::vector splits_to_offsets_num_tensors = { + 1, + 4, + 255, + 256, + 257, + 1024, +}; + +} // namespace + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, SplitsToOffsetsTestSuite, + ::testing::Combine(::testing::ValuesIn(splits_to_offsets_num_tensors), + ::testing::Values(static_cast(1), static_cast(7), + static_cast(128))), + [](const testing::TestParamInfo &info) { + std::string name = std::to_string(std::get<0>(info.param)) + "X" + + std::to_string(std::get<1>(info.param)); + return name; + }); diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 0ec40dc01c..1bdd80a369 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -87,6 +87,48 @@ __global__ void __launch_bounds__(kThreadsPerBlock) reinterpret_cast(ptr)[idx] = data.value; } +__global__ void __launch_bounds__(kThreadsPerBlock) + splits_to_offsets_kernel(const int64_t *__restrict__ first_dims, int64_t *__restrict__ output, + size_t num_tensors, int64_t logical_last_dim) { + __shared__ int64_t block_scan[kThreadsPerBlock]; + __shared__ int64_t chunk_prefix; + + const size_t tid = threadIdx.x; + if (tid == 0) { + output[0] = 0; + chunk_prefix = 0; + } + __syncthreads(); + + for (size_t chunk_start = 0; chunk_start < num_tensors; chunk_start += kThreadsPerBlock) { + const size_t idx = chunk_start + tid; + int64_t value = 0; + if (idx < num_tensors) { + value = first_dims[idx] * logical_last_dim; + } + block_scan[tid] = value; + __syncthreads(); + + // Inclusive scan in shared memory. + for (size_t offset = 1; offset < kThreadsPerBlock; offset <<= 1) { + const int64_t addend = (tid >= offset) ? block_scan[tid - offset] : 0; + __syncthreads(); + block_scan[tid] += addend; + __syncthreads(); + } + + if (idx < num_tensors) { + output[idx + 1] = chunk_prefix + block_scan[tid]; + } + __syncthreads(); + + if (tid == kThreadsPerBlock - 1) { + chunk_prefix += block_scan[tid]; + } + __syncthreads(); + } +} + } // namespace #define MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, vectorizedType, stream) \ @@ -116,6 +158,19 @@ void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, float, stream); MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, uint8_t, stream); } + +void nvte_splits_to_offsets(const int64_t *first_dims, int64_t *output, size_t num_tensors, + int64_t logical_last_dim, cudaStream_t stream) { + NVTE_API_CALL(nvte_splits_to_offsets); + NVTE_CHECK(output != nullptr, "Output pointer must be allocated."); + NVTE_CHECK(num_tensors > 0, "num_tensors must be greater than 0."); + NVTE_CHECK(first_dims != nullptr, "first_dims pointer must be allocated."); + NVTE_CHECK(logical_last_dim > 0, "logical_last_dim must be greater than 0."); + + splits_to_offsets_kernel<<<1, kThreadsPerBlock, 0, stream>>>(first_dims, output, num_tensors, + logical_last_dim); + NVTE_CHECK_CUDA(cudaGetLastError()); +} } // extern "C" void checkCuDriverContext(CUstream stream) { diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index e316f8be8c..b7461a85d1 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -427,6 +427,22 @@ int nvte_is_non_tn_fp8_gemm_supported(); */ void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream); +/*! \brief Compute scaled prefix-sum offsets for grouped tensors. + * + * Computes: + * output[0] = 0 + * output[i + 1] = sum_{j=0..i}(first_dims[j] * logical_last_dim) + * for i in [0, num_tensors - 1]. + * + * \param[in] first_dims Pointer to device int64 array of size num_tensors. + * \param[out] output Pointer to device int64 array of size num_tensors + 1. + * \param[in] num_tensors Number of entries in first_dims. + * \param[in] logical_last_dim Scale factor applied to each first_dims entry. + * \param[in] stream CUDA stream to use for the operation. + */ +void nvte_splits_to_offsets(const int64_t *first_dims, int64_t *output, size_t num_tensors, + int64_t logical_last_dim, cudaStream_t stream); + /*! \brief TE Grouped Tensor type * * NVTEGroupedTensor is a collection of tensors with potentially different shapes diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index e4d4e5094c..078b384a8b 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -435,6 +435,8 @@ size_t get_cublasLt_version(); size_t get_cudnn_version(); +at::Tensor splits_to_offsets(const at::Tensor &first_dims, int64_t logical_last_dim); + /*************************************************************************************************** * Support THD format for Context Parallel **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/misc.cpp b/transformer_engine/pytorch/csrc/extensions/misc.cpp index d667a61d44..c5707fa53c 100644 --- a/transformer_engine/pytorch/csrc/extensions/misc.cpp +++ b/transformer_engine/pytorch/csrc/extensions/misc.cpp @@ -12,4 +12,22 @@ size_t get_cublasLt_version() { return cublasLtGetVersion(); } size_t get_cudnn_version() { return cudnnGetVersion(); } +at::Tensor splits_to_offsets(const at::Tensor &first_dims, int64_t logical_last_dim) { + NVTE_CHECK(first_dims.is_cuda(), "first_dims must be on CUDA."); + NVTE_CHECK(first_dims.scalar_type() == at::kLong, "first_dims must have dtype int64."); + NVTE_CHECK(first_dims.dim() == 1, "first_dims must be a 1D tensor."); + NVTE_CHECK(logical_last_dim > 0, "logical_last_dim must be greater than 0."); + + auto first_dims_contiguous = first_dims.contiguous(); + const auto num_tensors = static_cast(first_dims_contiguous.numel()); + auto output = at::empty({static_cast(num_tensors) + 1}, + first_dims_contiguous.options().dtype(at::kLong)); + + nvte_splits_to_offsets(static_cast(first_dims_contiguous.data_ptr()), + static_cast(output.data_ptr()), num_tensors, logical_last_dim, + at::cuda::getCurrentCUDAStream()); + + return output; +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 8302a13010..5a12c885f9 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -445,6 +445,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Get cublasLt version", py::call_guard()); m.def("get_cudnn_version", &transformer_engine::pytorch::get_cudnn_version, "Get cuDNN version", py::call_guard()); + m.def("splits_to_offsets", &transformer_engine::pytorch::splits_to_offsets, + "Compute grouped tensor offsets from split sizes", py::arg("first_dims"), + py::arg("logical_last_dim"), py::call_guard()); m.def("get_num_cublas_streams", &nvte_get_num_compute_streams, "Get number of compute streams", py::call_guard()); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 27dc87697f..c904057e97 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -73,11 +73,15 @@ std::optional build_grouped_tensor_offsets(const size_t num_tensors, "first_dims must have length ", num_tensors, "."); const int64_t logical_last_dim_i64 = static_cast(logical_last_dim); - auto scaled_first_dims = (first_dims_tensor * logical_last_dim_i64).contiguous(); - // Single kernel needed for these ops. - auto cumsum = at::cumsum(scaled_first_dims, 0); - auto zero = at::zeros({1}, cumsum.options()); - return at::cat({zero, cumsum}); + const auto first_dims_contiguous = first_dims_tensor.contiguous(); + auto tensor_offsets = + at::empty({static_cast(num_tensors) + 1}, first_dims_contiguous.options()); + NVTE_SCOPED_GIL_RELEASE({ + nvte_splits_to_offsets(static_cast(first_dims_contiguous.data_ptr()), + static_cast(tensor_offsets.data_ptr()), num_tensors, + logical_last_dim_i64, at::cuda::getCurrentCUDAStream()); + }); + return tensor_offsets; } at::TensorOptions grouped_tensor_data_options(const DType dtype) {