Skip to content

[TRTLLM-13265][feat] Fuse LTX-2 Gate + Residual + Norm + AdaLN modulation(ShiftScale) + Quant kernels#15102

Open
luyiyun1021 wants to merge 6 commits into
NVIDIA:mainfrom
luyiyun1021:ltx2-adaln-fusion
Open

[TRTLLM-13265][feat] Fuse LTX-2 Gate + Residual + Norm + AdaLN modulation(ShiftScale) + Quant kernels#15102
luyiyun1021 wants to merge 6 commits into
NVIDIA:mainfrom
luyiyun1021:ltx2-adaln-fusion

Conversation

@luyiyun1021

@luyiyun1021 luyiyun1021 commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator

Description

Fuses LTX-2 NVFP4 transformer block's AdaLN modulation patterns into a single templated CUDA kernel. Replaces torch.compile-generated Triton prep kernels plus standalone bf16-to-fp4 quantize kernels with one fused launch per modulation site. The kernel is shape-polymorphic across video and audio tokens and composes six behaviors via templates: residual add, gate weighting, RMSNorm, AdaLN shift_scale, single vs dual shift_scale output, and inline fp4 quantization with per-block scale factors.

Patterns Replaced

Each block modulation pattern becomes one cpp op (the modulating ones also have a _quant variant that emits FP4 + per-block SF inline):

Op Pattern Use site
fused_dit_rmsnorm_shift_scale[_quant] RmsNorm → shift_scale(shift, scale) KA: post-pre-norm
fused_dit_resid_rmsnorm_shift_scale_dual[_quant] Residual → RmsNorm → 2× independent shift_scale KB: attention out, dual branches
fused_dit_gate_resid_rmsnorm_shift_scale[_quant] Gate·Residual → RmsNorm → shift_scale KC: FFN input, gated
fused_dit_gate_resid_rmsnorm[_quant] Gate·Residual → RmsNorm (no shift_scale) KD: tail of block
fused_dit_gate_resid Gate·Residual (no RmsNorm, no shift_scale) FFN output gate (vx + ff(vx_scaled)·gate)

Design

One templated kernel fusedDiTGateResidNormShiftScaleKernel<D, ROWS_PER_BLOCK, BLOCK_SIZE, HAS_RESIDUAL, HAS_GATE, HAS_NORM, HAS_SHIFT_SCALE, NUM_OUT, HAS_QUANT> (cpp/tensorrt_llm/kernels/fusedDiTNormKernel.cu). Six compile-time flags compose the block-prefix pipeline; 9 explicit launchFusedDiTNorm instantiations cover every call site:

Variant RESID GATE NORM SHIFT_SCALE NUM_OUT dtypes Pattern / use site
KA 1 bf16, fp4 RmsNorm → shift_scale (post-pre-norm)
KB 2 bf16, fp4 resid → RmsNorm → 2× shift_scale (attn out)
KC 1 bf16, fp4 gate·resid → RmsNorm → shift_scale (FFN in)
KD 1 bf16, fp4 gate·resid → RmsNorm (block tail)
gate_resid 1 bf16 gate·resid (FFN out)

Modulators are passed as (table, ts) pair-form and composed inline in-kernel — fp32 table narrowed to bf16 then one __hadd2 — so no separate prep kernel runs. FP4 variants emit packed FP4 + 128×4 swizzled SF inline (no standalone quantize launch).

Tile + load tactics (auto-dispatched by shape — no env var, no caller knob)

launchFusedDiTNorm picks one of two tiles and one of two X-load mechanisms from (D, variant, token alignment) at launch:

Shape / variant Tile X load Why this tactic
video D=4096, KA/KB/KC Pipelined — 4-row/288t, warp-specialized, circular TMA TMA bulk large grid amortizes warp-spec setup; 4-row CTA reuses one modulator fetch ×4; multi-stage SMEM hides HBM latency
video D=4096, KD + gate_resid Default — 1-row/256t TMA bulk no shift_scale → compute-light/HBM-bound; warp-spec overhead doesn't pay, but a single bulk load still frees the LSU
audio D=2048, all Default — 1-row/256t cp.async grid too small (T≈252) and kernels are 2–4 µs → mbarrier setup cost exceeds cp.async's per-thread issue

Pipeline depth (pipelined only): bf16 = 3 stages (HBM-latency-bound, deeper hides more), fp4 = 2 stages (Phase-2 quant is compute-heavy; extra stages only add SMEM pressure).

attn load (default tile): also TMA-bulk when HAS_RESIDUAL && HAS_QUANT (KD-quant) — its load overlaps the heavy quant Phase 2; bf16 keeps per-thread LDG (already overlaps the TMA-X load on a separate LSU pipe).

Auto-dispatch predicate for the pipelined tile (all must hold): D==4096, variant has HAS_SHIFT_SCALE || !HAS_GATE (KA/KB/KC; KD and gate_resid fail it), and num_tokens % 4 == 0 && tokens_per_batch ≥ 4 && tokens_per_batch % 4 == 0. Default tile is 1-row/256t, chosen by NCU sweep (2 CTAs/SM at ≤80 regs, best per-element rate among {2r256, 1r256, 1r512}).

Pipelined vs default mechanics:

Default Pipelined
CTA 1 row × 256t 4 rows × 288t (__launch_bounds__(288,2))
Warps uniform 8 × 32 1 producer + 8 consumer (warp-specialized)
SMEM single-buffered TMA load circular 2–3 stages, mbarrier-driven
Sync __syncthreads() mbarrier.try_wait.parity + consumer-only bar.sync 1, 256
Modulator cache per-row per-batch, reused across 4 rows

Perf (CUDA-Graph, true SM time, video D=4096): pipelined is −10…−25% vs default on KA/KB/KC bf16, −11…−15% on fp4. KD and audio always use the default tile.

Per-kernel Microbenchmark

Pure GPU time via CUDA Graph capture + replay. V1 = (B=1, T=15360, D=4096), V2 = (B=2, T=15360, D=4096). cpp = this PR's fused kernel (auto-selected variant per the dispatch predicate above); compile = torch.compile of an equivalent eager expression that ends in torch.ops.trtllm.tunable_fp4_quantize with is_sf_swizzled_layout=True (matches the production Linear path). bf16 outputs verified by cosine ≥ 0.99 against an fp32 reference; FP4 outputs verified at the unit-test layer. Bench harness: tmp/feature/ltx2-adaln-refactor/bench_kabcd_vs_compile.py.

Variant Shape cpp (μs) compile (μs) speedup
KA-quant V1 73 117 1.60×
KA-quant V2 144 241 1.68×
KB-quant V1 146 231 1.59×
KB-quant V2 281 494 1.76×
KC-quant V1 96 149 1.55×
KC-quant V2 187 307 1.65×
KD-quant V1 71 129 1.82×
KD-quant V2 143 265 1.85×

End-to-end (LTX-2 single-stage, 1 GPU, audio+video, guidance_scale=1.0, 768×1280×121 frames, 40 denoise steps)

5-run average, warmup excluded. B200. Denoise per-step is the pipeline's own Denoising done loop timing (40 steps, excludes VAE decode + text encode); full E2E is the whole generate() wall including VAE + text encode.

OFF (torch.compile baseline) ON (this PR) Δ Gain
denoise loop, 40 steps (ms) 13746 13184 -562 ms -4.09%
denoise per-step (ms) 343.7 329.6 -14.1 ms -4.09%
full E2E incl VAE (ms) 18955 18282 -673 ms -3.55%

The fused AdaLN kernels touch only the denoise loop, so the -4.09% denoise gain is the kernel-level signal; the full-E2E -3.55% is diluted by the fixed VAE decode + text encode cost shared by both arms.

Per-step kernel A/B (nsys, step 10, ON vs OFF, CUDA-Graph node trace)

Net change in the kernels this PR touches (AdaLN modulation + inline FP4 quant). GEMM, attention (cuDNN SDPA), and QK/split-norm kernels are excluded — this PR does not touch them. torch.compile triton-numbering suffixes (_view_N) are stripped before aggregation so renumbered-but-identical kernels match.

Category Count GPU time (ms)
Added — fused C++ AdaLN kernels (default + pipelined) 480 +26.1
Removed — torch.compile triton rms-norm/shift_scale prep 480 -26.9
Changedquantize_with_block_size (FP4 quant now inline) 840 → 336 -17.9
Net (touched region) -17.6 ms/step

The dominant win is the FP4 quantize collapsing from 840 to 336 launches: the KA/KB/KC/KD quant variants now emit packed FP4 + per-block SF inline inside the fused kernel, eliminating 504 standalone quantize_with_block_size launches plus their bf16 round-trips through HBM. The ~480 added fused kernels are roughly time-neutral against the ~480 triton rms/shift_scale prep kernels they replace.

Test Coverage

Four unit tests under tests/unittest/_torch/thop/parallel_hw_agnostic/ exercise each op against an eager PyTorch reference. The hidden_dim ∈ {2048, 4096} × batch_size ∈ {1, 2} × tokens_per_batch parametrize sweep covers both the default tile and the pipelined variant naturally (no env-var injection):

  • test_fused_dit_rmsnorm_shift_scale.py (KA bf16 + quant; bf16 tpb ∈ {1, 16, 512}, quant tpb ∈ {126, 128})
  • test_fused_dit_resid_rmsnorm_shift_scale_dual.py (KB bf16 + quant; bf16 tpb ∈ {16, 512}, quant tpb ∈ {126, 128})
  • test_fused_dit_gate_resid_rmsnorm_shift_scale.py (KC bf16 + quant; bf16 tpb ∈ {16, 512}, quant tpb ∈ {126, 128})
  • test_fused_dit_gate_resid_rmsnorm_quant.py (KD bf16 + quant; KD always uses the default tile)

The tpb=126 cases exercise the default tile (126 % 4 = 2); tpb={16, 128, 512} exercise the pipelined variant on D=4096 KA/KB/KC.

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • If PR introduces API changes, an appropriate PR label is added - either api-compatible or api-breaking. For api-breaking, include BREAKING in the PR title.

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

Summary by CodeRabbit

  • New Features

    • Added fused RMSNorm + optional residual/gate/modulation kernels with BF16 and NVFP4 (FP4+scale) output modes, exposed as new CUDA-backed operators for faster inference.
    • Model runtime: optional fusion toggle and FP4-aware pathways; attention and linear modules now accept/propagate FP4-quantized tensors.
  • Tests

    • Added end-to-end unit tests covering BF16 and NVFP4 variants, shape/stride checks, and quantized GEMM correctness.

@luyiyun1021 luyiyun1021 requested review from a team as code owners June 8, 2026 11:13
@luyiyun1021 luyiyun1021 changed the title [TRTLLM-13265][feat] Fuse LTX-2 NVFP4 AdaLN modulation kernels (KA/KB/KC/KD) [TRTLLM-13265][feat] Fuse LTX-2 NVFP4 AdaLN modulation kernels Jun 8, 2026
@coderabbitai

coderabbitai Bot commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

Adds fused CUDA kernels and a templated launcher for DiT-style RMSNorm (BF16 with optional residual/gate/modulate) plus NVFP4 FP4+SF output, exposes them as multiple Torch ops, integrates fused dispatch into LTX-2 model code and NVFP4 linear/attention paths, and adds fake registrations and unit tests.

Changes

Fused DiT RMSNorm kernels and model integration

Layer / File(s) Summary
CUDA kernel implementation & launcher
cpp/tensorrt_llm/kernels/fusedDiTNormKernel.cu
Device kernels (fusedDiTNormKernel and pipelined variant), helper to combine modulators, per-row RMS via warp reductions, BF16 vs FP4 output paths, and host launcher with templated instantiations.
Kernel header / API
cpp/tensorrt_llm/kernels/fusedDiTNormKernel.h
Adds AdaLNNormParams public struct and templated launchFusedDiTNorm declaration with compile-time flags and constraints.
Build list
cpp/tensorrt_llm/thop/CMakeLists.txt
Adds four new fused-op .cpp sources to the th_common target.
Torch ops — RMSNorm + shift/scale
cpp/tensorrt_llm/thop/fusedDiTRmsNormShiftScaleOp.cpp
Implements bf16 and quantized fused RMSNorm + AdaLN shift/scale ops, populates AdaLNNormParams, allocates outputs, and registers Torch schemas/impls.
Torch ops — residual + gate + RMSNorm
cpp/tensorrt_llm/thop/fusedDiTResidGateRmsNormOp.cpp
Implements bf16 and FP4 variants for residual+gate RMSNorm, including layout/stride validation and quant output packing.
Torch ops — gated residual + RMSNorm + modulation
cpp/tensorrt_llm/thop/fusedDiTGateResidualRmsNormModulateOp.cpp
Implements fused gated-residual + RMSNorm + AdaLN modulation ops (bf16 and quant), handling timestep strides and FP4/SF packing.
Torch ops — residual + RMSNorm + dual shift/scale
cpp/tensorrt_llm/thop/fusedDiTResidualRmsNormShiftScaleDualOp.cpp
Implements dual-path shift/scale fused ops with bf16 and FP4 outputs (two FP4+SF pairs when quantized).
Python fake/meta registrations
tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
Adds shape inference fakes for quantized and non-quantized fused ops, computing flattened M and swizzled FP4 scale shapes.
Python dispatch & FP4 helpers
tensorrt_llm/_torch/visual_gen/models/ltx2/ltx2_core/utils_ltx2.py
Adds NVFP4 input-scale detection, supported-dim predicate, FP4 reshape/wrapping, and fused vs eager dispatch helpers for gate/resid/modulate and dual-shift-scale flows.
Transformer integration
tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py
Adds _fuse_adaln gating, supplies table/ts pairs, defers residuals when fusing, and rewrites AV/FFN paths to consume fused kernels when available.
FP4 integration in linear/attention
tensorrt_llm/_torch/modules/linear.py, tensorrt_llm/_torch/visual_gen/modules/attention.py
NVFP4LinearMethod.apply() flattens higher-rank Fp4QuantizedTensor inputs for GEMM; Attention.forward accepts Fp4QuantizedTensor hidden states.
Unit tests
tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_*.py
Adds BF16 and NVFP4 tests with PyTorch references, shape/stride edge cases, negative tests for unsupported dims, and FP4 quantization validation via fp4_quantize + nvfp4_gemm.

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested labels

VisualGen, api-compatible

Suggested reviewers

  • mikeiovine
  • leslie-fang25
  • Wanli-Jiang
  • venkywonka
  • nv-guomingz
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 51.22% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed PR description is comprehensive and well-structured with clear sections covering implementation details, design patterns, performance metrics, and test coverage.
Title check ✅ Passed The title clearly and specifically describes the main change: fusing LTX-2 Gate, Residual, Norm, AdaLN modulation, and quantization kernels into a single CUDA implementation.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@luyiyun1021 luyiyun1021 changed the title [TRTLLM-13265][feat] Fuse LTX-2 NVFP4 AdaLN modulation kernels [TRTLLM-13265][feat] Fuse LTX-2 Gate + Residual + Norm + AdaLN modulation + Quant kernels Jun 9, 2026

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/visual_gen/models/ltx2/ltx2_core/utils_ltx2.py (1)

18-402: 🛠️ Refactor suggestion | 🟠 Major | 🏗️ Heavy lift

Move these TRT-LLM-specific fused helpers out of ltx2_core.

This file sits in the upstream-mirrored LTX-2 subtree, but the new surface here is TensorRT-LLM-specific custom-op dispatch, NVFP4 handling, and fusion policy. Keeping that logic in ltx2_core will make future upstream syncs much harder; please move it into a TensorRT-LLM-owned adapter module and keep this subtree close to upstream. Based on learnings: files under tensorrt_llm/_torch/visual_gen/models/ltx2/ltx2_core/ are ported directly from the upstream Lightricks LTX-2 repository and should remain faithful to upstream.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/visual_gen/models/ltx2/ltx2_core/utils_ltx2.py` around
lines 18 - 402, This code introduces TensorRT-LLM-specific fused kernels and
NVFP4 handling (e.g. get_nvfp4_input_scale, apply_fused_adaln_modulate,
apply_fused_gate_resid_rms_modulate, apply_fused_resid_gate_rms_quant,
apply_fused_resid_rms_dual_shift_scale and Fp4QuantizedTensor usage) inside the
upstream-mirrored ltx2_core subtree; extract these TRT-LLM-specific helpers into
a new TensorRT-LLM owned adapter module, move all TensorRT-only ops/quant logic
and NVFP4 helpers there, update imports in this file to call into the adapter
(or provide thin shim wrappers delegating to the new module), and ensure the
original ltx2_core file retains only upstream-faithful logic so upstream syncs
remain clean.

Source: Learnings

🧹 Nitpick comments (2)
tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py (1)

468-473: ⚡ Quick win

Drop the benchmark-only env override before merge.

TLLM_DISABLE_FUSE_ADALN changes execution behavior through an undocumented process env var, and the comment already marks it as temporary. Leaving it in both constructors bakes a hidden debug switch into production code; either remove it or promote it to a documented config/benchmark knob.

Also applies to: 1297-1301

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py` around lines
468 - 473, Remove the temporary benchmark-only environment override that toggles
self._fuse_adaln (the block importing os as _os_bench and checking
TLLM_DISABLE_FUSE_ADALN) from the transformer_ltx2 constructors; locate the same
pattern around the second constructor (the block at ~1297-1301) and delete it or
replace it with a proper, documented configuration parameter (e.g., a
constructor arg or a config object) that explicitly controls fuse_adaln instead
of relying on an undocumented process env var. Ensure the final behavior uses
the explicit config/constructor value to set self._fuse_adaln and remove any
leftover _os_bench import or env checks.
tensorrt_llm/_torch/visual_gen/modules/attention.py (1)

479-484: ⚡ Quick win

Restore an explicit type annotation for hidden_states.

The signature dropped type information. Please annotate it as torch.Tensor | Fp4QuantizedTensor (or equivalent) to keep static typing intact.

Suggested fix
     def forward(
         self,
-        hidden_states,
+        hidden_states: torch.Tensor | Fp4QuantizedTensor,
         encoder_hidden_states: Optional[torch.Tensor] = None,
         freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
     ) -> torch.Tensor:

As per coding guidelines, "Always annotate functions."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/visual_gen/modules/attention.py` around lines 479 - 484,
The forward method lost an explicit type for hidden_states—restore it by
annotating the parameter as a union type (e.g. hidden_states: torch.Tensor |
Fp4QuantizedTensor) so static typing is preserved; update the forward signature
in the attention module (function forward) to use that union for hidden_states
while leaving encoder_hidden_states and freqs as-is and ensure any imports or
forward references for Fp4QuantizedTensor are added or aliased if necessary.

Source: Coding guidelines

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@cpp/tensorrt_llm/kernels/fusedDiTNormKernel.cu`:
- Around line 958-1046: The launcher dispatches HAS_QUANT specializations
unconditionally but those kernels require runtime SM >= 10; add a runtime
compute-capability guard at the top of launchFusedDiTNorm that queries
cudaGetDeviceProperties (or cudaGetDevice) and, when HAS_QUANT is true and
deviceProp.major < 10 (or equivalent SM check), fail fast (TLLM_THROW or
TLLM_CHECK_WITH_INFO) rather than launching; apply the same check before
dispatching both the pipelined path (fusedDiTNormKernelPipelined) and the
regular fusedDiTNormKernel launches so quantized specializations are never
launched on unsupported SMs.

In `@tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py`:
- Around line 1154-1236: The fake/meta registrations for the quantized ops were
added but the bf16 op schemas are missing, causing FakeTensor tracing to fail
when apply_fused_adaln_modulate, apply_fused_gate_resid_rms_modulate, or
apply_fused_resid_rms_dual_shift_scale dispatch the bf16 paths; add
`@torch.library.register_fake` handlers for the bf16 op names
trtllm::fused_dit_rmsnorm_shift_scale,
trtllm::fused_dit_gate_resid_rms_modulate, and
trtllm::fused_dit_resid_rms_shift_scale_dual that mirror the non-quantized
signatures used by those functions and return tensors with the same
shapes/dtypes as the inputs (e.g., torch.empty_like(x) or the appropriate tuple
of empty_like tensors matching the eager op outputs), so FakeTensor/tracing can
succeed.

In `@tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py`:
- Around line 459-467: The _fuse_adaln flag currently only checks hidden dims
via is_fused_adaln_supported_dim(video.dim)/audio.dim; update the guard to also
require bf16 dtype so non-bf16 configs don't take the fused path. Import torch
if needed and change the v_ok/a_ok conditions in transformer_ltx2 (where
_fuse_adaln is set) to require video is None or (video.dtype == torch.bfloat16
and is_fused_adaln_supported_dim(video.dim)) and similarly for audio, then set
self._fuse_adaln = v_ok and a_ok.

In
`@tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_gate_resid_rms_modulate.py`:
- Around line 123-131: The test currently computes D_ref by calling
torch.ops.trtllm.fused_dit_gate_resid_rms_modulate and then F.linear, coupling
the quant-path validation to the bf16 fused op; change the baseline to use an
independent reference implementation by calling torch_ref(...) (the standalone
reference kernel) with the same inputs instead of out_bf16, and assign its
result to D_ref so the quant test validates against torch_ref rather than the
fused op; update the D_ref assignment in the test block that currently uses
out_bf16 to call torch_ref with x, attn, gate_table, gate_ts, scale_table,
scale_ts, shift_table, shift_ts, eps (or the appropriate torch_ref signature)
and then apply F.linear with W as before.

In
`@tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_resid_gate_rms_norm_quant.py`:
- Around line 112-116: The test should fail only for the unsupported-dimension
error: change sf_scale from a one-element vector to a scalar tensor (e.g.,
torch.tensor(1.0, dtype=torch.float32, device=device)) so it cannot cause a
different failure, and tighten the assertion on
torch.ops.trtllm.fused_dit_resid_gate_rms_norm_quant by using
pytest.raises(RuntimeError, match=r"unsupported.*dimension") (or the exact
unsupported-dimension message) to ensure the test only passes when that specific
error is raised.

In
`@tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_resid_rms_shift_scale_dual.py`:
- Around line 138-144: The test currently computes D1_ref/D2_ref by calling the
implementation under test torch.ops.trtllm.fused_dit_resid_rms_shift_scale_dual
(referenced as fused_dit_resid_rms_shift_scale_dual), which can mask shared
regressions; replace that use with the independent reference implementation
torch_ref (call torch_ref(x_bf16, attn2, s1_table, s1_ts, h1_table, h1_ts,
s2_table, s2_ts, h2_table, h2_ts, eps) or the appropriate torch_ref signature)
to produce o1_bf16/o2_bf16 and then build D1_ref = F.linear(o1_bf16, W1) and
D2_ref = F.linear(o2_bf16, W2); keep all input variables (x_bf16, attn2,
s*_table, s*_ts, h*_table, h*_ts, eps, W1, W2) the same so the test decouples
quant validation from fused_dit_resid_rms_shift_scale_dual.

In
`@tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_rmsnorm_shift_scale.py`:
- Around line 165-168: The test currently broad-catches any RuntimeError when
calling torch.ops.trtllm.fused_dit_rmsnorm_shift_scale; tighten the assertion to
only accept the unsupported-dimension error by using pytest.raises with a match
regex (e.g. pytest.raises(RuntimeError, match=r"unsupported.*dimension") ) so
the test for fused_dit_rmsnorm_shift_scale validates the specific
unsupported-dimension contract rather than any runtime failure.

---

Outside diff comments:
In `@tensorrt_llm/_torch/visual_gen/models/ltx2/ltx2_core/utils_ltx2.py`:
- Around line 18-402: This code introduces TensorRT-LLM-specific fused kernels
and NVFP4 handling (e.g. get_nvfp4_input_scale, apply_fused_adaln_modulate,
apply_fused_gate_resid_rms_modulate, apply_fused_resid_gate_rms_quant,
apply_fused_resid_rms_dual_shift_scale and Fp4QuantizedTensor usage) inside the
upstream-mirrored ltx2_core subtree; extract these TRT-LLM-specific helpers into
a new TensorRT-LLM owned adapter module, move all TensorRT-only ops/quant logic
and NVFP4 helpers there, update imports in this file to call into the adapter
(or provide thin shim wrappers delegating to the new module), and ensure the
original ltx2_core file retains only upstream-faithful logic so upstream syncs
remain clean.

---

Nitpick comments:
In `@tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py`:
- Around line 468-473: Remove the temporary benchmark-only environment override
that toggles self._fuse_adaln (the block importing os as _os_bench and checking
TLLM_DISABLE_FUSE_ADALN) from the transformer_ltx2 constructors; locate the same
pattern around the second constructor (the block at ~1297-1301) and delete it or
replace it with a proper, documented configuration parameter (e.g., a
constructor arg or a config object) that explicitly controls fuse_adaln instead
of relying on an undocumented process env var. Ensure the final behavior uses
the explicit config/constructor value to set self._fuse_adaln and remove any
leftover _os_bench import or env checks.

In `@tensorrt_llm/_torch/visual_gen/modules/attention.py`:
- Around line 479-484: The forward method lost an explicit type for
hidden_states—restore it by annotating the parameter as a union type (e.g.
hidden_states: torch.Tensor | Fp4QuantizedTensor) so static typing is preserved;
update the forward signature in the attention module (function forward) to use
that union for hidden_states while leaving encoder_hidden_states and freqs as-is
and ensure any imports or forward references for Fp4QuantizedTensor are added or
aliased if necessary.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 74ebb7eb-0878-4102-8b71-70ca7fff551d

📥 Commits

Reviewing files that changed from the base of the PR and between 09c21b6 and 4b65e7d.

📒 Files selected for processing (16)
  • cpp/tensorrt_llm/kernels/fusedDiTNormKernel.cu
  • cpp/tensorrt_llm/kernels/fusedDiTNormKernel.h
  • cpp/tensorrt_llm/thop/CMakeLists.txt
  • cpp/tensorrt_llm/thop/fusedDiTGateResidualRmsNormModulateOp.cpp
  • cpp/tensorrt_llm/thop/fusedDiTResidGateRmsNormOp.cpp
  • cpp/tensorrt_llm/thop/fusedDiTResidualRmsNormShiftScaleDualOp.cpp
  • cpp/tensorrt_llm/thop/fusedDiTRmsNormShiftScaleOp.cpp
  • tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
  • tensorrt_llm/_torch/modules/linear.py
  • tensorrt_llm/_torch/visual_gen/models/ltx2/ltx2_core/utils_ltx2.py
  • tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py
  • tensorrt_llm/_torch/visual_gen/modules/attention.py
  • tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_gate_resid_rms_modulate.py
  • tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_resid_gate_rms_norm_quant.py
  • tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_resid_rms_shift_scale_dual.py
  • tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_rmsnorm_shift_scale.py

Comment on lines +958 to +1046
template <bool HAS_RESIDUAL, bool HAS_GATE, bool HAS_MODULATE, int NUM_OUT, bool HAS_QUANT>
void launchFusedDiTNorm(AdaLNNormParams const& params, int hidden_dim, cudaStream_t stream)
{
TLLM_CHECK_WITH_INFO(params.num_tokens >= 1, "num_tokens must be >= 1, got %d", params.num_tokens);
TLLM_CHECK_WITH_INFO(
params.tokens_per_batch >= 1, "tokens_per_batch must be >= 1, got %d", params.tokens_per_batch);
TLLM_CHECK_WITH_INFO(params.num_tokens % params.tokens_per_batch == 0,
"num_tokens (%d) must be divisible by tokens_per_batch (%d)", params.num_tokens, params.tokens_per_batch);

// Production tile selected via NCU sweep at D=4096, B=1, T=15360: 1-row/CTA, 256 threads/CTA
// gives 2 CTAs/SM (regs <= 80) and the highest per-element bench rate among {2r256, 1r256, 1r512}.
constexpr int ROWS_PER_BLOCK = 1;
constexpr int BLOCK_SIZE = 256;
constexpr int WARPS_PER_ROW = (BLOCK_SIZE / ROWS_PER_BLOCK) / 32;

cudaLaunchConfig_t cfg = {};
cfg.stream = stream;
cudaLaunchAttribute attrs[1] = {};
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = 1;
cfg.attrs = attrs;
cfg.numAttrs = 1;

// Optional pipelined dispatch (multi-row CTA + circular TMA stages + warp specialization).
// Enabled by env var TLLM_DIT_NORM_PIPELINED=1. Fires only when:
// - hidden_dim == 4096 (video path; audio D=2048 has too small a grid to amortize the
// warp-specialization overhead and bench-regresses)
// - variant is KA/KB/KC (KD's HAS_GATE && !HAS_MODULATE is HBM-bound and bench-regresses)
// - tokens_per_batch >= 4 && % 4 == 0 (R_CTA=4 rows share batchIdx for per-CTA mod cache)
// Otherwise falls through to the default path below.
static bool const pipelined_enabled = []()
{
char const* env = std::getenv("TLLM_DIT_NORM_PIPELINED");
return env != nullptr && env[0] != '\0' && env[0] != '0';
}();
constexpr bool kPipelinedVariantOK = HAS_MODULATE || !HAS_GATE;
if constexpr (kPipelinedVariantOK)
{
// NUM_STAGES chosen per HAS_QUANT: bf16 path is HBM-latency-bound so deeper pipeline
// (3 stages) helps; quant Phase 2 is compute-heavy and extra stages just add SMEM
// pressure without commensurate latency hiding.
constexpr int PIPE_NUM_STAGES = HAS_QUANT ? 2 : 3;
constexpr int PIPE_BLOCK_SIZE = 288;
constexpr int PIPE_D = 4096;
constexpr int PIPE_R_CTA = 4;
if (pipelined_enabled && hidden_dim == PIPE_D && (params.num_tokens % PIPE_R_CTA == 0)
&& (params.tokens_per_batch >= PIPE_R_CTA) && (params.tokens_per_batch % PIPE_R_CTA == 0))
{
constexpr int pipe_stage_elems = HAS_RESIDUAL ? (2 * PIPE_D) : PIPE_D;
size_t const pipe_smem_bytes
= static_cast<size_t>(PIPE_NUM_STAGES) * pipe_stage_elems * sizeof(__nv_bfloat16)
+ 2 * PIPE_NUM_STAGES * sizeof(uint64_t) + 8 * sizeof(float) + 16;
cudaLaunchConfig_t pipe_cfg = cfg;
pipe_cfg.gridDim = dim3((params.num_tokens + PIPE_R_CTA - 1) / PIPE_R_CTA);
pipe_cfg.blockDim = dim3(PIPE_BLOCK_SIZE);
pipe_cfg.dynamicSmemBytes = static_cast<int>(pipe_smem_bytes);
auto* pipe_kp = fusedDiTNormKernelPipelined<PIPE_D, PIPE_R_CTA, PIPE_NUM_STAGES, HAS_RESIDUAL, HAS_GATE,
HAS_MODULATE, NUM_OUT, HAS_QUANT>;
cudaFuncSetAttribute(
pipe_kp, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast<int>(pipe_smem_bytes));
cudaLaunchKernelEx(&pipe_cfg, pipe_kp, params);
return;
}
}
(void) pipelined_enabled;

#define LAUNCH(D_VAL) \
do \
{ \
constexpr bool USE_TMA_HOST = (D_VAL >= 4096) && !(HAS_QUANT && HAS_MODULATE && NUM_OUT == 1); \
constexpr bool USE_TMA_ATTN_HOST = USE_TMA_HOST && HAS_RESIDUAL && HAS_QUANT; \
int const attn_extra_bytes \
= USE_TMA_ATTN_HOST ? (ROWS_PER_BLOCK * D_VAL * static_cast<int>(sizeof(__nv_bfloat16))) : 0; \
cfg.gridDim = dim3((params.num_tokens + ROWS_PER_BLOCK - 1) / ROWS_PER_BLOCK); \
cfg.blockDim = dim3(BLOCK_SIZE); \
cfg.dynamicSmemBytes = ROWS_PER_BLOCK * D_VAL * static_cast<int>(sizeof(__nv_bfloat16)) + attn_extra_bytes \
+ ROWS_PER_BLOCK * WARPS_PER_ROW * static_cast<int>(sizeof(float)); \
cudaLaunchKernelEx(&cfg, \
fusedDiTNormKernel<D_VAL, ROWS_PER_BLOCK, BLOCK_SIZE, HAS_RESIDUAL, HAS_GATE, HAS_MODULATE, NUM_OUT, \
HAS_QUANT>, \
params); \
} while (0)

switch (hidden_dim)
{
case 2048: LAUNCH(2048); break;
case 4096: LAUNCH(4096); break;
default: TLLM_THROW("Unsupported hidden_dim for fusedDiTNorm: %d (only 2048, 4096)", hidden_dim);
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Add a runtime SM guard before launching quantized specializations.

HAS_QUANT kernels only emit FP4/SF stores in __CUDA_ARCH__ >= 1000 blocks, but the launcher currently dispatches them unconditionally. On lower-SM runtime targets this can return undefined outputs instead of failing fast.

Suggested fix
 template <bool HAS_RESIDUAL, bool HAS_GATE, bool HAS_MODULATE, int NUM_OUT, bool HAS_QUANT>
 void launchFusedDiTNorm(AdaLNNormParams const& params, int hidden_dim, cudaStream_t stream)
 {
+    if constexpr (HAS_QUANT)
+    {
+        int dev = 0;
+        int cc_major = 0;
+        cudaGetDevice(&dev);
+        cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, dev);
+        TLLM_CHECK_WITH_INFO(
+            cc_major >= 10,
+            "fusedDiTNorm quant variants require SM100+; got compute capability major %d",
+            cc_major);
+    }
+
     TLLM_CHECK_WITH_INFO(params.num_tokens >= 1, "num_tokens must be >= 1, got %d", params.num_tokens);
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/tensorrt_llm/kernels/fusedDiTNormKernel.cu` around lines 958 - 1046, The
launcher dispatches HAS_QUANT specializations unconditionally but those kernels
require runtime SM >= 10; add a runtime compute-capability guard at the top of
launchFusedDiTNorm that queries cudaGetDeviceProperties (or cudaGetDevice) and,
when HAS_QUANT is true and deviceProp.major < 10 (or equivalent SM check), fail
fast (TLLM_THROW or TLLM_CHECK_WITH_INFO) rather than launching; apply the same
check before dispatching both the pipelined path (fusedDiTNormKernelPipelined)
and the regular fusedDiTNormKernel launches so quantized specializations are
never launched on unsupported SMs.

Comment on lines +1154 to +1236
@torch.library.register_fake("trtllm::fused_dit_rmsnorm_shift_scale_quant")
def _(x, scale_table, scale_ts, shift_table, shift_ts, sf_scale, eps):
"""Fake/meta for fused KA + NVFP4 quant. SF layout is SWIZZLED 128x4.

Modulator is built inline from (table, ts) pairs:
scale[d] = scale_table[d] + scale_ts[b, d]
Folds the upstream broadcast-add Triton prep kernel into the C++ op.
"""
D = x.shape[-1]
M = 1
for d in x.shape[:-1]:
M *= d
_, scale_shape = fp4_utils.get_fp4_shape((M, D),
16,
is_swizzled_layout=True)
out_fp4 = x.new_empty((M, D // 2), dtype=torch.uint8)
out_sf = x.new_empty((scale_shape, ), dtype=torch.uint8)
return out_fp4, out_sf

@torch.library.register_fake(
"trtllm::fused_dit_gate_resid_rms_modulate_quant")
def _(x, attn_out, gate_table, gate_ts, scale_table, scale_ts, shift_table,
shift_ts, sf_scale, eps):
"""Fake/meta for fused KC + NVFP4 quant. SF layout is SWIZZLED 128x4.

Modulators built inline from (table, ts) pairs; folds upstream
broadcast-add Triton prep kernel into the C++ op.
"""
D = x.shape[-1]
M = 1
for d in x.shape[:-1]:
M *= d
_, scale_shape = fp4_utils.get_fp4_shape((M, D),
16,
is_swizzled_layout=True)
out_fp4 = x.new_empty((M, D // 2), dtype=torch.uint8)
out_sf = x.new_empty((scale_shape, ), dtype=torch.uint8)
return out_fp4, out_sf

@torch.library.register_fake("trtllm::fused_dit_resid_gate_rms_norm")
def _(x, attn_out, gate_table, gate_ts, eps):
"""Fake/meta for fused KD bf16 (residual_add + gate_mul + rms_norm).
Gate built inline from (table, ts) pair -- folds the upstream
broadcast-add Triton prep + `attn * gate` mul into Phase 0b."""
return torch.empty_like(x)

@torch.library.register_fake("trtllm::fused_dit_resid_gate_rms_norm_quant")
def _(x, attn_out, gate_table, gate_ts, sf_scale, eps):
"""Fake/meta for fused KD (residual_add + gate_mul + rms_norm + NVFP4 quant).
Gate built inline from (table, ts) pair. SF layout is SWIZZLED 128x4."""
D = x.shape[-1]
M = 1
for d in x.shape[:-1]:
M *= d
_, scale_shape = fp4_utils.get_fp4_shape((M, D),
16,
is_swizzled_layout=True)
out_fp4 = x.new_empty((M, D // 2), dtype=torch.uint8)
out_sf = x.new_empty((scale_shape, ), dtype=torch.uint8)
return out_fp4, out_sf

@torch.library.register_fake(
"trtllm::fused_dit_resid_rms_shift_scale_dual_quant")
def _(x, attn2_out, scale_dir1_table, scale_dir1_ts, shift_dir1_table,
shift_dir1_ts, scale_dir2_table, scale_dir2_ts, shift_dir2_table,
shift_dir2_ts, sf_scale1, sf_scale2, eps):
"""Fake/meta for fused KB + dual NVFP4 quant. SF layout is SWIZZLED 128x4.

4 modulators built inline from (table, ts) pairs; folds upstream
broadcast-add Triton prep kernel into the C++ op.
"""
D = x.shape[-1]
M = 1
for d in x.shape[:-1]:
M *= d
_, scale_shape = fp4_utils.get_fp4_shape((M, D),
16,
is_swizzled_layout=True)
out1_fp4 = x.new_empty((M, D // 2), dtype=torch.uint8)
out1_sf = x.new_empty((scale_shape, ), dtype=torch.uint8)
out2_fp4 = x.new_empty((M, D // 2), dtype=torch.uint8)
out2_sf = x.new_empty((scale_shape, ), dtype=torch.uint8)
return out1_fp4, out1_sf, out2_fp4, out2_sf

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Register fake/meta kernels for the bf16 variants too.

These additions cover the quantized ops plus KD bf16, but apply_fused_adaln_modulate(), apply_fused_gate_resid_rms_modulate(), and apply_fused_resid_rms_dual_shift_scale() still dispatch to the bf16 op names trtllm::fused_dit_rmsnorm_shift_scale, trtllm::fused_dit_gate_resid_rms_modulate, and trtllm::fused_dit_resid_rms_shift_scale_dual when the downstream linear is not NVFP4-quantized. Without fake registrations for those schemas, torch.compile / FakeTensor tracing will fail on the bf16 path even though eager execution works.

Suggested fake registrations
+    `@torch.library.register_fake`("trtllm::fused_dit_rmsnorm_shift_scale")
+    def _(x, scale_table, scale_ts, shift_table, shift_ts, eps):
+        return torch.empty_like(x)
+
+    `@torch.library.register_fake`("trtllm::fused_dit_gate_resid_rms_modulate")
+    def _(x, attn_out, gate_table, gate_ts, scale_table, scale_ts, shift_table, shift_ts, eps):
+        return torch.empty_like(x)
+
+    `@torch.library.register_fake`("trtllm::fused_dit_resid_rms_shift_scale_dual")
+    def _(
+        x,
+        attn2_out,
+        scale_dir1_table,
+        scale_dir1_ts,
+        shift_dir1_table,
+        shift_dir1_ts,
+        scale_dir2_table,
+        scale_dir2_ts,
+        shift_dir2_table,
+        shift_dir2_ts,
+        eps,
+    ):
+        return torch.empty_like(x), torch.empty_like(x)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py` around lines 1154 - 1236,
The fake/meta registrations for the quantized ops were added but the bf16 op
schemas are missing, causing FakeTensor tracing to fail when
apply_fused_adaln_modulate, apply_fused_gate_resid_rms_modulate, or
apply_fused_resid_rms_dual_shift_scale dispatch the bf16 paths; add
`@torch.library.register_fake` handlers for the bf16 op names
trtllm::fused_dit_rmsnorm_shift_scale,
trtllm::fused_dit_gate_resid_rms_modulate, and
trtllm::fused_dit_resid_rms_shift_scale_dual that mirror the non-quantized
signatures used by those functions and return tensors with the same
shapes/dtypes as the inputs (e.g., torch.empty_like(x) or the appropriate tuple
of empty_like tensors matching the eager op outputs), so FakeTensor/tracing can
succeed.

Comment on lines +459 to +467
# Whether to dispatch AdaLN modulation to the fused CUDA kernels. Resolved
# once at construction; call sites just consult the flag. The kernels are
# bf16 + hidden_dim in {2048, 4096}; non-matching cases raise at the C++
# boundary -- this flag is the only Python-side guard.
from .ltx2_core.utils_ltx2 import is_fused_adaln_supported_dim

v_ok = video is None or is_fused_adaln_supported_dim(video.dim)
a_ok = audio is None or is_fused_adaln_supported_dim(audio.dim)
self._fuse_adaln = v_ok and a_ok

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Gate _fuse_adaln on bf16 as well as hidden size.

All four new fused thop entrypoints hard-check torch.bfloat16, but _fuse_adaln only checks the hidden dims. A float16/float32 LTX-2 config with a supported inner dim will now take the fused path and fail on the first custom-op call instead of cleanly staying on the eager path.

Also applies to: 1288-1296

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py` around lines
459 - 467, The _fuse_adaln flag currently only checks hidden dims via
is_fused_adaln_supported_dim(video.dim)/audio.dim; update the guard to also
require bf16 dtype so non-bf16 configs don't take the fused path. Import torch
if needed and change the v_ok/a_ok conditions in transformer_ltx2 (where
_fuse_adaln is set) to require video is None or (video.dtype == torch.bfloat16
and is_fused_adaln_supported_dim(video.dim)) and similarly for audio, then set
self._fuse_adaln = v_ok and a_ok.

Comment on lines +123 to +131
# bf16 reference: compute KC's bf16 output, then F.linear with random W.
torch.manual_seed(13)
W = torch.randn(out_dim, hidden_dim, dtype=torch.bfloat16, device="cuda") * 0.05

x_bf16 = x.clone()
out_bf16 = torch.ops.trtllm.fused_dit_gate_resid_rms_modulate(
x_bf16, attn, gate_table, gate_ts, scale_table, scale_ts, shift_table, shift_ts, eps
)
D_ref = F.linear(out_bf16, W)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Use an independent reference in the quant test path.

Line 128 derives D_ref from the bf16 fused op, so quant validation is coupled to the same implementation family. That weakens detection of shared regressions at tokens_per_batch=126 (a shape not covered by the bf16 test). Use torch_ref(...) for the baseline here.

Suggested patch
-    x_bf16 = x.clone()
-    out_bf16 = torch.ops.trtllm.fused_dit_gate_resid_rms_modulate(
-        x_bf16, attn, gate_table, gate_ts, scale_table, scale_ts, shift_table, shift_ts, eps
-    )
+    out_bf16, _ = torch_ref(
+        x,
+        attn,
+        gate_table,
+        gate_ts,
+        scale_table,
+        scale_ts,
+        shift_table,
+        shift_ts,
+        tokens_per_batch,
+        eps,
+    )
     D_ref = F.linear(out_bf16, W)

As per coding guidelines, tests/** should provide sufficient and actionable coverage; this change prevents false confidence from coupled references.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# bf16 reference: compute KC's bf16 output, then F.linear with random W.
torch.manual_seed(13)
W = torch.randn(out_dim, hidden_dim, dtype=torch.bfloat16, device="cuda") * 0.05
x_bf16 = x.clone()
out_bf16 = torch.ops.trtllm.fused_dit_gate_resid_rms_modulate(
x_bf16, attn, gate_table, gate_ts, scale_table, scale_ts, shift_table, shift_ts, eps
)
D_ref = F.linear(out_bf16, W)
# bf16 reference: compute KC's bf16 output, then F.linear with random W.
torch.manual_seed(13)
W = torch.randn(out_dim, hidden_dim, dtype=torch.bfloat16, device="cuda") * 0.05
out_bf16, _ = torch_ref(
x,
attn,
gate_table,
gate_ts,
scale_table,
scale_ts,
shift_table,
shift_ts,
tokens_per_batch,
eps,
)
D_ref = F.linear(out_bf16, W)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_gate_resid_rms_modulate.py`
around lines 123 - 131, The test currently computes D_ref by calling
torch.ops.trtllm.fused_dit_gate_resid_rms_modulate and then F.linear, coupling
the quant-path validation to the bf16 fused op; change the baseline to use an
independent reference implementation by calling torch_ref(...) (the standalone
reference kernel) with the same inputs instead of out_bf16, and assign its
result to D_ref so the quant test validates against torch_ref rather than the
fused op; update the D_ref assignment in the test block that currently uses
out_bf16 to call torch_ref with x, attn, gate_table, gate_ts, scale_table,
scale_ts, shift_table, shift_ts, eps (or the appropriate torch_ref signature)
and then apply F.linear with W as before.

Source: Coding guidelines

Comment on lines +112 to +116
sf_scale = torch.tensor([1.0], dtype=torch.float32, device=device)
with pytest.raises(RuntimeError):
torch.ops.trtllm.fused_dit_resid_gate_rms_norm_quant(
x, attn, gate_table, gate_ts, sf_scale, 1e-6
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Make the unsupported-dim test fail only for the intended reason.

Line 113 uses a broad RuntimeError assertion, and Line 112 uses a [1] tensor for sf_scale, which can introduce alternate failure causes. Use a scalar sf_scale and match the unsupported-dimension message.

Suggested patch
-    sf_scale = torch.tensor([1.0], dtype=torch.float32, device=device)
-    with pytest.raises(RuntimeError):
+    sf_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
+    with pytest.raises(RuntimeError, match=r"(unsupported|2048|4096)"):
         torch.ops.trtllm.fused_dit_resid_gate_rms_norm_quant(
             x, attn, gate_table, gate_ts, sf_scale, 1e-6
         )

As per coding guidelines, tests/** should keep coverage actionable and targeted to the contract being validated.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_resid_gate_rms_norm_quant.py`
around lines 112 - 116, The test should fail only for the unsupported-dimension
error: change sf_scale from a one-element vector to a scalar tensor (e.g.,
torch.tensor(1.0, dtype=torch.float32, device=device)) so it cannot cause a
different failure, and tighten the assertion on
torch.ops.trtllm.fused_dit_resid_gate_rms_norm_quant by using
pytest.raises(RuntimeError, match=r"unsupported.*dimension") (or the exact
unsupported-dimension message) to ensure the test only passes when that specific
error is raised.

Source: Coding guidelines

Comment on lines +138 to +144
# bf16 reference: KB bf16 outputs -> F.linear with random W1 / W2.
x_bf16 = x.clone()
o1_bf16, o2_bf16 = torch.ops.trtllm.fused_dit_resid_rms_shift_scale_dual(
x_bf16, attn2, s1_table, s1_ts, h1_table, h1_ts, s2_table, s2_ts, h2_table, h2_ts, eps
)
D1_ref = F.linear(o1_bf16, W1)
D2_ref = F.linear(o2_bf16, W2)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Decouple quant validation from the bf16 fused implementation.

Line 140 builds D1_ref/D2_ref from fused_dit_resid_rms_shift_scale_dual, which is part of the same implementation family being tested. This can mask shared regressions at tokens_per_batch=126. Build the quant reference from torch_ref(...) instead.

Suggested patch
-    # bf16 reference: KB bf16 outputs -> F.linear with random W1 / W2.
-    x_bf16 = x.clone()
-    o1_bf16, o2_bf16 = torch.ops.trtllm.fused_dit_resid_rms_shift_scale_dual(
-        x_bf16, attn2, s1_table, s1_ts, h1_table, h1_ts, s2_table, s2_ts, h2_table, h2_ts, eps
-    )
+    # Independent bf16 reference from torch_ref.
+    o1_bf16, o2_bf16, _ = torch_ref(
+        x,
+        attn2,
+        s1_table,
+        s1_ts,
+        h1_table,
+        h1_ts,
+        s2_table,
+        s2_ts,
+        h2_table,
+        h2_ts,
+        tokens_per_batch,
+        eps,
+    )
     D1_ref = F.linear(o1_bf16, W1)
     D2_ref = F.linear(o2_bf16, W2)

As per coding guidelines, for tests/** coverage should be sufficient and actionable; independent references are needed to catch real regressions.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_resid_rms_shift_scale_dual.py`
around lines 138 - 144, The test currently computes D1_ref/D2_ref by calling the
implementation under test torch.ops.trtllm.fused_dit_resid_rms_shift_scale_dual
(referenced as fused_dit_resid_rms_shift_scale_dual), which can mask shared
regressions; replace that use with the independent reference implementation
torch_ref (call torch_ref(x_bf16, attn2, s1_table, s1_ts, h1_table, h1_ts,
s2_table, s2_ts, h2_table, h2_ts, eps) or the appropriate torch_ref signature)
to produce o1_bf16/o2_bf16 and then build D1_ref = F.linear(o1_bf16, W1) and
D2_ref = F.linear(o2_bf16, W2); keep all input variables (x_bf16, attn2,
s*_table, s*_ts, h*_table, h*_ts, eps, W1, W2) the same so the test decouples
quant validation from fused_dit_resid_rms_shift_scale_dual.

Source: Coding guidelines

Comment on lines +165 to +168
with pytest.raises(RuntimeError):
torch.ops.trtllm.fused_dit_rmsnorm_shift_scale(
x, scale_table, scale_ts, shift_table, shift_ts, 1e-6
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Narrow the unsupported-dimension assertion to the intended contract.

Line 165 currently accepts any RuntimeError, so unrelated failures can make this test pass. Match the expected unsupported-dimension error text.

Suggested patch
-    with pytest.raises(RuntimeError):
+    with pytest.raises(RuntimeError, match=r"(unsupported|2048|4096)"):
         torch.ops.trtllm.fused_dit_rmsnorm_shift_scale(
             x, scale_table, scale_ts, shift_table, shift_ts, 1e-6
         )

As per coding guidelines, tests/** feedback should be actionable and coverage should clearly validate the intended behavior.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_rmsnorm_shift_scale.py`
around lines 165 - 168, The test currently broad-catches any RuntimeError when
calling torch.ops.trtllm.fused_dit_rmsnorm_shift_scale; tighten the assertion to
only accept the unsupported-dimension error by using pytest.raises with a match
regex (e.g. pytest.raises(RuntimeError, match=r"unsupported.*dimension") ) so
the test for fused_dit_rmsnorm_shift_scale validates the specific
unsupported-dimension contract rather than any runtime failure.

Source: Coding guidelines

… (table,ts) pair-form modulator

Four fused CUDA kernels cover all per-block AdaLN sites for the LTX-2
transformer (video + audio modalities, 48 blocks per stage). Each kernel
consumes (table_fp32[D], ts_bf16[B,T,D]) modulator pairs and combines
them inline in Phase 0b, instead of taking a pre-added bf16 modulator.

The combine semantics match PyTorch eager _get_all_ada_values: narrow
the fp32 Parameter slice to bf16 first, then bf16 hw add via
__hadd2(__float22bfloat162_rn(table), ts). This folds the broadcast-add
prep work (previously emitted by Inductor as a separate leaf POI kernel
per call site) into the consumer C++ op.

Kernels (file -> op family):
  - fusedDiTRmsNormShiftScaleKernel (KA)
      rms_norm(x) -> (1 + scale) * norm + shift
      2 modulator pairs (scale, shift). 2 rows / CTA, BLOCK_SIZE=256.
  - fusedDiTResidualRmsNormShiftScaleDualKernelA (KB)
      (x + attn2) -> rms_norm -> two parallel (1+s)*norm + h
      4 modulator pairs (2x scale,shift). 1 row / CTA, BLOCK_SIZE=512.
  - fusedDiTGateResidualRmsNormModulateKernel (KC)
      (x + attn * gate) -> rms_norm -> (1 + scale) * norm + shift
      3 modulator pairs (gate, scale, shift). 1 row / CTA, BLOCK_SIZE=512.
  - fusedDiTResidGateRmsNormQuantKernel (KD)
      (x + attn * gate) -> rms_norm [-> NVFP4 quant]
      1 modulator pair (gate). 1 row / CTA, BLOCK_SIZE=512.

Each kernel has a bf16-output variant and an NVFP4-output variant
selected via a <bool DO_QUANT> template parameter. The quant variant
directly emits packed FP4 + 128x4 SWIZZLED FP8 e4m3 scale factors that
the downstream NVFP4 Linear (qkv_proj, ff.up_proj, cross_attn.to_q/to_k)
consumes without re-quantizing.

Op signatures (Tensor schemas in thop):
  - KA: (x, scale_table, scale_ts, shift_table, shift_ts, [sf_scale,] eps)
  - KB: (x, attn2, *(table, ts) x4, [sf_scale1, sf_scale2,] eps)
  - KC: (x, attn, *(table, ts) x3, [sf_scale,] eps)
  - KD: (x, attn_out, gate_table, gate_ts, [sf_scale,] eps)

Python:
  - _get_ada_value_pair helper in transformer_ltx2.py returns a
    (table_slice, ts_slice) view without materializing the broadcast-add.
    Sibling to the existing _get_all_ada_values helper.
  - apply_fused_adaln_modulate / _gate_resid_rms_modulate /
    _resid_rms_dual_shift_scale / _resid_gate_rms_quant wrappers in
    utils_ltx2.py dispatch the bf16 / quant variants based on
    fp4_input_scale.
  - Linear forward (modules/linear.py) accepts Fp4QuantizedTensor input.
  - Attention forward (visual_gen/modules/attention.py) handles
    Fp4QuantizedTensor for qkv_proj input.
  - apply_fused_resid_rms_quant kept as backward-compat alias for
    apply_fused_resid_gate_rms_quant.

Fake op registrations in cpp_custom_ops.py updated to match the new
schemas.

Unit tests (under tests/unittest/_torch/thop/parallel_hw_agnostic/):
  KA: test_fused_dit_rmsnorm_shift_scale.py            28/28 pass
  KB: test_fused_dit_resid_rms_shift_scale_dual_a.py   12/12 pass
  KC: test_fused_dit_gate_resid_rms_modulate.py        12/12 pass
  KD: test_fused_dit_resid_gate_rms_norm_quant.py      13/13 pass
Tolerance: rtol=2e-2, atol=1e-2. Numerics match PyTorch eager
_get_all_ada_values byte-for-byte at the combined modulator boundary.

E2E (1 GPU, CUDA Graph, 121 frames @ 768x1280, 40 steps, B200, 3-run avg):
  fuse_OFF (Inductor eager) : 483.86 ms / step
  fuse_ON  (this commit)    : 481.67 ms / step  (-2.19 ms, -0.45%)

Inductor cache dump verifies the KA/KB/KC/KD modulator prep kernels are
eliminated; remaining POI kernels are unrelated GEMM bias adds and FFN
GELU+FP4 quant pointwise.

Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
…by template

Adds compile-time hybrid TMA / cp.async gating to the unified fusedDiTNormKernel.
On Blackwell (sm_90+), the X tensor (and, for quant+residual templates, the attn
tensor) is loaded via a single cp.async.bulk.shared::cta.global instruction with
mbarrier completion, replacing the 32-issue cp.async loop. The original cp.async
path is preserved as the compile-time fallback for templates where TMA does not
win on paired bench.

Two constexpr gates pick the load path per template instantiation:

  USE_TMA       = (D >= 4096) && !(HAS_QUANT && HAS_MODULATE && NUM_OUT == 1)
  USE_TMA_ATTN  = USE_TMA && HAS_RESIDUAL && HAS_QUANT

Microbench summary (paired vs baseline, no regressions across 32 cells):

  Template         V2 d%      V1 d%
  KA-bf16         -9.0       -9.0     (TMA)
  KA-quant        -0.1       +1.0     (cp.async fallback)
  KB-bf16        -11.0      -11.5     (TMA)
  KB-quant       -19.7      -20.9     (TMA + TMA-attn)
  KC-bf16         -1.2       -0.8     (TMA, marginal)
  KC-quant        +0.1       +0.2     (cp.async fallback)
  KD-bf16         -1.4       -1.1     (TMA, marginal)
  KD-quant        -9.6      -11.5     (TMA + TMA-attn)

Production E2E (cfg=1, 40-step, paired 5 runs/side):

  per_step : 481.4 -> 473.2 ms (-8.2 ms, -1.71%)
  E2E      : 19.26 -> 18.93 s  (-330 ms, -1.71%)

Step-4 nsys (--cuda-graph-trace=node):

  total GPU: 377.85 -> 371.01 ms (-6.84 ms, -1.81%)
  launches : 3617 -> 3395
  KB-quant V1 (heaviest kernel): 8.50 ms / 42 calls / 202 us per call.

Implementation notes:

  - mbarrier in static SMEM (8 B, alignas(8)); compiler elides when USE_TMA=false.
  - When USE_TMA_ATTN, smem_attn slot follows smem_x; warp_sums offset shifts.
    Host launcher dyn_smem includes the extra slot via matching constexpr.
  - Phase 0c reads attn from smem_attn (USE_TMA_ATTN) or stays as direct LDG.
  - Wait site gated: mbarrier.try_wait.parity vs __pipeline_wait_prior(0).
  - No new fields in AdaLNNormParams; non-tensor cp.async.bulk variant means
    no cuTensorMapEncodeTiled host overhead.
  - All 8 template instantiations rebuilt; ptxas reports 0 spill stores/loads.

Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
… per-slot mod helper

cpp/tensorrt_llm/kernels/fusedDiTNormKernel.cu:
  Add fusedDiTNormKernelPipelined: multi-row CTA (R_CTA=4) with circular
  SMEM stages and 1 producer / 8 consumer warp specialization. Producer
  issues cp.async.bulk for X (and attn when HAS_RESIDUAL); consumers wait
  on full_bar, compute x_new + RMSNorm + modulate + write, then arrive
  empty_bar. NUM_STAGES=3 for bf16, NUM_STAGES=2 for quant. Inline PTX
  helpers (mbar_init/arrive/wait, cp_async_bulk, bar_sync_consumer) are
  hoisted out of the existing kernel for reuse. Gated by env var
  TLLM_DIT_NORM_PIPELINED on D=4096 + KA/KB/KC variants; KD falls through
  to the default path. All other shapes/variants unchanged.

cpp/tensorrt_llm/kernels/fusedDiTNormKernel.h:
  Trim NCU sweep narrative in the tile-selection comment.

tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py:
  Restore _get_ada_values(table, batch_size, timestep, indices: slice)
  signature so MSA (slice 0..3) / MLP (slice 3..6) / gate (slice 5..6)
  fetch independently. Add pair-form companion
  _get_ada_table_ts_pairs(..., indices: slice) returning one
  (table_slice, ts_slice) per slot. Collapse 22 per-slot call sites
  across MSA video / MSA audio / AV CA video / AV CA audio / MLP video /
  MLP audio into 6 slice-based calls. Drop dead pre-fetch of all 6
  combined-form values at the top of each run_*x branch.

Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
…drop unused alias

- Remove backward-compat alias ``apply_fused_resid_rms_quant`` (zero callers).
- Strip references to renamed/dropped helpers (``_get_ada_value_pair``,
  ``_get_all_ada_values``) and the "previously pre-added bf16 modulator" /
  "Inductor leaf POI prep kernel" narrative.
- Condense each eager-fallback comment to a single line that points at the
  one autotuner-bug note in ``apply_fused_adaln_modulate``.

Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
…strip K* names

- Auto-dispatch in launchFusedDiTNorm: replace TLLM_DIT_NORM_PIPELINED env-var
  gate with a pure shape/template predicate. One Python op entry point; C++
  decides between default and pipelined kernels per (hidden_dim, num_tokens,
  tokens_per_batch, variant flags).
- transformer_ltx2: drop defer_text_residual / defer_a2v_into_ka /
  defer_v2a_into_ka flags. Producers always park attn outputs in *_attn_raw
  slots; consumers naturally bifurcate on `is not None`; per-batch perturbation
  masks are pre-multiplied onto attn before deferring (mask is per-batch,
  gate is per-feature -- commutative). A single av_ca_runs flag gates a
  cleanup at the AV CA exit for single-modality models where the consumer
  block doesn't run.
- Add _get_av_ca_ada_table_ts_pairs as pair-form companion to
  _get_av_ca_ada_values; refactor the two KB-site callers to use it instead
  of `_get_ada_table_ts_pairs(table[:4], slice(0, 4))`.
- Strip all KA/KB/KC/KD dev-stage names from code comments and test files;
  describe each kernel by its operation instead.
- Add tokens_per_batch=128 cells to KA/KB/KC quant unit tests so the
  auto-dispatched pipelined kernel gets CI numerical coverage on the quant
  path (existing tokens_per_batch=126 cells hit the default tile).
- Rename v_ok / a_ok -> video_supports_fused_adaln / audio_supports_fused_adaln;
  hoist `import os` to file head; refresh the TLLM_DISABLE_FUSE_ADALN
  comment to describe its intentional A/B benchmarking role.

Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
@luyiyun1021 luyiyun1021 force-pushed the ltx2-adaln-fusion branch 3 times, most recently from 8facf8f to 1bca0a6 Compare June 10, 2026 09:19
…+ FFN callsite

Extend the fused-DiT kernel family with a gate-residual-only variant for the
FFN write-back site (`vx = vx + ff(vx_scaled) * gate`):

- Add HAS_NORM template flag to fusedDiTGateResidNormModulateKernel; the
  HAS_NORM=false instantiation skips the rms reduce + Phase 2 modulate paths
  and returns after Phase 0d's in-place x_new writeback.
- New thop op `fused_dit_gate_resid` (cpp/tensorrt_llm/thop/fusedDiTGateResidOp.cpp)
  that launches the HAS_NORM=false specialization.
- New Python wrapper `apply_fused_gate_resid` in utils_ltx2.py with a
  torch.compile-safe eager fallback for fuse=False.
- Wire FFN gate-residual callsites in transformer_ltx2.py (video + audio FFN)
  to the new wrapper.

Naming consistency:
- Rename `apply_fused_adaln_modulate` -> `apply_fused_rms_shift_scale` and
  `apply_fused_resid_gate` -> `apply_fused_gate_resid` so all wrapper names
  read in op-execution order (gate->resid->rms->shift_scale).

Cleanups:
- Strip the TLLM_DIT_NORM_FORCE_DEFAULT and TLLM_DISABLE_FUSE_ADALN bench-only
  env vars; auto-dispatch / shape-based fuse is the only path now.
- Drop the now-unused <cstdlib> include and `import os`.
- Move the combined-form `_get_av_ca_ada_values` (AV cross-attn modulators)
  into the eager `else` branch where its outputs are actually consumed. In the
  fused path its scale/shift are recomputed via pair-form and its gate output
  was already dead, so the unconditional call only emitted a redundant
  torch.compile prep kernel (~9.7us/step) -- now eliminated.

E2E (LTX-2 single-stage, audio+video, guidance_scale=1.0, 768x1280x121, 40
steps, B200, 5-run avg): denoise per-step 343.7ms -> 329.6ms (-14.1ms, -4.09%),
denoise loop 13.746s -> 13.184s; full E2E (incl VAE) 18.955s -> 18.282s
(-3.55%). Per-kernel quant microbench: KA/KB/KC/KD 1.55-1.85x vs torch.compile
(fused inline FP4). gate_resid (bf16) ties with compile at memory-bound peak.

77/77 unit tests pass (tests/unittest/_torch/thop/parallel_hw_agnostic/).

Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
@luyiyun1021 luyiyun1021 changed the title [TRTLLM-13265][feat] Fuse LTX-2 Gate + Residual + Norm + AdaLN modulation + Quant kernels [TRTLLM-13265][feat] Fuse LTX-2 Gate + Residual + Norm + AdaLN modulation(ShiftScale) + Quant kernels Jun 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant