Skip to content

SM80 CuTe DSL KDA Support [WIP] #90

Open
harvestingmoon wants to merge 1 commit into
inclusionAI:mainfrom
harvestingmoon:feature/sm80
Open

SM80 CuTe DSL KDA Support [WIP] #90
harvestingmoon wants to merge 1 commit into
inclusionAI:mainfrom
harvestingmoon:feature/sm80

Conversation

@harvestingmoon

Copy link
Copy Markdown

📌 Description

  • 3 CuTe DSL MMA kernels: QK, delta-H (WH), fwd-O (QH)
  • Host-side algorithm matching FLA KDA: asymmetric gating, block triangular solve
  • Persistent JIT caching via CUTE_DSL_CACHE_DIR
  • Validation: max diff ~0.87 vs FLA Triton (bf16 noise floor)

right now only forward prefill has been completed, i am currently running benchmarks to ensure that it is within the threshold. Currently running the test on a single RTX3060 but I will probably use an A100 server to test out if possible

🚀 Pull Request Checklist

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing.

⚡ Performance

Reviewer Notes

- 3 CuTe DSL MMA kernels: QK, delta-H (WH), fwd-O (QH)
- Host-side algorithm matching FLA KDA: asymmetric gating, block triangular solve
- Persistent JIT caching via CUTE_DSL_CACHE_DIR
- Validation: max diff ~0.87 vs FLA Triton (bf16 noise floor)
@harvestingmoon harvestingmoon changed the title SM80 CuTe DSL KDA forward prefill SM80 CuTe DSL KDA Support Jun 11, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds Ampere (SM80) fully-fused KDA forward prefill support, introducing CuTe DSL SM80 kernels for QK, delta-H, and fused output operations alongside test and benchmark scripts. However, several critical issues must be addressed: the kda_fo and kda_dh kernel calls in ampere_fused_fwd.py are missing required arguments (scale and decay), and kda_dh incorrectly overwrites h_state. Additionally, multi-batch and multi-head support is broken due to ignored batch_idx variables in kda_fused_fwd_sm80.py and hardcoded batch/head indices in chunk_delta_h_sm80.py and fwd_o_sm80.py. Finally, please resolve potential reshaping errors when sequence length T is not a multiple of chunk size, remove accidentally committed garbage files, and replace hardcoded absolute paths in the test scripts with relative paths.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +221 to +222
fo_ch, (B, ct, H, V_dim, K_dim), stream)
torch.cuda.synchronize()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The call to kda_fo is missing the scale argument. Since scale is a required positional argument in FwdOSM80.__call__, passing stream as the 6th argument will bind it to scale, leaving the actual stream argument missing. This will cause a runtime/compilation error.

        kda_fo(q_gated[t_beg:t_end].contiguous(), h_state, g_zero[t_beg:t_end].contiguous(),\n                    fo_ch, (B, ct, H, V_dim, K_dim), scale, stream)

Ai_full[3*BC:4*BC, 2*BC:3*BC] = -Ai_blocks[3] @ L[3*BC:4*BC, 2*BC:3*BC] @ Ai_blocks[2]
w_ch = (Ai_full @ (k_chunk * bf.mean(dim=1, keepdim=True) * torch.exp(g_chunk * LN2))).unsqueeze(1).to(torch.bfloat16).contiguous()

kda_dh(kg_ch, w_ch, zg_ch, h_state, h_state, beta, (B, ct, H, V_dim, K_dim), stream)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There are two critical issues with this kda_dh call:\n1. State Overwriting: Passing h_state as both the input state h and the output o causes kda_dh to overwrite the first 64 columns of h_state with WH_acc, destroying the original state values. When h_state += update is called later, it adds the update to WH_acc instead of the original h_state, leading to incorrect state updates. Please allocate a temporary buffer for wh_acc and pass it as the output tensor to kda_dh.\n2. Missing decay Argument: The call is missing the decay argument. Since decay is a required positional argument in ChunkDeltaHFwdSM80.__call__, passing stream as the 8th argument will bind it to decay, leaving the actual stream argument missing. This will cause a runtime/compilation error.

Comment on lines +55 to +56
sQ[(row, k_idx, 0)] = q[(chunk_idx * BT + row, head_idx, k_idx)]
sK[(row, k_idx, 0)] = k[(chunk_idx * BT + row, head_idx, k_idx)]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Although batch_idx is retrieved from cute.arch.block_idx(), it is never used to index q or k. As a result, all batches will read from the memory corresponding to batch 0, causing incorrect results and race conditions when B > 1. Please incorporate batch_idx into the sequence dimension indexing.

            sQ[(row, k_idx, 0)] = q[(batch_idx * S + chunk_idx * BT + row, head_idx, k_idx)]\n            sK[(row, k_idx, 0)] = k[(batch_idx * S + chunk_idx * BT + row, head_idx, k_idx)]

row = coord[0]
col = coord[1]
val = tCsC[i]
o[(chunk_idx * BT + row, head_idx, col)] = BFloat16(Float32(val))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The output tensor o is indexed using chunk_idx * BT + row, which completely ignores batch_idx. As a result, all batches will write to the memory corresponding to batch 0, causing race conditions and incorrect results when B > 1. Please incorporate batch_idx into the sequence dimension indexing.

Suggested change
o[(chunk_idx * BT + row, head_idx, col)] = BFloat16(Float32(val))
o[(batch_idx * S + chunk_idx * BT + row, head_idx, col)] = BFloat16(Float32(val))

for row in cutlass.range_constexpr(BV):
for i in cutlass.range_constexpr(4):
ki = i*32+lane_id
sH[(row,ki,0)] = BFloat16(h_state[(0,0,v_offset+row,ki)])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The indexing of h_state has hardcoded 0, 0 for the batch and head dimensions (e.g., h_state[(0,0,v_offset+row,ki)]). This completely breaks multi-batch and multi-head support. You should map the grid to include batch and head indices, and use them to index h_state and o_state (line 89).

Comment on lines +1 to +2

SSUUMMMMAARRYY OOFF LLEESSSS CCOOMMMMAANNDDSS

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This file appears to be accidentally committed due to a shell command or git redirection error. It contains the help output of the less command. Please delete this file.

Comment on lines +1 to +2

SSUUMMMMAARRYY OOFF LLEESSSS CCOOMMMMAANNDDSS

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This file appears to be accidentally committed due to a shell command or git redirection error. It contains the help output of the less command. Please delete this file.

Comment thread test_delta_h_compile.py
Comment on lines +3 to +4
sys.path.insert(0, "/mnt/d/Programming/New folder (2)/cuLA")
sys.path.insert(0, "/mnt/d/Programming/New folder (2)/cuLA/third_party/flash-linear-attention")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Hardcoded absolute paths like "/mnt/d/Programming/New folder (2)/cuLA" are used to modify sys.path. This will break the tests on any other machine or CI/CD environment. Please use relative paths or rely on installing the package in editable mode (pip install -e .).

Comment thread test_qk_runtime.py
Comment on lines +3 to +4
sys.path.insert(0, "/mnt/d/Programming/New folder (2)/cuLA")
sys.path.insert(0, "/mnt/d/Programming/New folder (2)/cuLA/third_party/flash-linear-attention")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Hardcoded absolute paths like "/mnt/d/Programming/New folder (2)/cuLA" are used to modify sys.path. This will break the tests on any other machine or CI/CD environment. Please use relative paths or rely on installing the package in editable mode (pip install -e .).

Comment thread test_sm80_compile2.py
Comment on lines +3 to +4
sys.path.insert(0, "/mnt/d/Programming/New folder (2)/cuLA")
sys.path.insert(0, "/mnt/d/Programming/New folder (2)/cuLA/third_party/flash-linear-attention")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Hardcoded absolute paths like "/mnt/d/Programming/New folder (2)/cuLA" are used to modify sys.path. This will break the tests on any other machine or CI/CD environment. Please use relative paths or rely on installing the package in editable mode (pip install -e .).

@harvestingmoon harvestingmoon changed the title SM80 CuTe DSL KDA Support SM80 CuTe DSL KDA Support [WIP] Jun 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant