Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions transformer_engine/jax/triton_extensions/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,10 +589,13 @@ def lowering(
probs_stride_token = 0
probs_stride_expert = 0

# Grid function equivalent: (num_tokens, cdiv(hidden_size, BLOCK_SIZE))
# Use minimum BLOCK_SIZE from autotune configs to ensure grid covers all elements
# We use BLOCK_SIZE in the grid calculation to ensure the grid is the
# proper size. If the grid size is an overestimate it can significantly
# hurt performance.
def grid(meta):
return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"]))

block_size = _get_min_block_size(_permute_kernel)
grid = (num_tokens, triton.cdiv(hidden_size, block_size))

# Use input_output_aliases to alias pre-zeroed buffers to outputs.
# This ensures padding positions contain zeros since the kernel only writes valid positions.
Expand Down Expand Up @@ -997,9 +1000,13 @@ def lowering(
unpermuted_probs_stride_token = num_experts
unpermuted_probs_stride_expert = 1

# Grid - use minimum BLOCK_SIZE from autotune configs
# We use BLOCK_SIZE in the grid calculation to ensure the grid is the
# proper size. If the grid size is an overestimate it can significantly
# hurt performance.
def grid(meta):
return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"]))

block_size = _get_min_block_size(_unpermute_kernel)
grid = (num_tokens, triton.cdiv(hidden_size, block_size))

return triton_call_lowering(
ctx,
Expand Down Expand Up @@ -1720,9 +1727,13 @@ def lowering(
probs_stride_token = 1
permuted_probs_stride_token = 1

# Grid - use minimum BLOCK_SIZE from autotune configs
# We use BLOCK_SIZE in the grid calculation to ensure the grid is the
# proper size. If the grid size is an overestimate it can significantly
# hurt performance.
def grid(meta):
return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"]))

block_size = _get_min_block_size(_sort_chunks_by_map_kernel)
grid = (num_tokens, triton.cdiv(hidden_size, block_size))

# Declare input_output_aliases so XLA knows output slot 0 is claimed by
# input 3 (output_buf). This prevents XLA from implicitly aliasing any
Expand Down
93 changes: 66 additions & 27 deletions transformer_engine/jax/triton_extensions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,16 @@ def triton_call_lowering(
ctx: MLIR lowering context
kernel_fn: Triton kernel function
*array_args: Input arrays (from ctx)
grid: Grid dimensions (int or tuple)
grid: Grid dimensions. May be either:
- an int or tuple (fixed grid for every config), or
- a callable ``meta -> int|tuple`` (evaluated per autotune config).

Use the callable form for autotuned kernels whose grid depends on
``BLOCK_SIZE`` (or any other autotuned constexpr); otherwise the
launch grid will not match the autotuner-selected config and the
kernel will either over-launch (waste) or under-cover. ``meta`` is
the merged dict ``{**constexprs, **config.kwargs}`` for the chosen
config — the same convention as jax-triton's ``triton_call``.
input_output_aliases: Mapping of input to output aliases
constexprs: Compile-time constants for the kernel. This includes both
tl.constexpr arguments AND scalar runtime arguments (like
Expand All @@ -404,13 +413,12 @@ def triton_call_lowering(
def lowering(ctx, x, *, block_size):
from ..triton_extensions import triton_call_lowering
n = ctx.avals_in[0].size

def grid(meta):
return (triton.cdiv(n, meta["BLOCK_SIZE"]),)

return triton_call_lowering(
ctx, my_kernel, x,
grid=(triton.cdiv(n, block_size),),
constexprs={
"n_elements": n, # scalar arg (not tl.constexpr in kernel)
"BLOCK_SIZE": block_size, # tl.constexpr arg
},
ctx, my_kernel, x, grid=grid, constexprs={"n_elements": n},
)
"""
# Get compute capability using gpu_triton
Expand All @@ -431,22 +439,39 @@ def lowering(ctx, x, *, block_size):
tensor_arg_names = [n for n in arg_names if n not in constexpr_names]
signature = {n: get_triton_dtype(a) for n, a in zip(tensor_arg_names, all_avals)}

# Normalize grid to 3D
if isinstance(grid, int):
grid_tuple = (grid, 1, 1)
elif len(grid) == 1:
grid_tuple = (grid[0], 1, 1)
elif len(grid) == 2:
grid_tuple = (grid[0], grid[1], 1)
else:
grid_tuple = grid[:3]
assert callable(grid) or isinstance(grid, tuple), (
"Argument 'grid' must be a tuple or a callable but received: "
f"type={type(grid)}, value={grid}"
)

# Default values for the kernel
# Normalize grid to 3D. When `grid` is a callable, defer evaluation until
# we know the per-config meta (so each autotune config gets its own grid,
# matching jax-triton's behavior).
def _normalize_grid(grid_tuple):
if isinstance(grid_tuple, int):
return (grid_tuple, 1, 1)
if len(grid_tuple) == 1:
return (grid_tuple[0], 1, 1)
if len(grid_tuple) == 2:
return (grid_tuple[0], grid_tuple[1], 1)
return tuple(grid_tuple[:3])

Comment thread
tdophung marked this conversation as resolved.
grid_callable = grid if callable(grid) else None
if grid_callable is None:
grid_tuple = _normalize_grid(grid)
else:
grid_tuple = None # evaluated per-config below

# Default kernel launch parameters. These apply to non-autotuned kernels
# and as a fallback when an autotuned config doesn't specify them. Values
# match Triton's own `triton.Config` defaults (num_warps=4, num_stages=3,
# num_ctas=1) and jax-triton's `get_or_create_triton_kernel`. Using a
# larger default (e.g. num_warps=32) over-provisions threads per block,
# which slashes SM occupancy on non-autotuned kernels — measured as an 8×
# slowdown on `_make_chunk_sort_map_kernel` vs jax-triton.
actual_kernel_fn = kernel_fn
num_warps = 32
num_stages = (
1 # TODO(Phuong): consider if it is beneficial to expose num_warps, num_stages, num_ctas
)
num_warps = 4
num_stages = 3
num_ctas = 1
kernel_constexprs = constexprs if constexprs is not None else {}

Expand Down Expand Up @@ -510,11 +535,18 @@ def lowering(ctx, x, *, block_size):
for _ in list(ctx.avals_in) + list(ctx.avals_out):
config_params.append(gpu_triton.create_array_parameter(0, 16))

# Per-config grid: evaluate `grid(meta)` if grid is a callable so
# the launch shape matches this config's BLOCK_SIZE (etc.).
if grid_callable is not None:
config_grid = _normalize_grid(grid_callable(config_constexprs))
else:
config_grid = grid_tuple

config_call = gpu_triton.TritonKernelCall(
config_kernel,
grid_tuple[0],
grid_tuple[1],
grid_tuple[2],
config_grid[0],
config_grid[1],
config_grid[2],
config_params,
)

Expand Down Expand Up @@ -571,11 +603,18 @@ def lowering(ctx, x, *, block_size):
for _ in list(ctx.avals_in) + list(ctx.avals_out):
kernel_params.append(gpu_triton.create_array_parameter(0, 16))

# Non-autotuned dispatch: evaluate `grid(meta)` once with the merged
# constexprs (which already reflect the single config we'll launch).
if grid_callable is not None:
single_grid = _normalize_grid(grid_callable(kernel_constexprs))
else:
single_grid = grid_tuple

kernel_call = gpu_triton.TritonKernelCall(
kernel,
grid_tuple[0],
grid_tuple[1],
grid_tuple[2],
single_grid[0],
single_grid[1],
single_grid[2],
kernel_params,
)

Expand Down
Loading