Skip to content

Make cuTile intrinsics 1-based#155

Closed
maleadt wants to merge 1 commit intomainfrom
tb/normalize
Closed

Make cuTile intrinsics 1-based#155
maleadt wants to merge 1 commit intomainfrom
tb/normalize

Conversation

@maleadt
Copy link
Copy Markdown
Member

@maleadt maleadt commented Mar 30, 2026

This is an experiment.

Summary: Move 1-based to 0-based index lowering from the DSL (operations.jl) into a compiler rewrite pass (index_lower_pass!). The load/store/gather operations now pass through Julia's natural 1-based indices, and the pass inserts subi(elem, 1) for each index element in load_partition_view and store_partition_view calls during compilation.

The motivation here is to simplify the IR we emit (the layernorm example currently emits 4x as much SASS instructions as cuTile Python does). A large part of that IR cruft comes from the repetitive +1/-1 we do as part of the 0-based to 1-based index conversion. For example, bid(1) returns a 1-based index (via addi(blockId_x, 1)), then each load/store call emits its own subi(..., 1) to undo it, resulting in 3 redundant constant+subi pairs:

  %1 = addi %blockId_x, %cst_1_i32          // bid(1) = blockId_x + 1
  %2 = subi %1, %cst_1_i32_7                // load a: undo +1
  load_view_tko ... %pview[%2]
  %3 = subi %1, %cst_1_i32_9                // load b: undo +1 again
  load_view_tko ... %pview_8[%3]
  %5 = subi %1, %cst_1_i32_12               // store c: undo +1 again
  store_view_tko ... %pview_13[%5]

On this branch, the indices passed around are kept 1-based, and a pass converts them late at the load/store boundary, result in significantly simpler IR:

  %1 = addi %blockId_x, %cst_1_i32          // bid(1) = blockId_x + 1
  load_view_tko ... %pview[%1]              // pass lowered index, addi+subi cancelled
  load_view_tko ... %pview_7[%1]
  store_view_tko ... %pview_10[%1]

Although this is nice, it both doesn't improve performance, and makes it harder to compare our IR to cuTile Python's. So not sure we want this.

Move 1-based to 0-based index lowering from the DSL (operations.jl) into a
compiler rewrite pass (index_lower_pass!). The load/store/gather operations
now pass through Julia's natural 1-based indices, and the pass inserts
subi(elem, 1) for each index element in load_partition_view and
store_partition_view calls during compilation.
@maleadt
Copy link
Copy Markdown
Member Author

maleadt commented Apr 6, 2026

Algebraic simplifications have turned out powerful enough to not need this.

@maleadt maleadt closed this Apr 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant