Skip to content

[JAX] Grouped GEMM Refactor to use first_dims and last_dims#2749

Draft
jberchtold-nvidia wants to merge 12 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/gmm-refactor
Draft

[JAX] Grouped GEMM Refactor to use first_dims and last_dims#2749
jberchtold-nvidia wants to merge 12 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/gmm-refactor

Conversation

@jberchtold-nvidia
Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented Mar 10, 2026

Description

This PR refactors the grouped GEMM API in the JAX backend to support fully ragged (variable-size per group)
dimensions across all tensor axes, replacing the previous single group_sizes parameter with six per-tensor
dimension parameters. The motivation is to generalize the interface so that forward and backward (wgrad) passes
can be expressed uniformly without special-casing, and to eliminate the need for callers to manually compute and
pass matrix dimensions (M, N, K) — these are now derived automatically from XLA buffer descriptors in C++.

Addresses issue: #2648

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • grouped_gemm API signature change: replaced the single group_sizes positional argument with six keyword
    arguments — lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims, out_first_dims, out_last_dims — each an
    optional (G,) int32 array describing per-group sizes along that tensor axis (empty (0,) arrays indicate a
    uniform/non-ragged dimension)
  • Removed explicit M/N/K parameters from C++ FFI: matrix dimensions are now derived automatically from XLA buffer
    shapes inside the C++ handler, eliminating manual dimension computation in Python
  • Removed is_grouped_dense_wgrad flag: the wgrad vs. forward distinction is now inferred from which dimension
    arrays are non-empty (non-empty rhs_first_dims indicates a ragged K contraction dimension, producing a
    (num_groups, M, N) output)
  • New C++ config struct GroupedGemmV2Config: consolidates lhs_is_trans, rhs_is_trans, and scaling_mode into a
    single FFI attribute struct, replacing individual attribute bindings
  • New C++ helper make_grouped_tensor() overload: accepts first_dims/last_dims buffers, converts int32 group-size
    arrays to int64 in partitioned int64_workspace slots, and returns updated workspace offset to avoid aliasing
  • dense.py updated: _grouped_dense_fwd_rule and _grouped_dense_bwd_rule updated to pass group_sizes via the
    appropriate new per-tensor parameter (lhs_first_dims/out_first_dims for forward; rhs_first_dims for wgrad)
  • Tests updated: TestGroupedDense test cases migrated to the new keyword-argument API with explicit empty_gs =
    jnp.empty((0,), jnp.int32) sentinels for non-ragged axes

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
tensor

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia marked this pull request as draft March 10, 2026 17:24
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 10, 2026

Greptile Summary

This PR refactors the JAX Grouped GEMM interface to replace the single group_sizes array with six separate per-tensor dimension arrays (lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims, out_first_dims, out_last_dims). It also removes the M/N/K static FFI attributes by deriving them at runtime from the 2D buffer shapes, replaces the is_grouped_dense_wgrad boolean flag with a semantic check on whether rhs dims are ragged, and introduces backward-compatible struct-based attribute decoding (GroupedGemmV2Config / GroupedGemmConfig) for the C++ FFI handlers.

Key issues found:

  • Latent correctness bug (gemm.cpp): any_ragged = is_lhs_ragged || is_rhs_ragged excludes output dims, so int64_sizes_ptr is never populated when only out_first_dims / out_last_dims are non-empty, but set_group_sizes_only is still called with the uninitialized pointer. The same active_gs_ptr selection chain does not include output buffers as fallbacks.
  • Incorrect M extraction for lhs_is_trans=True in non-wgrad path (gemm.py abstract_eval + gemm.cpp): Both the Python abstract eval and the C++ GroupedGemmV2FFI unconditionally use shape[0] for M in the non-wgrad branch, which equals K when lhs_is_trans=True.
  • Divide-by-zero risk when all dims are empty (gemm.py): If all six dimension arrays are empty sentinels, num_gemms=0, which propagates to alpha/beta size 0 and a C++ integer division by zero (rhs_data.dimensions()[0] / num_gemms) in the V2 path.
  • Removed K_lhs == K_rhs assertion (gemm.py): The early validation that contracting dimensions of LHS and RHS match was silently dropped, deferring shape-mismatch detection to deep inside cuBLAS.

Confidence Score: 3/5

  • PR introduces a latent correctness bug (uninitialized pointer in V2 FFI for out-only-ragged calls) and removes an important K-dimension consistency check; existing paths are safe but the new API surface has unfenced failure modes.
  • The refactoring is mechanically sound for all current in-tree call sites (FWD, DGRAD, WGRAD always set lhs_first_dims alongside any out_first_dims), so no existing test is expected to regress. However, three newly introduced issues — the out-only-ragged int64_sizes_ptr bug, the wrong-M bug for lhs_is_trans=True in non-wgrad, and the divide-by-zero for all-empty dims — reduce confidence. The removal of the K_lhs == K_rhs assertion also weakens the API contract without replacement validation.
  • Pay close attention to transformer_engine/jax/csrc/extensions/gemm.cpp (both GroupedGemmV2FFI and GroupedGemmFFI) and transformer_engine/jax/cpp_extensions/gemm.py (abstract_eval and the num_gemms computation).

Important Files Changed

Filename Overview
transformer_engine/jax/csrc/extensions/gemm.cpp Replaces single group_sizes buffer + M/N/K static attrs with 6 per-dimension buffers; introduces GroupedGemmConfig/GroupedGemmV2Config structs for backward-compatible dict attributes. Latent bug: any_ragged only covers lhs/rhs dims, so int64_sizes_ptr stays uninitialized when only output dims are ragged. Also, the non-wgrad m-derivation (lhs_data.dimensions()[0]) is wrong when lhs_is_trans=True.
transformer_engine/jax/cpp_extensions/gemm.py Core Python grouped GEMM refactor: replaces group_sizes with six *_first_dims/*_last_dims parameters, removes M/N/K static args, and derives output shapes from 2D buffer dimensions. Issues: non-wgrad M = lhs_data_aval.shape[0] is incorrect for lhs_is_trans=True; the K_lhs == K_rhs consistency assert was removed; num_gemms=0 can cause divide-by-zero in the V2 path.
transformer_engine/jax/dense.py Call-site update to the new grouped_gemm API: FWD and DGRAD pass lhs_first_dims=group_sizes, out_first_dims=group_sizes; WGRAD passes lhs_first_dims=group_sizes, rhs_first_dims=group_sizes. Changes are mechanical and correct for current usage patterns.
tests/jax/test_custom_call_compute.py Tests updated to use the new keyword-argument API with explicit empty sentinel arrays; no new test coverage added for the extended multi-dim API (e.g. out-only ragged, lhs_last_dims, rhs_last_dims scenarios).

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["grouped_gemm(lhs, rhs,\nlhs_first_dims, lhs_last_dims,\nrhs_first_dims, rhs_last_dims,\nout_first_dims, out_last_dims, ...)"] --> B{Is rhs ragged?\nrhs_first_dims.size > 0\nor rhs_last_dims.size > 0}

    B -->|Yes| C["WGRAD path\nlhs_is_trans=True, rhs_is_trans=False\nlhs_flatten_axis=1, rhs_flatten_axis=1\nout_shape=(num_gemms, M, N)"]
    B -->|No| D["FWD / DGRAD path\nDerive lhs_is_trans from contracting_dims\nout_shape=(M_total, N)"]

    C --> E{can_use_v2?}
    D --> E

    E -->|BF16 + NO_SCALING + no bias| F["V2 path\nalpha=ones(G), beta=zeros(G)\nGroupedGemmV2FFI"]
    E -->|Otherwise| G["Legacy path\ngroup_offset=zeros(1)\nGroupedGemmFFI\nnvte_multi_tensor_gemm loop"]

    F --> H{any_ragged?\nis_lhs_ragged ∥ is_rhs_ragged}
    H -->|Yes| I["nvte_convert_int32_to_int64\nactive_gs_ptr from lhs/rhs dims"]
    H -->|No| J["⚠ Skip int64 conversion\nactive_gs_ptr = nullptr"]

    I --> K{is_rhs_ragged?}
    J --> K

    K -->|Yes| L["WGRAD branch\nrhs/lhs set_group_sizes_only\nout shape: (num_gemms*M, N)"]
    K -->|No| M["FWD/DGRAD branch\nlhs set_group_sizes_only\nout set_group_sizes_only\n⚠ uses int64_sizes_ptr\nout shape: (M_total, N)"]

    G --> N["dim_list_host D2H copy\nor async GroupedGemmGetGroupSizes\nnvte_multi_tensor_gemm per-group loop"]
Loading

Last reviewed commit: ed9c8e4

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/gmm-refactor branch from 35171af to 88bb7da Compare March 10, 2026 18:56
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/gmm-refactor branch from 20fadc7 to 025f598 Compare March 10, 2026 23:26
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/gmm-refactor branch from a427b9e to 089e530 Compare March 10, 2026 23:59
jberchtold-nvidia and others added 3 commits March 10, 2026 17:04
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…mm-refactor

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant