[JAX] Size autotuned Triton grids per config #2975
Conversation
The autotuned path in triton_call_lowering compiled all BLOCK_SIZE configs but dispatched every one with the same fixed grid sized for the smallest BLOCK_SIZE, so larger configs over-launched by the BLOCK_SIZE ratio. Make grid accept a callable(meta)->tuple evaluated per config, matching the jax-triton API. Update _permute_kernel, _unpermute_kernel, and _sort_chunks_by_map_kernel lowerings. Measured 22.6ms -> 7.4ms (3.06x) on GB200 for sort_chunks at 524k tokens, hidden=4096, fp32. Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Greptile SummaryThis PR fixes a grid-sizing bug in the JAX Triton kernel lowering path where all autotuned configs were dispatched with a fixed grid sized for the smallest BLOCK_SIZE, causing larger configs to over-launch by the BLOCK_SIZE ratio. The fix makes
Confidence Score: 5/5Safe to merge — the fix is narrow, well-reasoned, and both the autotuned and fallback dispatch paths correctly resolve the grid before calling TritonKernelCall. The change is a targeted bug fix: the callable grid is evaluated with the correct config-specific constexprs on the autotuned path, and with the merged kernel constexprs on the non-autotuned path. The _normalize_grid helper is symmetric with the original normalization logic. The three permutation kernel sites close over stable local variables (num_tokens, hidden_size) with no loop-variable capture issues, and BLOCK_SIZE is always present in the meta dict passed to the callable on every code path. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["triton_call_lowering(grid=...)"] --> B{Is grid callable?}
B -- No --> C["_normalize_grid(grid) → fixed grid_tuple"]
B -- Yes --> D["grid_callable = grid\ngrid_tuple = None"]
C --> E{Is kernel autotuned?}
D --> E
E -- Yes --> F["For each config in kernel_fn.configs"]
F --> G["config_constexprs =\n{**constexprs, **config.kwargs}"]
G --> H{grid_callable?}
H -- Yes --> I["_normalize_grid(grid_callable(config_constexprs)) → config_grid"]
H -- No --> J["config_grid = grid_tuple"]
I --> K["TritonKernelCall(config_kernel, config_grid)"]
J --> K
K --> L["TritonAutotunedKernelCall (all configs)"]
E -- No --> M["kernel_constexprs = constexprs or merged fallback"]
M --> N{grid_callable?}
N -- Yes --> O["_normalize_grid(grid_callable(kernel_constexprs)) → single_grid"]
N -- No --> P["single_grid = grid_tuple"]
O --> Q["TritonKernelCall(kernel, single_grid)"]
P --> Q
Reviews (3): Last reviewed commit: "[JAX] Fix misleading old-JAX fallback co..." | Re-trigger Greptile |
Clarify that constexprs values override config.kwargs in the non-autotune
fallback path (utils.py merges {**first_cfg.kwargs, **constexprs}). Three
sites: _permute_kernel, _unpermute_kernel, _sort_chunks_by_map_kernel.
Signed-off-by: tdophung <tdophung@nvidia.com>
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
Overall looks good! Left a few comments. Thanks!
| def grid(meta): | ||
| return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) | ||
|
|
||
| # block_size populates the BLOCK_SIZE entry in constexprs so the wrapper |
There was a problem hiding this comment.
SR: some of these comments are more about this particular fix and reference the previous approach, which I think will be confusing in the future after this PR merges and we can no longer see the previous approach
I think we can simplify this and the comment above to something like "We use BLOCK_SIZE in grid calculation to ensure the grid is the proper size. If the grid size is an overestimate it can significantly hurt performance"
There was a problem hiding this comment.
Same for other cases where def grid is defined below too
| if len(grid_tuple) == 2: | ||
| return (grid_tuple[0], grid_tuple[1], 1) | ||
| return tuple(grid_tuple[:3]) | ||
|
|
There was a problem hiding this comment.
SR: assert isinstance(grid, (callable, tuple)), f"Argument 'grid' must be a tuple or a callable but received: type={type(grid)}, value={grid}"
Description
The autotuned path in triton_call_lowering compiled all BLOCK_SIZE configs but dispatched every one with the same fixed grid sized for the smallest BLOCK_SIZE, so larger configs over-launched by the BLOCK_SIZE ratio. Make grid accept a callable(meta)->tuple evaluated per config, matching the jax-triton API. Update _permute_kernel, _unpermute_kernel, and _sort_chunks_by_map_kernel lowerings. Measured 22.6ms -> 7.4ms (3.06x) on GB200 for sort_chunks at 524k tokens, hidden=4096, fp32.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: