Skip to content

[JAX] Size autotuned Triton grids per config #2975

Open
tdophung wants to merge 4 commits into
NVIDIA:mainfrom
tdophung:jax-triton-grid-autotune
Open

[JAX] Size autotuned Triton grids per config #2975
tdophung wants to merge 4 commits into
NVIDIA:mainfrom
tdophung:jax-triton-grid-autotune

Conversation

@tdophung
Copy link
Copy Markdown
Collaborator

@tdophung tdophung commented May 11, 2026

Description

The autotuned path in triton_call_lowering compiled all BLOCK_SIZE configs but dispatched every one with the same fixed grid sized for the smallest BLOCK_SIZE, so larger configs over-launched by the BLOCK_SIZE ratio. Make grid accept a callable(meta)->tuple evaluated per config, matching the jax-triton API. Update _permute_kernel, _unpermute_kernel, and _sort_chunks_by_map_kernel lowerings. Measured 22.6ms -> 7.4ms (3.06x) on GB200 for sort_chunks at 524k tokens, hidden=4096, fp32.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Pass the grid callable format to triton kernel call
  • Make grid passed into the triton lowering call for all triton permutation kernels to be a a callable(meta)->tuple

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

The autotuned path in triton_call_lowering compiled all BLOCK_SIZE configs
but dispatched every one with the same fixed grid sized for the smallest
BLOCK_SIZE, so larger configs over-launched by the BLOCK_SIZE ratio. Make
grid accept a callable(meta)->tuple evaluated per config, matching the
jax-triton API. Update _permute_kernel, _unpermute_kernel, and
_sort_chunks_by_map_kernel lowerings. Measured 22.6ms -> 7.4ms (3.06x) on
GB200 for sort_chunks at 524k tokens, hidden=4096, fp32.

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung changed the title [JAX] Size autotuned Triton grids per config (3x perm-kernel speedup on JAX side) [JAX] Size autotuned Triton grids per config May 11, 2026
Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung marked this pull request as ready for review May 11, 2026 21:11
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 11, 2026

Greptile Summary

This PR fixes a grid-sizing bug in the JAX Triton kernel lowering path where all autotuned configs were dispatched with a fixed grid sized for the smallest BLOCK_SIZE, causing larger configs to over-launch by the BLOCK_SIZE ratio. The fix makes grid accept a callable(meta) -> tuple evaluated per autotune config, matching jax-triton's API.

  • utils.py: triton_call_lowering now accepts a callable grid; a _normalize_grid helper normalises the result to a 3-D tuple, called once per config in the autotuned loop and once with the merged constexprs on the non-autotuned (fallback) path.
  • permutation.py: _permute_kernel, _unpermute_kernel, and _sort_chunks_by_map_kernel lowerings each replace the fixed grid tuple with a def grid(meta) closure that reads meta["BLOCK_SIZE"], so every autotuned config is launched with the correct number of blocks.

Confidence Score: 5/5

Safe to merge — the fix is narrow, well-reasoned, and both the autotuned and fallback dispatch paths correctly resolve the grid before calling TritonKernelCall.

The change is a targeted bug fix: the callable grid is evaluated with the correct config-specific constexprs on the autotuned path, and with the merged kernel constexprs on the non-autotuned path. The _normalize_grid helper is symmetric with the original normalization logic. The three permutation kernel sites close over stable local variables (num_tokens, hidden_size) with no loop-variable capture issues, and BLOCK_SIZE is always present in the meta dict passed to the callable on every code path.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/jax/triton_extensions/utils.py Adds callable grid support to triton_call_lowering; evaluated per config on the autotuned path and once with merged constexprs on the fallback path. Logic is correct and well-commented.
transformer_engine/jax/triton_extensions/permutation.py Three kernel lowerings updated to use a per-config callable grid closing over num_tokens and hidden_size; correctly surfaces BLOCK_SIZE from each config's meta dict.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["triton_call_lowering(grid=...)"] --> B{Is grid callable?}
    B -- No --> C["_normalize_grid(grid) → fixed grid_tuple"]
    B -- Yes --> D["grid_callable = grid\ngrid_tuple = None"]
    C --> E{Is kernel autotuned?}
    D --> E
    E -- Yes --> F["For each config in kernel_fn.configs"]
    F --> G["config_constexprs =\n{**constexprs, **config.kwargs}"]
    G --> H{grid_callable?}
    H -- Yes --> I["_normalize_grid(grid_callable(config_constexprs)) → config_grid"]
    H -- No --> J["config_grid = grid_tuple"]
    I --> K["TritonKernelCall(config_kernel, config_grid)"]
    J --> K
    K --> L["TritonAutotunedKernelCall (all configs)"]
    E -- No --> M["kernel_constexprs = constexprs or merged fallback"]
    M --> N{grid_callable?}
    N -- Yes --> O["_normalize_grid(grid_callable(kernel_constexprs)) → single_grid"]
    N -- No --> P["single_grid = grid_tuple"]
    O --> Q["TritonKernelCall(kernel, single_grid)"]
    P --> Q
Loading

Reviews (3): Last reviewed commit: "[JAX] Fix misleading old-JAX fallback co..." | Re-trigger Greptile

Comment thread transformer_engine/jax/triton_extensions/permutation.py Outdated
Clarify that constexprs values override config.kwargs in the non-autotune
fallback path (utils.py merges {**first_cfg.kwargs, **constexprs}). Three
sites: _permute_kernel, _unpermute_kernel, _sort_chunks_by_map_kernel.

Signed-off-by: tdophung <tdophung@nvidia.com>
Copy link
Copy Markdown
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good! Left a few comments. Thanks!

def grid(meta):
return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"]))

# block_size populates the BLOCK_SIZE entry in constexprs so the wrapper
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SR: some of these comments are more about this particular fix and reference the previous approach, which I think will be confusing in the future after this PR merges and we can no longer see the previous approach

I think we can simplify this and the comment above to something like "We use BLOCK_SIZE in grid calculation to ensure the grid is the proper size. If the grid size is an overestimate it can significantly hurt performance"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for other cases where def grid is defined below too

if len(grid_tuple) == 2:
return (grid_tuple[0], grid_tuple[1], 1)
return tuple(grid_tuple[:3])

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SR: assert isinstance(grid, (callable, tuple)), f"Argument 'grid' must be a tuple or a callable but received: type={type(grid)}, value={grid}"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants