From cedd9fda39064b7f3fec87372d1f5326015da2ee Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 11 May 2026 10:58:04 -0700 Subject: [PATCH 1/4] [JAX] Size autotuned Triton grids per config (3x perm-kernel speedup) The autotuned path in triton_call_lowering compiled all BLOCK_SIZE configs but dispatched every one with the same fixed grid sized for the smallest BLOCK_SIZE, so larger configs over-launched by the BLOCK_SIZE ratio. Make grid accept a callable(meta)->tuple evaluated per config, matching the jax-triton API. Update _permute_kernel, _unpermute_kernel, and _sort_chunks_by_map_kernel lowerings. Measured 22.6ms -> 7.4ms (3.06x) on GB200 for sort_chunks at 524k tokens, hidden=4096, fp32. Signed-off-by: tdophung --- .../jax/triton_extensions/permutation.py | 27 +++++-- .../jax/triton_extensions/utils.py | 72 +++++++++++++------ 2 files changed, 71 insertions(+), 28 deletions(-) diff --git a/transformer_engine/jax/triton_extensions/permutation.py b/transformer_engine/jax/triton_extensions/permutation.py index 98c54e52bb..faf720d0e6 100644 --- a/transformer_engine/jax/triton_extensions/permutation.py +++ b/transformer_engine/jax/triton_extensions/permutation.py @@ -589,10 +589,17 @@ def lowering( probs_stride_token = 0 probs_stride_expert = 0 - # Grid function equivalent: (num_tokens, cdiv(hidden_size, BLOCK_SIZE)) - # Use minimum BLOCK_SIZE from autotune configs to ensure grid covers all elements + # Per-config grid: size by autotune-selected BLOCK_SIZE. With a fixed + # grid sized for the smallest BLOCK_SIZE, larger configs over-launch by + # the BLOCK_SIZE ratio (every extra block masks out and exits) — the + # cause of the perm-kernel perf regression vs jax-triton's autotuner. + def grid(meta): + return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) + + # block_size is only a placeholder for the constexpr signature; the + # autotune loop overrides BLOCK_SIZE per config (and on the old-JAX + # fallback path, the first config's BLOCK_SIZE is used). block_size = _get_min_block_size(_permute_kernel) - grid = (num_tokens, triton.cdiv(hidden_size, block_size)) # Use input_output_aliases to alias pre-zeroed buffers to outputs. # This ensures padding positions contain zeros since the kernel only writes valid positions. @@ -997,9 +1004,12 @@ def lowering( unpermuted_probs_stride_token = num_experts unpermuted_probs_stride_expert = 1 - # Grid - use minimum BLOCK_SIZE from autotune configs + # Per-config grid: size by autotune-selected BLOCK_SIZE. + def grid(meta): + return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) + + # Placeholder for constexpr signature; autotune overrides per config. block_size = _get_min_block_size(_unpermute_kernel) - grid = (num_tokens, triton.cdiv(hidden_size, block_size)) return triton_call_lowering( ctx, @@ -1720,9 +1730,12 @@ def lowering( probs_stride_token = 1 permuted_probs_stride_token = 1 - # Grid - use minimum BLOCK_SIZE from autotune configs + # Per-config grid: size by autotune-selected BLOCK_SIZE. + def grid(meta): + return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) + + # Placeholder for constexpr signature; autotune overrides per config. block_size = _get_min_block_size(_sort_chunks_by_map_kernel) - grid = (num_tokens, triton.cdiv(hidden_size, block_size)) # Declare input_output_aliases so XLA knows output slot 0 is claimed by # input 3 (output_buf). This prevents XLA from implicitly aliasing any diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 2a86321c34..864565b834 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -390,7 +390,16 @@ def triton_call_lowering( ctx: MLIR lowering context kernel_fn: Triton kernel function *array_args: Input arrays (from ctx) - grid: Grid dimensions (int or tuple) + grid: Grid dimensions. May be either: + - an int or tuple (fixed grid for every config), or + - a callable ``meta -> int|tuple`` (evaluated per autotune config). + + Use the callable form for autotuned kernels whose grid depends on + ``BLOCK_SIZE`` (or any other autotuned constexpr); otherwise the + launch grid will not match the autotuner-selected config and the + kernel will either over-launch (waste) or under-cover. ``meta`` is + the merged dict ``{**constexprs, **config.kwargs}`` for the chosen + config — the same convention as jax-triton's ``triton_call``. input_output_aliases: Mapping of input to output aliases constexprs: Compile-time constants for the kernel. This includes both tl.constexpr arguments AND scalar runtime arguments (like @@ -404,13 +413,12 @@ def triton_call_lowering( def lowering(ctx, x, *, block_size): from ..triton_extensions import triton_call_lowering n = ctx.avals_in[0].size + + def grid(meta): + return (triton.cdiv(n, meta["BLOCK_SIZE"]),) + return triton_call_lowering( - ctx, my_kernel, x, - grid=(triton.cdiv(n, block_size),), - constexprs={ - "n_elements": n, # scalar arg (not tl.constexpr in kernel) - "BLOCK_SIZE": block_size, # tl.constexpr arg - }, + ctx, my_kernel, x, grid=grid, constexprs={"n_elements": n}, ) """ # Get compute capability using gpu_triton @@ -431,15 +439,23 @@ def lowering(ctx, x, *, block_size): tensor_arg_names = [n for n in arg_names if n not in constexpr_names] signature = {n: get_triton_dtype(a) for n, a in zip(tensor_arg_names, all_avals)} - # Normalize grid to 3D - if isinstance(grid, int): - grid_tuple = (grid, 1, 1) - elif len(grid) == 1: - grid_tuple = (grid[0], 1, 1) - elif len(grid) == 2: - grid_tuple = (grid[0], grid[1], 1) + # Normalize grid to 3D. When `grid` is a callable, defer evaluation until + # we know the per-config meta (so each autotune config gets its own grid, + # matching jax-triton's behavior). + def _normalize_grid(g): + if isinstance(g, int): + return (g, 1, 1) + if len(g) == 1: + return (g[0], 1, 1) + if len(g) == 2: + return (g[0], g[1], 1) + return tuple(g[:3]) + + grid_callable = grid if callable(grid) else None + if grid_callable is None: + grid_tuple = _normalize_grid(grid) else: - grid_tuple = grid[:3] + grid_tuple = None # evaluated per-config below # Default values for the kernel actual_kernel_fn = kernel_fn @@ -510,11 +526,18 @@ def lowering(ctx, x, *, block_size): for _ in list(ctx.avals_in) + list(ctx.avals_out): config_params.append(gpu_triton.create_array_parameter(0, 16)) + # Per-config grid: evaluate `grid(meta)` if grid is a callable so + # the launch shape matches this config's BLOCK_SIZE (etc.). + if grid_callable is not None: + config_grid = _normalize_grid(grid_callable(config_constexprs)) + else: + config_grid = grid_tuple + config_call = gpu_triton.TritonKernelCall( config_kernel, - grid_tuple[0], - grid_tuple[1], - grid_tuple[2], + config_grid[0], + config_grid[1], + config_grid[2], config_params, ) @@ -571,11 +594,18 @@ def lowering(ctx, x, *, block_size): for _ in list(ctx.avals_in) + list(ctx.avals_out): kernel_params.append(gpu_triton.create_array_parameter(0, 16)) + # Non-autotuned dispatch: evaluate `grid(meta)` once with the merged + # constexprs (which already reflect the single config we'll launch). + if grid_callable is not None: + single_grid = _normalize_grid(grid_callable(kernel_constexprs)) + else: + single_grid = grid_tuple + kernel_call = gpu_triton.TritonKernelCall( kernel, - grid_tuple[0], - grid_tuple[1], - grid_tuple[2], + single_grid[0], + single_grid[1], + single_grid[2], kernel_params, ) From 4af0ab03ce7764ebcc136f0a22e93e1db82181fc Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 11 May 2026 14:10:27 -0700 Subject: [PATCH 2/4] change variable name for more sensical naming Signed-off-by: tdophung --- .../jax/triton_extensions/utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 864565b834..b4d7865e87 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -442,14 +442,14 @@ def grid(meta): # Normalize grid to 3D. When `grid` is a callable, defer evaluation until # we know the per-config meta (so each autotune config gets its own grid, # matching jax-triton's behavior). - def _normalize_grid(g): - if isinstance(g, int): - return (g, 1, 1) - if len(g) == 1: - return (g[0], 1, 1) - if len(g) == 2: - return (g[0], g[1], 1) - return tuple(g[:3]) + def _normalize_grid(grid_tuple): + if isinstance(grid_tuple, int): + return (grid_tuple, 1, 1) + if len(grid_tuple) == 1: + return (grid_tuple[0], 1, 1) + if len(grid_tuple) == 2: + return (grid_tuple[0], grid_tuple[1], 1) + return tuple(grid_tuple[:3]) grid_callable = grid if callable(grid) else None if grid_callable is None: From 0077c1ef0260d55477ee097f76544c39b6792c2a Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 11 May 2026 14:35:20 -0700 Subject: [PATCH 3/4] [JAX] Fix misleading old-JAX fallback comment in perm lowerings Clarify that constexprs values override config.kwargs in the non-autotune fallback path (utils.py merges {**first_cfg.kwargs, **constexprs}). Three sites: _permute_kernel, _unpermute_kernel, _sort_chunks_by_map_kernel. Signed-off-by: tdophung --- .../jax/triton_extensions/permutation.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/transformer_engine/jax/triton_extensions/permutation.py b/transformer_engine/jax/triton_extensions/permutation.py index faf720d0e6..032ff884d7 100644 --- a/transformer_engine/jax/triton_extensions/permutation.py +++ b/transformer_engine/jax/triton_extensions/permutation.py @@ -596,9 +596,11 @@ def lowering( def grid(meta): return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) - # block_size is only a placeholder for the constexpr signature; the - # autotune loop overrides BLOCK_SIZE per config (and on the old-JAX - # fallback path, the first config's BLOCK_SIZE is used). + # block_size populates the BLOCK_SIZE entry in constexprs so the wrapper + # marks it constexpr in the kernel signature. The autotune loop + # overrides it per config; on the old-JAX fallback path the caller's + # value wins (utils.py merges {**first_cfg.kwargs, **constexprs}), + # which is why we pass the smallest config's BLOCK_SIZE explicitly. block_size = _get_min_block_size(_permute_kernel) # Use input_output_aliases to alias pre-zeroed buffers to outputs. @@ -1008,7 +1010,11 @@ def lowering( def grid(meta): return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) - # Placeholder for constexpr signature; autotune overrides per config. + # block_size populates the BLOCK_SIZE entry in constexprs so the wrapper + # marks it constexpr in the kernel signature. The autotune loop + # overrides it per config; on the old-JAX fallback path the caller's + # value wins (utils.py merges {**first_cfg.kwargs, **constexprs}), + # which is why we pass the smallest config's BLOCK_SIZE explicitly. block_size = _get_min_block_size(_unpermute_kernel) return triton_call_lowering( @@ -1734,7 +1740,11 @@ def lowering( def grid(meta): return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) - # Placeholder for constexpr signature; autotune overrides per config. + # block_size populates the BLOCK_SIZE entry in constexprs so the wrapper + # marks it constexpr in the kernel signature. The autotune loop + # overrides it per config; on the old-JAX fallback path the caller's + # value wins (utils.py merges {**first_cfg.kwargs, **constexprs}), + # which is why we pass the smallest config's BLOCK_SIZE explicitly. block_size = _get_min_block_size(_sort_chunks_by_map_kernel) # Declare input_output_aliases so XLA knows output slot 0 is claimed by From 556b46d195e7fe173f7f0112d96c53934a1081a3 Mon Sep 17 00:00:00 2001 From: tdophung Date: Thu, 14 May 2026 10:29:21 -0700 Subject: [PATCH 4/4] [JAX] Triton wrapper defaults match jax-triton (3.25ms speedup) num_warps default 32->4 and num_stages 1->3 in triton_call_lowering match Triton's own triton.Config defaults. Non-autotuned kernels (e.g. _make_chunk_sort_map_kernel) were running with 1024 threads/block, an 8x kernel slowdown. Also: tuple/callable grid assertion + comment trims. Signed-off-by: tdophung --- .../jax/triton_extensions/permutation.py | 30 ++++++------------- .../jax/triton_extensions/utils.py | 19 ++++++++---- 2 files changed, 23 insertions(+), 26 deletions(-) diff --git a/transformer_engine/jax/triton_extensions/permutation.py b/transformer_engine/jax/triton_extensions/permutation.py index 032ff884d7..22f983f078 100644 --- a/transformer_engine/jax/triton_extensions/permutation.py +++ b/transformer_engine/jax/triton_extensions/permutation.py @@ -589,18 +589,12 @@ def lowering( probs_stride_token = 0 probs_stride_expert = 0 - # Per-config grid: size by autotune-selected BLOCK_SIZE. With a fixed - # grid sized for the smallest BLOCK_SIZE, larger configs over-launch by - # the BLOCK_SIZE ratio (every extra block masks out and exits) — the - # cause of the perm-kernel perf regression vs jax-triton's autotuner. + # We use BLOCK_SIZE in the grid calculation to ensure the grid is the + # proper size. If the grid size is an overestimate it can significantly + # hurt performance. def grid(meta): return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) - # block_size populates the BLOCK_SIZE entry in constexprs so the wrapper - # marks it constexpr in the kernel signature. The autotune loop - # overrides it per config; on the old-JAX fallback path the caller's - # value wins (utils.py merges {**first_cfg.kwargs, **constexprs}), - # which is why we pass the smallest config's BLOCK_SIZE explicitly. block_size = _get_min_block_size(_permute_kernel) # Use input_output_aliases to alias pre-zeroed buffers to outputs. @@ -1006,15 +1000,12 @@ def lowering( unpermuted_probs_stride_token = num_experts unpermuted_probs_stride_expert = 1 - # Per-config grid: size by autotune-selected BLOCK_SIZE. + # We use BLOCK_SIZE in the grid calculation to ensure the grid is the + # proper size. If the grid size is an overestimate it can significantly + # hurt performance. def grid(meta): return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) - # block_size populates the BLOCK_SIZE entry in constexprs so the wrapper - # marks it constexpr in the kernel signature. The autotune loop - # overrides it per config; on the old-JAX fallback path the caller's - # value wins (utils.py merges {**first_cfg.kwargs, **constexprs}), - # which is why we pass the smallest config's BLOCK_SIZE explicitly. block_size = _get_min_block_size(_unpermute_kernel) return triton_call_lowering( @@ -1736,15 +1727,12 @@ def lowering( probs_stride_token = 1 permuted_probs_stride_token = 1 - # Per-config grid: size by autotune-selected BLOCK_SIZE. + # We use BLOCK_SIZE in the grid calculation to ensure the grid is the + # proper size. If the grid size is an overestimate it can significantly + # hurt performance. def grid(meta): return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) - # block_size populates the BLOCK_SIZE entry in constexprs so the wrapper - # marks it constexpr in the kernel signature. The autotune loop - # overrides it per config; on the old-JAX fallback path the caller's - # value wins (utils.py merges {**first_cfg.kwargs, **constexprs}), - # which is why we pass the smallest config's BLOCK_SIZE explicitly. block_size = _get_min_block_size(_sort_chunks_by_map_kernel) # Declare input_output_aliases so XLA knows output slot 0 is claimed by diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index b4d7865e87..332bc6ddb7 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -439,6 +439,11 @@ def grid(meta): tensor_arg_names = [n for n in arg_names if n not in constexpr_names] signature = {n: get_triton_dtype(a) for n, a in zip(tensor_arg_names, all_avals)} + assert callable(grid) or isinstance(grid, tuple), ( + "Argument 'grid' must be a tuple or a callable but received: " + f"type={type(grid)}, value={grid}" + ) + # Normalize grid to 3D. When `grid` is a callable, defer evaluation until # we know the per-config meta (so each autotune config gets its own grid, # matching jax-triton's behavior). @@ -457,12 +462,16 @@ def _normalize_grid(grid_tuple): else: grid_tuple = None # evaluated per-config below - # Default values for the kernel + # Default kernel launch parameters. These apply to non-autotuned kernels + # and as a fallback when an autotuned config doesn't specify them. Values + # match Triton's own `triton.Config` defaults (num_warps=4, num_stages=3, + # num_ctas=1) and jax-triton's `get_or_create_triton_kernel`. Using a + # larger default (e.g. num_warps=32) over-provisions threads per block, + # which slashes SM occupancy on non-autotuned kernels — measured as an 8× + # slowdown on `_make_chunk_sort_map_kernel` vs jax-triton. actual_kernel_fn = kernel_fn - num_warps = 32 - num_stages = ( - 1 # TODO(Phuong): consider if it is beneficial to expose num_warps, num_stages, num_ctas - ) + num_warps = 4 + num_stages = 3 num_ctas = 1 kernel_constexprs = constexprs if constexprs is not None else {}