diff --git a/transformer_engine/jax/triton_extensions/permutation.py b/transformer_engine/jax/triton_extensions/permutation.py index 98c54e52bb..22f983f078 100644 --- a/transformer_engine/jax/triton_extensions/permutation.py +++ b/transformer_engine/jax/triton_extensions/permutation.py @@ -589,10 +589,13 @@ def lowering( probs_stride_token = 0 probs_stride_expert = 0 - # Grid function equivalent: (num_tokens, cdiv(hidden_size, BLOCK_SIZE)) - # Use minimum BLOCK_SIZE from autotune configs to ensure grid covers all elements + # We use BLOCK_SIZE in the grid calculation to ensure the grid is the + # proper size. If the grid size is an overestimate it can significantly + # hurt performance. + def grid(meta): + return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) + block_size = _get_min_block_size(_permute_kernel) - grid = (num_tokens, triton.cdiv(hidden_size, block_size)) # Use input_output_aliases to alias pre-zeroed buffers to outputs. # This ensures padding positions contain zeros since the kernel only writes valid positions. @@ -997,9 +1000,13 @@ def lowering( unpermuted_probs_stride_token = num_experts unpermuted_probs_stride_expert = 1 - # Grid - use minimum BLOCK_SIZE from autotune configs + # We use BLOCK_SIZE in the grid calculation to ensure the grid is the + # proper size. If the grid size is an overestimate it can significantly + # hurt performance. + def grid(meta): + return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) + block_size = _get_min_block_size(_unpermute_kernel) - grid = (num_tokens, triton.cdiv(hidden_size, block_size)) return triton_call_lowering( ctx, @@ -1720,9 +1727,13 @@ def lowering( probs_stride_token = 1 permuted_probs_stride_token = 1 - # Grid - use minimum BLOCK_SIZE from autotune configs + # We use BLOCK_SIZE in the grid calculation to ensure the grid is the + # proper size. If the grid size is an overestimate it can significantly + # hurt performance. + def grid(meta): + return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) + block_size = _get_min_block_size(_sort_chunks_by_map_kernel) - grid = (num_tokens, triton.cdiv(hidden_size, block_size)) # Declare input_output_aliases so XLA knows output slot 0 is claimed by # input 3 (output_buf). This prevents XLA from implicitly aliasing any diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 2a86321c34..332bc6ddb7 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -390,7 +390,16 @@ def triton_call_lowering( ctx: MLIR lowering context kernel_fn: Triton kernel function *array_args: Input arrays (from ctx) - grid: Grid dimensions (int or tuple) + grid: Grid dimensions. May be either: + - an int or tuple (fixed grid for every config), or + - a callable ``meta -> int|tuple`` (evaluated per autotune config). + + Use the callable form for autotuned kernels whose grid depends on + ``BLOCK_SIZE`` (or any other autotuned constexpr); otherwise the + launch grid will not match the autotuner-selected config and the + kernel will either over-launch (waste) or under-cover. ``meta`` is + the merged dict ``{**constexprs, **config.kwargs}`` for the chosen + config — the same convention as jax-triton's ``triton_call``. input_output_aliases: Mapping of input to output aliases constexprs: Compile-time constants for the kernel. This includes both tl.constexpr arguments AND scalar runtime arguments (like @@ -404,13 +413,12 @@ def triton_call_lowering( def lowering(ctx, x, *, block_size): from ..triton_extensions import triton_call_lowering n = ctx.avals_in[0].size + + def grid(meta): + return (triton.cdiv(n, meta["BLOCK_SIZE"]),) + return triton_call_lowering( - ctx, my_kernel, x, - grid=(triton.cdiv(n, block_size),), - constexprs={ - "n_elements": n, # scalar arg (not tl.constexpr in kernel) - "BLOCK_SIZE": block_size, # tl.constexpr arg - }, + ctx, my_kernel, x, grid=grid, constexprs={"n_elements": n}, ) """ # Get compute capability using gpu_triton @@ -431,22 +439,39 @@ def lowering(ctx, x, *, block_size): tensor_arg_names = [n for n in arg_names if n not in constexpr_names] signature = {n: get_triton_dtype(a) for n, a in zip(tensor_arg_names, all_avals)} - # Normalize grid to 3D - if isinstance(grid, int): - grid_tuple = (grid, 1, 1) - elif len(grid) == 1: - grid_tuple = (grid[0], 1, 1) - elif len(grid) == 2: - grid_tuple = (grid[0], grid[1], 1) - else: - grid_tuple = grid[:3] + assert callable(grid) or isinstance(grid, tuple), ( + "Argument 'grid' must be a tuple or a callable but received: " + f"type={type(grid)}, value={grid}" + ) - # Default values for the kernel + # Normalize grid to 3D. When `grid` is a callable, defer evaluation until + # we know the per-config meta (so each autotune config gets its own grid, + # matching jax-triton's behavior). + def _normalize_grid(grid_tuple): + if isinstance(grid_tuple, int): + return (grid_tuple, 1, 1) + if len(grid_tuple) == 1: + return (grid_tuple[0], 1, 1) + if len(grid_tuple) == 2: + return (grid_tuple[0], grid_tuple[1], 1) + return tuple(grid_tuple[:3]) + + grid_callable = grid if callable(grid) else None + if grid_callable is None: + grid_tuple = _normalize_grid(grid) + else: + grid_tuple = None # evaluated per-config below + + # Default kernel launch parameters. These apply to non-autotuned kernels + # and as a fallback when an autotuned config doesn't specify them. Values + # match Triton's own `triton.Config` defaults (num_warps=4, num_stages=3, + # num_ctas=1) and jax-triton's `get_or_create_triton_kernel`. Using a + # larger default (e.g. num_warps=32) over-provisions threads per block, + # which slashes SM occupancy on non-autotuned kernels — measured as an 8× + # slowdown on `_make_chunk_sort_map_kernel` vs jax-triton. actual_kernel_fn = kernel_fn - num_warps = 32 - num_stages = ( - 1 # TODO(Phuong): consider if it is beneficial to expose num_warps, num_stages, num_ctas - ) + num_warps = 4 + num_stages = 3 num_ctas = 1 kernel_constexprs = constexprs if constexprs is not None else {} @@ -510,11 +535,18 @@ def lowering(ctx, x, *, block_size): for _ in list(ctx.avals_in) + list(ctx.avals_out): config_params.append(gpu_triton.create_array_parameter(0, 16)) + # Per-config grid: evaluate `grid(meta)` if grid is a callable so + # the launch shape matches this config's BLOCK_SIZE (etc.). + if grid_callable is not None: + config_grid = _normalize_grid(grid_callable(config_constexprs)) + else: + config_grid = grid_tuple + config_call = gpu_triton.TritonKernelCall( config_kernel, - grid_tuple[0], - grid_tuple[1], - grid_tuple[2], + config_grid[0], + config_grid[1], + config_grid[2], config_params, ) @@ -571,11 +603,18 @@ def lowering(ctx, x, *, block_size): for _ in list(ctx.avals_in) + list(ctx.avals_out): kernel_params.append(gpu_triton.create_array_parameter(0, 16)) + # Non-autotuned dispatch: evaluate `grid(meta)` once with the merged + # constexprs (which already reflect the single config we'll launch). + if grid_callable is not None: + single_grid = _normalize_grid(grid_callable(kernel_constexprs)) + else: + single_grid = grid_tuple + kernel_call = gpu_triton.TritonKernelCall( kernel, - grid_tuple[0], - grid_tuple[1], - grid_tuple[2], + single_grid[0], + single_grid[1], + single_grid[2], kernel_params, )