Skip to content

Fused kernel for calculating offsets from first dim splits#2755

Open
ksivaman wants to merge 6 commits intoNVIDIA:mainfrom
ksivaman:fused_mul_zero_cumsum_kernel
Open

Fused kernel for calculating offsets from first dim splits#2755
ksivaman wants to merge 6 commits intoNVIDIA:mainfrom
ksivaman:fused_mul_zero_cumsum_kernel

Conversation

@ksivaman
Copy link
Member

Description

Introduce a single kernel that calculates offsets from first dimension splits for the case where last dimension is not varying. Previously, the scale, cumulative sum, and zero concat were unfused.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Introduce a single kernel that calculates offsets from first dimension splits for the case where last dimension is not varying.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman
Copy link
Member Author

/te-ci

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 12, 2026

Greptile Summary

This PR replaces a three-operation PyTorch sequence (scale → cumsum → cat-with-zero) with a single fused CUDA kernel (splits_to_offsets_kernel) that computes prefix-sum offsets for grouped tensors in one pass. The kernel uses a Hillis-Steele inclusive scan within a single 256-thread block, iterating over input elements in chunks when num_tensors > 256. A new C API (nvte_splits_to_offsets) and a PyTorch binding (splits_to_offsets) are exposed, and comprehensive tests covering boundary cases around kThreadsPerBlock (255, 256, 257) are included.

  • Kernel correctness: The __syncthreads() barriers are correctly placed — after the output writes (before chunk_prefix is updated) and after the update (before the next chunk's reads) — ensuring no shared-memory race on chunk_prefix across loop iterations.
  • C API validation (nvte_splits_to_offsets): Properly guards against null pointers, num_tensors == 0, and logical_last_dim <= 0.
  • PyTorch wrapper gap: The splits_to_offsets function in misc.cpp validates dtype, device, and rank, but does not guard against an empty (numel == 0) first_dims tensor; a call with an empty tensor will propagate an error from inside nvte_splits_to_offsets rather than producing a clear boundary-level message.
  • build_grouped_tensor_offsets: Cleanly migrated to the new kernel via NVTE_SCOPED_GIL_RELEASE, removing the multi-kernel PyTorch path.

Confidence Score: 4/5

  • PR is safe to merge; the kernel logic and synchronization are correct, with one minor validation gap in the PyTorch wrapper.
  • The fused kernel correctly implements Hillis-Steele prefix scan with proper __syncthreads() barriers, the C API has full input validation, and the test suite covers critical boundary cases around kThreadsPerBlock. The only gap is the missing numel() > 0 guard in the public PyTorch splits_to_offsets wrapper, which would cause a confusing error message for empty-tensor inputs rather than a runtime crash or incorrect result.
  • transformer_engine/pytorch/csrc/extensions/misc.cpp — missing empty-tensor guard before calling nvte_splits_to_offsets

Important Files Changed

Filename Overview
transformer_engine/common/common.cu Adds splits_to_offsets_kernel: a single-block Hillis-Steele inclusive-scan kernel that fuses scale, prefix-sum, and zero-prepend into one pass. Synchronization is correct: __syncthreads() barriers are placed after the output writes and after the chunk_prefix update, ensuring no race on chunk_prefix across loop iterations. The C-API wrapper nvte_splits_to_offsets includes appropriate null-pointer and bounds checks.
transformer_engine/pytorch/csrc/extensions/misc.cpp Adds the PyTorch-facing splits_to_offsets wrapper. Input validation is otherwise thorough (CUDA device, dtype, ndim, logical_last_dim > 0), but there is no guard against an empty first_dims tensor (numel == 0), which would surface a less-than-helpful error from inside nvte_splits_to_offsets.
transformer_engine/pytorch/csrc/quantizer.cpp Replaces the three-op sequence (scale × cumsum × cat-with-zero) with a single call to nvte_splits_to_offsets inside a NVTE_SCOPED_GIL_RELEASE block. Logic is correct and the GIL release is consistent with the rest of the file.
tests/cpp/operator/test_splits_to_offsets.cu New test covers boundary values around kThreadsPerBlock (255, 256, 257) and larger sizes (1024), and multiple logical_last_dim values. Reference computation on the host is correct. Looks good.
transformer_engine/common/include/transformer_engine/transformer_engine.h New public C API declaration for nvte_splits_to_offsets with clear Doxygen documentation of parameters, semantics, and the exact formula computed. No issues.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["splits_to_offsets(first_dims, logical_last_dim)
    [PyTorch Python API]"]
    B["splits_to_offsets()
    misc.cpp
    — validate: CUDA, int64, 1D, logical_last_dim > 0
    — allocate output[num_tensors + 1]"]
    C["nvte_splits_to_offsets()
    common.cu  [C API]
    — validate: non-null, num_tensors > 0, logical_last_dim > 0"]
    D["splits_to_offsets_kernel<<<1, 256>>>
    — thread 0: output[0] = 0, chunk_prefix = 0
    — chunk loop (stride 256):
        load first_dims × logical_last_dim → block_scan
        Hillis-Steele inclusive scan
        write output[idx+1] = chunk_prefix + block_scan[tid]
        thread 255: chunk_prefix += block_scan[255]"]
    E["build_grouped_tensor_offsets()
    quantizer.cpp  [internal path]"]
    F["output[0..num_tensors]
    device int64 prefix-sum tensor"]

    A --> B --> C --> D --> F
    E -->|"NVTE_SCOPED_GIL_RELEASE"| C
Loading

Comments Outside Diff (1)

  1. transformer_engine/pytorch/csrc/extensions/misc.cpp, line 22-28 (link)

    Missing guard for empty input tensor

    first_dims is validated to be 1-D and on CUDA, but there is no check that it is non-empty. If first_dims.numel() == 0, num_tensors becomes 0 and the call to nvte_splits_to_offsets will throw "num_tensors must be greater than 0" — an error message that does not mention first_dims and may be confusing at the Python boundary.

    Either add an explicit guard at the PyTorch level:

    or return the identity result ([0]) for the zero-tensor edge case, which is the mathematically correct prefix-sum for an empty input.

Last reviewed commit: b06b489

ksivaman and others added 3 commits March 12, 2026 10:15
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant