SM80 CuTe DSL KDA Support [WIP] #90
Conversation
- 3 CuTe DSL MMA kernels: QK, delta-H (WH), fwd-O (QH) - Host-side algorithm matching FLA KDA: asymmetric gating, block triangular solve - Persistent JIT caching via CUTE_DSL_CACHE_DIR - Validation: max diff ~0.87 vs FLA Triton (bf16 noise floor)
There was a problem hiding this comment.
Code Review
This pull request adds Ampere (SM80) fully-fused KDA forward prefill support, introducing CuTe DSL SM80 kernels for QK, delta-H, and fused output operations alongside test and benchmark scripts. However, several critical issues must be addressed: the kda_fo and kda_dh kernel calls in ampere_fused_fwd.py are missing required arguments (scale and decay), and kda_dh incorrectly overwrites h_state. Additionally, multi-batch and multi-head support is broken due to ignored batch_idx variables in kda_fused_fwd_sm80.py and hardcoded batch/head indices in chunk_delta_h_sm80.py and fwd_o_sm80.py. Finally, please resolve potential reshaping errors when sequence length T is not a multiple of chunk size, remove accidentally committed garbage files, and replace hardcoded absolute paths in the test scripts with relative paths.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| fo_ch, (B, ct, H, V_dim, K_dim), stream) | ||
| torch.cuda.synchronize() |
There was a problem hiding this comment.
The call to kda_fo is missing the scale argument. Since scale is a required positional argument in FwdOSM80.__call__, passing stream as the 6th argument will bind it to scale, leaving the actual stream argument missing. This will cause a runtime/compilation error.
kda_fo(q_gated[t_beg:t_end].contiguous(), h_state, g_zero[t_beg:t_end].contiguous(),\n fo_ch, (B, ct, H, V_dim, K_dim), scale, stream)| Ai_full[3*BC:4*BC, 2*BC:3*BC] = -Ai_blocks[3] @ L[3*BC:4*BC, 2*BC:3*BC] @ Ai_blocks[2] | ||
| w_ch = (Ai_full @ (k_chunk * bf.mean(dim=1, keepdim=True) * torch.exp(g_chunk * LN2))).unsqueeze(1).to(torch.bfloat16).contiguous() | ||
|
|
||
| kda_dh(kg_ch, w_ch, zg_ch, h_state, h_state, beta, (B, ct, H, V_dim, K_dim), stream) |
There was a problem hiding this comment.
There are two critical issues with this kda_dh call:\n1. State Overwriting: Passing h_state as both the input state h and the output o causes kda_dh to overwrite the first 64 columns of h_state with WH_acc, destroying the original state values. When h_state += update is called later, it adds the update to WH_acc instead of the original h_state, leading to incorrect state updates. Please allocate a temporary buffer for wh_acc and pass it as the output tensor to kda_dh.\n2. Missing decay Argument: The call is missing the decay argument. Since decay is a required positional argument in ChunkDeltaHFwdSM80.__call__, passing stream as the 8th argument will bind it to decay, leaving the actual stream argument missing. This will cause a runtime/compilation error.
| sQ[(row, k_idx, 0)] = q[(chunk_idx * BT + row, head_idx, k_idx)] | ||
| sK[(row, k_idx, 0)] = k[(chunk_idx * BT + row, head_idx, k_idx)] |
There was a problem hiding this comment.
Although batch_idx is retrieved from cute.arch.block_idx(), it is never used to index q or k. As a result, all batches will read from the memory corresponding to batch 0, causing incorrect results and race conditions when B > 1. Please incorporate batch_idx into the sequence dimension indexing.
sQ[(row, k_idx, 0)] = q[(batch_idx * S + chunk_idx * BT + row, head_idx, k_idx)]\n sK[(row, k_idx, 0)] = k[(batch_idx * S + chunk_idx * BT + row, head_idx, k_idx)]| row = coord[0] | ||
| col = coord[1] | ||
| val = tCsC[i] | ||
| o[(chunk_idx * BT + row, head_idx, col)] = BFloat16(Float32(val)) |
There was a problem hiding this comment.
The output tensor o is indexed using chunk_idx * BT + row, which completely ignores batch_idx. As a result, all batches will write to the memory corresponding to batch 0, causing race conditions and incorrect results when B > 1. Please incorporate batch_idx into the sequence dimension indexing.
| o[(chunk_idx * BT + row, head_idx, col)] = BFloat16(Float32(val)) | |
| o[(batch_idx * S + chunk_idx * BT + row, head_idx, col)] = BFloat16(Float32(val)) |
| for row in cutlass.range_constexpr(BV): | ||
| for i in cutlass.range_constexpr(4): | ||
| ki = i*32+lane_id | ||
| sH[(row,ki,0)] = BFloat16(h_state[(0,0,v_offset+row,ki)]) |
There was a problem hiding this comment.
|
|
||
| SSUUMMMMAARRYY OOFF LLEESSSS CCOOMMMMAANNDDSS |
|
|
||
| SSUUMMMMAARRYY OOFF LLEESSSS CCOOMMMMAANNDDSS |
| sys.path.insert(0, "/mnt/d/Programming/New folder (2)/cuLA") | ||
| sys.path.insert(0, "/mnt/d/Programming/New folder (2)/cuLA/third_party/flash-linear-attention") |
There was a problem hiding this comment.
| sys.path.insert(0, "/mnt/d/Programming/New folder (2)/cuLA") | ||
| sys.path.insert(0, "/mnt/d/Programming/New folder (2)/cuLA/third_party/flash-linear-attention") |
There was a problem hiding this comment.
| sys.path.insert(0, "/mnt/d/Programming/New folder (2)/cuLA") | ||
| sys.path.insert(0, "/mnt/d/Programming/New folder (2)/cuLA/third_party/flash-linear-attention") |
There was a problem hiding this comment.
📌 Description
right now only forward prefill has been completed, i am currently running benchmarks to ensure that it is within the threshold. Currently running the test on a single RTX3060 but I will probably use an A100 server to test out if possible
🚀 Pull Request Checklist
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
⚡ Performance
Reviewer Notes