Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,8 @@ def _gdn_decode_kernel(
conv_state_indices=infer_state.b_buffer_idx,
)

# Recurrent processing with fused gating
# FusedRecurrentFunction.forward calls .contiguous() on q/k/v/a/b internally
# Recurrent processing with fused gating; the kernel reads the
# q/k/v/a/b column views directly via per-token strides (no copies)
query, key, value = self._rearrange_mixed_qkv(mixed_qkv, decode=True)
core_attn_out, _ = fused_recurrent_gated_delta_rule(
q=query,
Expand Down
91 changes: 74 additions & 17 deletions lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
V: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
stride_q_tok: tl.constexpr,
stride_k_tok: tl.constexpr,
stride_v_tok: tl.constexpr,
stride_a_tok: tl.constexpr,
stride_b_tok: tl.constexpr,
stride_init_state_token: tl.constexpr,
stride_final_state_token: tl.constexpr,
stride_indices_seq: tl.constexpr,
Expand Down Expand Up @@ -94,15 +99,15 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
o_k = i_k * BK + tl.arange(0, BK)
o_v = i_v * BV + tl.arange(0, BV)

p_q = q + (bos * H + i_h) * K + o_k
p_k = k + (bos * H + i_h) * K + o_k
p_v = v + (bos * HV + i_hv) * V + o_v
p_q = q + bos * stride_q_tok + i_h * K + o_k
p_k = k + bos * stride_k_tok + i_h * K + o_k
p_v = v + bos * stride_v_tok + i_hv * V + o_v
if FUSE_GATING:
# Fused gating: load per-head constants once, compute g/beta inline per token
b_A_log = tl.load(A_log + i_hv).to(tl.float32)
b_dt_bias = tl.load(dt_bias + i_hv).to(tl.float32)
p_a_raw = a_raw + bos * HV + i_hv
p_b_raw = b_raw + bos * HV + i_hv
p_a_raw = a_raw + bos * stride_a_tok + i_hv
p_b_raw = b_raw + bos * stride_b_tok + i_hv
else:
if IS_BETA_HEADWISE:
p_beta = beta + (bos * HV + i_hv) * V + o_v
Expand Down Expand Up @@ -193,13 +198,13 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)

p_q += H * K
p_k += H * K
p_q += stride_q_tok
p_k += stride_k_tok
p_o += HV * V
p_v += HV * V
p_v += stride_v_tok
if FUSE_GATING:
p_a_raw += HV
p_b_raw += HV
p_a_raw += stride_a_tok
p_b_raw += stride_b_tok
else:
if not IS_KDA:
p_g += HV
Expand All @@ -208,6 +213,43 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
p_beta += HV * (V if IS_BETA_HEADWISE else 1)


def _token_stride(x: torch.Tensor, inner_numel: int, cu_seqlens) -> int:
"""Per-token element stride of x addressed as [tokens, ...inner dims...].

The kernel reads token ``i`` at ``base + i * token_stride`` with the inner
dims packed, which supports column views of one wider projection output
(token stride larger than inner_numel). Returns -1 if x's layout cannot be
addressed that way (caller must fall back to .contiguous()).
"""
if x.dim() == 2:
# [tokens, inner] (a_raw / b_raw)
return x.stride(0) if x.stride(1) == 1 else -1
# 4D q/k/v
if cu_seqlens is not None:
# varlen layout [1, tokens, head, dim]
token_dim = 1
elif x.shape[1] == 1:
# decode layout [tokens, 1, head, dim]
token_dim = 0
else:
# [B, T>1, head, dim]: a single token stride only exists if contiguous
return inner_numel if x.is_contiguous() else -1
if x.stride(-1) == 1 and x.stride(-2) == x.shape[-1]:
return x.stride(token_dim)
return -1


def _ensure_token_strided(x: torch.Tensor, inner_numel: int, cu_seqlens):
"""Return (tensor, token_stride); copies to contiguous only when needed."""
if x is None:
return None, 0
stride = _token_stride(x, inner_numel, cu_seqlens)
if stride < 0:
x = x.contiguous()
stride = inner_numel
return x, stride


def fused_recurrent_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
Expand All @@ -232,6 +274,11 @@ def fused_recurrent_gated_delta_rule_fwd(
B, T, H, K, V = *k.shape, v.shape[-1]
HV = v.shape[2]
N = B if cu_seqlens is None else len(cu_seqlens) - 1
q, stride_q_tok = _ensure_token_strided(q, H * K, cu_seqlens)
k, stride_k_tok = _ensure_token_strided(k, H * K, cu_seqlens)
v, stride_v_tok = _ensure_token_strided(v, HV * V, cu_seqlens)
a_raw, stride_a_tok = _ensure_token_strided(a_raw, HV, cu_seqlens)
b_raw, stride_b_tok = _ensure_token_strided(b_raw, HV, cu_seqlens)
BK = triton.next_power_of_2(K)
if T == 1:
# Decode path: use larger BV to reduce kernel instances (4 blocks instead of 16)
Expand Down Expand Up @@ -261,20 +308,23 @@ def fused_recurrent_gated_delta_rule_fwd(
stride_init_state_token = initial_state.stride(0)
stride_final_state_token = final_state.stride(0)

# Strides for read indices
# Strides for read indices. The kernel advances along a row with `+ i_t`
# (token stride 1), so 2D index tensors must have contiguous rows.
if ssm_state_indices is None:
stride_indices_seq, stride_indices_tok = 1, 1
elif ssm_state_indices.ndim == 1:
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
else:
assert ssm_state_indices.stride(-1) == 1, "2D ssm_state_indices must have contiguous rows"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using assert statements for runtime validation of tensor properties can be risky because assertions are stripped out when Python is run with optimization flags (-O). If these checks are bypassed, it could lead to silent correctness issues or out-of-bounds memory accesses in the Triton kernel. It is safer to raise a ValueError instead.

Suggested change
assert ssm_state_indices.stride(-1) == 1, "2D ssm_state_indices must have contiguous rows"
if ssm_state_indices.stride(-1) != 1:
raise ValueError("2D ssm_state_indices must have contiguous rows")

stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()

# Strides for write indices (if provided)
# Strides for write indices (if provided); same contiguous-row requirement
if ssm_state_write_indices is None:
stride_write_indices_seq, stride_write_indices_tok = 1, 1
elif ssm_state_write_indices.ndim == 1:
stride_write_indices_seq, stride_write_indices_tok = ssm_state_write_indices.stride(0), 1
else:
assert ssm_state_write_indices.stride(-1) == 1, "2D ssm_state_write_indices must have contiguous rows"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using assert statements for runtime validation of tensor properties can be risky because assertions are stripped out when Python is run with optimization flags (-O). If these checks are bypassed, it could lead to silent correctness issues or out-of-bounds memory accesses in the Triton kernel. It is safer to raise a ValueError instead.

Suggested change
assert ssm_state_write_indices.stride(-1) == 1, "2D ssm_state_write_indices must have contiguous rows"
if ssm_state_write_indices.stride(-1) != 1:
raise ValueError("2D ssm_state_write_indices must have contiguous rows")

stride_write_indices_seq, stride_write_indices_tok = ssm_state_write_indices.stride()

grid = (NK, NV, N * HV)
Expand Down Expand Up @@ -305,6 +355,11 @@ def fused_recurrent_gated_delta_rule_fwd(
V=V,
BK=BK,
BV=BV,
stride_q_tok=stride_q_tok,
stride_k_tok=stride_k_tok,
stride_v_tok=stride_v_tok,
stride_a_tok=stride_a_tok,
stride_b_tok=stride_b_tok,
stride_init_state_token=stride_init_state_token,
stride_final_state_token=stride_final_state_token,
stride_indices_seq=stride_indices_seq,
Expand Down Expand Up @@ -348,10 +403,12 @@ def forward(
b_raw: torch.Tensor | None = None,
out: torch.Tensor | None = None,
):
# q/k/v/a_raw/b_raw may be non-contiguous column views of one projection
# output; the kernel handles them via per-token strides (no copies).
o, final_state = fused_recurrent_gated_delta_rule_fwd(
q=q.contiguous(),
k=k.contiguous(),
v=v.contiguous(),
q=q,
k=k,
v=v,
g=g.contiguous() if g is not None else None,
beta=beta.contiguous() if beta is not None else None,
scale=scale,
Expand All @@ -364,8 +421,8 @@ def forward(
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
A_log=A_log,
dt_bias=dt_bias,
a_raw=a_raw.contiguous() if a_raw is not None else None,
b_raw=b_raw.contiguous() if b_raw is not None else None,
a_raw=a_raw,
b_raw=b_raw,
out=out,
)

Expand Down
125 changes: 125 additions & 0 deletions unit_tests/models/qwen3next/test_fused_recurrent_strided.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import pytest
import torch

from lightllm.models.qwen3next.triton_kernel.fla.ops.fused_recurrent import (
fused_recurrent_gated_delta_rule,
)

if not torch.cuda.is_available():
pytest.skip("CUDA required", allow_module_level=True)


@pytest.mark.parametrize("batch", [1, 2, 16])
def test_decode_strided_views_match_contiguous(batch):
"""q/k/v/a/b passed as column views of one projection output (the decode
path layout) must produce the same result as contiguous copies."""
torch.manual_seed(0)
H, HV, K, V = 2, 8, 128, 128
key_dim, value_dim = H * K, HV * V
qkv_dim = 2 * key_dim + value_dim
total_dim = qkv_dim + value_dim + 2 * HV # qkv + z + b + a
cache_slots = 64

mixed = torch.randn(batch, total_dim, device="cuda", dtype=torch.bfloat16)
mixed_qkv = mixed[:, :qkv_dim]
b_raw = mixed[:, qkv_dim + value_dim : qkv_dim + value_dim + HV]
a_raw = mixed[:, qkv_dim + value_dim + HV :]

query, key, value = torch.split(mixed_qkv, [key_dim, key_dim, value_dim], dim=-1)
q = query.view(batch, 1, H, K)
k = key.view(batch, 1, H, K)
v = value.view(batch, 1, HV, V)

A_log = torch.randn(HV, device="cuda", dtype=torch.float32) * 0.1
dt_bias = torch.randn(HV, device="cuda", dtype=torch.float32) * 0.1
ssm_state = torch.randn(cache_slots, HV, K, V, device="cuda", dtype=torch.bfloat16)
idx = torch.randperm(cache_slots, device="cuda")[:batch].to(torch.int32)

def run(q_, k_, v_, a_, b_, state):
out, _ = fused_recurrent_gated_delta_rule(
q=q_,
k=k_,
v=v_,
initial_state=state,
inplace_final_state=True,
ssm_state_indices=idx,
use_qk_l2norm_in_kernel=True,
A_log=A_log,
dt_bias=dt_bias,
a_raw=a_,
b_raw=b_,
)
return out

state_ref = ssm_state.clone()
out_ref = run(q.contiguous(), k.contiguous(), v.contiguous(), a_raw.contiguous(), b_raw.contiguous(), state_ref)
state_strided = ssm_state.clone()
out_strided = run(q, k, v, a_raw, b_raw, state_strided)

assert torch.equal(out_ref, out_strided)
assert torch.equal(state_ref, state_strided)


def test_varlen_strided_views_match_contiguous():
"""Varlen layout [1, tokens, H, K] with column-view inputs."""
torch.manual_seed(1)
H, HV, K, V = 2, 8, 128, 128
key_dim, value_dim = H * K, HV * V
qkv_dim = 2 * key_dim + value_dim
total_dim = qkv_dim + value_dim + 2 * HV
seqlens = [3, 5, 1]
tokens = sum(seqlens)
cu = torch.tensor([0, 3, 8, 9], device="cuda", dtype=torch.long)

mixed = torch.randn(tokens, total_dim, device="cuda", dtype=torch.bfloat16)
mixed_qkv = mixed[:, :qkv_dim]
b_raw = mixed[:, qkv_dim + value_dim : qkv_dim + value_dim + HV]
a_raw = mixed[:, qkv_dim + value_dim + HV :]
query, key, value = torch.split(mixed_qkv, [key_dim, key_dim, value_dim], dim=-1)
q = query.view(1, tokens, H, K)
k = key.view(1, tokens, H, K)
v = value.view(1, tokens, HV, V)

A_log = torch.randn(HV, device="cuda", dtype=torch.float32) * 0.1
dt_bias = torch.randn(HV, device="cuda", dtype=torch.float32) * 0.1
# ssm_state_indices is required: the non-continuous-batching varlen branch
# indexes h0 by token offset (bos) instead of sequence index, reading out
# of bounds for any sequence after the first (latent upstream bug; all
# production call sites pass ssm_state_indices). With inplace_final_state
# the kernel writes a state per token, so indices are 2D [N, max_seqlen]
# mapping each (seq, token) to a distinct slot; the seq's initial state is
# read from its token-0 slot.
max_len = max(seqlens)
idx = torch.zeros(len(seqlens), max_len, device="cuda", dtype=torch.int32)
slot = 0
for i, sl in enumerate(seqlens):
idx[i, :sl] = torch.arange(slot, slot + sl, device="cuda", dtype=torch.int32)
slot += sl
init_state = torch.randn(tokens, HV, K, V, device="cuda", dtype=torch.bfloat16)

def run(q_, k_, v_, a_, b_):
state = init_state.clone()
out, _ = fused_recurrent_gated_delta_rule(
q=q_,
k=k_,
v=v_,
initial_state=state,
inplace_final_state=True,
cu_seqlens=cu,
ssm_state_indices=idx,
use_qk_l2norm_in_kernel=True,
A_log=A_log,
dt_bias=dt_bias,
a_raw=a_,
b_raw=b_,
)
return out, state

out_ref, final_ref = run(q.contiguous(), k.contiguous(), v.contiguous(), a_raw.contiguous(), b_raw.contiguous())
out_strided, final_strided = run(q, k, v, a_raw, b_raw)
assert torch.equal(out_ref, out_strided)
assert torch.equal(final_ref, final_strided)


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading