[TRTLLM-13265][feat] Fuse LTX-2 Gate + Residual + Norm + AdaLN modulation(ShiftScale) + Quant kernels#15102
[TRTLLM-13265][feat] Fuse LTX-2 Gate + Residual + Norm + AdaLN modulation(ShiftScale) + Quant kernels#15102luyiyun1021 wants to merge 6 commits into
Conversation
📝 WalkthroughWalkthroughAdds fused CUDA kernels and a templated launcher for DiT-style RMSNorm (BF16 with optional residual/gate/modulate) plus NVFP4 FP4+SF output, exposes them as multiple Torch ops, integrates fused dispatch into LTX-2 model code and NVFP4 linear/attention paths, and adds fake registrations and unit tests. ChangesFused DiT RMSNorm kernels and model integration
🎯 4 (Complex) | ⏱️ ~60 minutes Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 7
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/visual_gen/models/ltx2/ltx2_core/utils_ltx2.py (1)
18-402: 🛠️ Refactor suggestion | 🟠 Major | 🏗️ Heavy liftMove these TRT-LLM-specific fused helpers out of
ltx2_core.This file sits in the upstream-mirrored LTX-2 subtree, but the new surface here is TensorRT-LLM-specific custom-op dispatch, NVFP4 handling, and fusion policy. Keeping that logic in
ltx2_corewill make future upstream syncs much harder; please move it into a TensorRT-LLM-owned adapter module and keep this subtree close to upstream. Based on learnings: files undertensorrt_llm/_torch/visual_gen/models/ltx2/ltx2_core/are ported directly from the upstream Lightricks LTX-2 repository and should remain faithful to upstream.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/visual_gen/models/ltx2/ltx2_core/utils_ltx2.py` around lines 18 - 402, This code introduces TensorRT-LLM-specific fused kernels and NVFP4 handling (e.g. get_nvfp4_input_scale, apply_fused_adaln_modulate, apply_fused_gate_resid_rms_modulate, apply_fused_resid_gate_rms_quant, apply_fused_resid_rms_dual_shift_scale and Fp4QuantizedTensor usage) inside the upstream-mirrored ltx2_core subtree; extract these TRT-LLM-specific helpers into a new TensorRT-LLM owned adapter module, move all TensorRT-only ops/quant logic and NVFP4 helpers there, update imports in this file to call into the adapter (or provide thin shim wrappers delegating to the new module), and ensure the original ltx2_core file retains only upstream-faithful logic so upstream syncs remain clean.Source: Learnings
🧹 Nitpick comments (2)
tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py (1)
468-473: ⚡ Quick winDrop the benchmark-only env override before merge.
TLLM_DISABLE_FUSE_ADALNchanges execution behavior through an undocumented process env var, and the comment already marks it as temporary. Leaving it in both constructors bakes a hidden debug switch into production code; either remove it or promote it to a documented config/benchmark knob.Also applies to: 1297-1301
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py` around lines 468 - 473, Remove the temporary benchmark-only environment override that toggles self._fuse_adaln (the block importing os as _os_bench and checking TLLM_DISABLE_FUSE_ADALN) from the transformer_ltx2 constructors; locate the same pattern around the second constructor (the block at ~1297-1301) and delete it or replace it with a proper, documented configuration parameter (e.g., a constructor arg or a config object) that explicitly controls fuse_adaln instead of relying on an undocumented process env var. Ensure the final behavior uses the explicit config/constructor value to set self._fuse_adaln and remove any leftover _os_bench import or env checks.tensorrt_llm/_torch/visual_gen/modules/attention.py (1)
479-484: ⚡ Quick winRestore an explicit type annotation for
hidden_states.The signature dropped type information. Please annotate it as
torch.Tensor | Fp4QuantizedTensor(or equivalent) to keep static typing intact.Suggested fix
def forward( self, - hidden_states, + hidden_states: torch.Tensor | Fp4QuantizedTensor, encoder_hidden_states: Optional[torch.Tensor] = None, freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor:As per coding guidelines, "Always annotate functions."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/visual_gen/modules/attention.py` around lines 479 - 484, The forward method lost an explicit type for hidden_states—restore it by annotating the parameter as a union type (e.g. hidden_states: torch.Tensor | Fp4QuantizedTensor) so static typing is preserved; update the forward signature in the attention module (function forward) to use that union for hidden_states while leaving encoder_hidden_states and freqs as-is and ensure any imports or forward references for Fp4QuantizedTensor are added or aliased if necessary.Source: Coding guidelines
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@cpp/tensorrt_llm/kernels/fusedDiTNormKernel.cu`:
- Around line 958-1046: The launcher dispatches HAS_QUANT specializations
unconditionally but those kernels require runtime SM >= 10; add a runtime
compute-capability guard at the top of launchFusedDiTNorm that queries
cudaGetDeviceProperties (or cudaGetDevice) and, when HAS_QUANT is true and
deviceProp.major < 10 (or equivalent SM check), fail fast (TLLM_THROW or
TLLM_CHECK_WITH_INFO) rather than launching; apply the same check before
dispatching both the pipelined path (fusedDiTNormKernelPipelined) and the
regular fusedDiTNormKernel launches so quantized specializations are never
launched on unsupported SMs.
In `@tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py`:
- Around line 1154-1236: The fake/meta registrations for the quantized ops were
added but the bf16 op schemas are missing, causing FakeTensor tracing to fail
when apply_fused_adaln_modulate, apply_fused_gate_resid_rms_modulate, or
apply_fused_resid_rms_dual_shift_scale dispatch the bf16 paths; add
`@torch.library.register_fake` handlers for the bf16 op names
trtllm::fused_dit_rmsnorm_shift_scale,
trtllm::fused_dit_gate_resid_rms_modulate, and
trtllm::fused_dit_resid_rms_shift_scale_dual that mirror the non-quantized
signatures used by those functions and return tensors with the same
shapes/dtypes as the inputs (e.g., torch.empty_like(x) or the appropriate tuple
of empty_like tensors matching the eager op outputs), so FakeTensor/tracing can
succeed.
In `@tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py`:
- Around line 459-467: The _fuse_adaln flag currently only checks hidden dims
via is_fused_adaln_supported_dim(video.dim)/audio.dim; update the guard to also
require bf16 dtype so non-bf16 configs don't take the fused path. Import torch
if needed and change the v_ok/a_ok conditions in transformer_ltx2 (where
_fuse_adaln is set) to require video is None or (video.dtype == torch.bfloat16
and is_fused_adaln_supported_dim(video.dim)) and similarly for audio, then set
self._fuse_adaln = v_ok and a_ok.
In
`@tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_gate_resid_rms_modulate.py`:
- Around line 123-131: The test currently computes D_ref by calling
torch.ops.trtllm.fused_dit_gate_resid_rms_modulate and then F.linear, coupling
the quant-path validation to the bf16 fused op; change the baseline to use an
independent reference implementation by calling torch_ref(...) (the standalone
reference kernel) with the same inputs instead of out_bf16, and assign its
result to D_ref so the quant test validates against torch_ref rather than the
fused op; update the D_ref assignment in the test block that currently uses
out_bf16 to call torch_ref with x, attn, gate_table, gate_ts, scale_table,
scale_ts, shift_table, shift_ts, eps (or the appropriate torch_ref signature)
and then apply F.linear with W as before.
In
`@tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_resid_gate_rms_norm_quant.py`:
- Around line 112-116: The test should fail only for the unsupported-dimension
error: change sf_scale from a one-element vector to a scalar tensor (e.g.,
torch.tensor(1.0, dtype=torch.float32, device=device)) so it cannot cause a
different failure, and tighten the assertion on
torch.ops.trtllm.fused_dit_resid_gate_rms_norm_quant by using
pytest.raises(RuntimeError, match=r"unsupported.*dimension") (or the exact
unsupported-dimension message) to ensure the test only passes when that specific
error is raised.
In
`@tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_resid_rms_shift_scale_dual.py`:
- Around line 138-144: The test currently computes D1_ref/D2_ref by calling the
implementation under test torch.ops.trtllm.fused_dit_resid_rms_shift_scale_dual
(referenced as fused_dit_resid_rms_shift_scale_dual), which can mask shared
regressions; replace that use with the independent reference implementation
torch_ref (call torch_ref(x_bf16, attn2, s1_table, s1_ts, h1_table, h1_ts,
s2_table, s2_ts, h2_table, h2_ts, eps) or the appropriate torch_ref signature)
to produce o1_bf16/o2_bf16 and then build D1_ref = F.linear(o1_bf16, W1) and
D2_ref = F.linear(o2_bf16, W2); keep all input variables (x_bf16, attn2,
s*_table, s*_ts, h*_table, h*_ts, eps, W1, W2) the same so the test decouples
quant validation from fused_dit_resid_rms_shift_scale_dual.
In
`@tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_rmsnorm_shift_scale.py`:
- Around line 165-168: The test currently broad-catches any RuntimeError when
calling torch.ops.trtllm.fused_dit_rmsnorm_shift_scale; tighten the assertion to
only accept the unsupported-dimension error by using pytest.raises with a match
regex (e.g. pytest.raises(RuntimeError, match=r"unsupported.*dimension") ) so
the test for fused_dit_rmsnorm_shift_scale validates the specific
unsupported-dimension contract rather than any runtime failure.
---
Outside diff comments:
In `@tensorrt_llm/_torch/visual_gen/models/ltx2/ltx2_core/utils_ltx2.py`:
- Around line 18-402: This code introduces TensorRT-LLM-specific fused kernels
and NVFP4 handling (e.g. get_nvfp4_input_scale, apply_fused_adaln_modulate,
apply_fused_gate_resid_rms_modulate, apply_fused_resid_gate_rms_quant,
apply_fused_resid_rms_dual_shift_scale and Fp4QuantizedTensor usage) inside the
upstream-mirrored ltx2_core subtree; extract these TRT-LLM-specific helpers into
a new TensorRT-LLM owned adapter module, move all TensorRT-only ops/quant logic
and NVFP4 helpers there, update imports in this file to call into the adapter
(or provide thin shim wrappers delegating to the new module), and ensure the
original ltx2_core file retains only upstream-faithful logic so upstream syncs
remain clean.
---
Nitpick comments:
In `@tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py`:
- Around line 468-473: Remove the temporary benchmark-only environment override
that toggles self._fuse_adaln (the block importing os as _os_bench and checking
TLLM_DISABLE_FUSE_ADALN) from the transformer_ltx2 constructors; locate the same
pattern around the second constructor (the block at ~1297-1301) and delete it or
replace it with a proper, documented configuration parameter (e.g., a
constructor arg or a config object) that explicitly controls fuse_adaln instead
of relying on an undocumented process env var. Ensure the final behavior uses
the explicit config/constructor value to set self._fuse_adaln and remove any
leftover _os_bench import or env checks.
In `@tensorrt_llm/_torch/visual_gen/modules/attention.py`:
- Around line 479-484: The forward method lost an explicit type for
hidden_states—restore it by annotating the parameter as a union type (e.g.
hidden_states: torch.Tensor | Fp4QuantizedTensor) so static typing is preserved;
update the forward signature in the attention module (function forward) to use
that union for hidden_states while leaving encoder_hidden_states and freqs as-is
and ensure any imports or forward references for Fp4QuantizedTensor are added or
aliased if necessary.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 74ebb7eb-0878-4102-8b71-70ca7fff551d
📒 Files selected for processing (16)
cpp/tensorrt_llm/kernels/fusedDiTNormKernel.cucpp/tensorrt_llm/kernels/fusedDiTNormKernel.hcpp/tensorrt_llm/thop/CMakeLists.txtcpp/tensorrt_llm/thop/fusedDiTGateResidualRmsNormModulateOp.cppcpp/tensorrt_llm/thop/fusedDiTResidGateRmsNormOp.cppcpp/tensorrt_llm/thop/fusedDiTResidualRmsNormShiftScaleDualOp.cppcpp/tensorrt_llm/thop/fusedDiTRmsNormShiftScaleOp.cpptensorrt_llm/_torch/custom_ops/cpp_custom_ops.pytensorrt_llm/_torch/modules/linear.pytensorrt_llm/_torch/visual_gen/models/ltx2/ltx2_core/utils_ltx2.pytensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.pytensorrt_llm/_torch/visual_gen/modules/attention.pytests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_gate_resid_rms_modulate.pytests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_resid_gate_rms_norm_quant.pytests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_resid_rms_shift_scale_dual.pytests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_rmsnorm_shift_scale.py
| template <bool HAS_RESIDUAL, bool HAS_GATE, bool HAS_MODULATE, int NUM_OUT, bool HAS_QUANT> | ||
| void launchFusedDiTNorm(AdaLNNormParams const& params, int hidden_dim, cudaStream_t stream) | ||
| { | ||
| TLLM_CHECK_WITH_INFO(params.num_tokens >= 1, "num_tokens must be >= 1, got %d", params.num_tokens); | ||
| TLLM_CHECK_WITH_INFO( | ||
| params.tokens_per_batch >= 1, "tokens_per_batch must be >= 1, got %d", params.tokens_per_batch); | ||
| TLLM_CHECK_WITH_INFO(params.num_tokens % params.tokens_per_batch == 0, | ||
| "num_tokens (%d) must be divisible by tokens_per_batch (%d)", params.num_tokens, params.tokens_per_batch); | ||
|
|
||
| // Production tile selected via NCU sweep at D=4096, B=1, T=15360: 1-row/CTA, 256 threads/CTA | ||
| // gives 2 CTAs/SM (regs <= 80) and the highest per-element bench rate among {2r256, 1r256, 1r512}. | ||
| constexpr int ROWS_PER_BLOCK = 1; | ||
| constexpr int BLOCK_SIZE = 256; | ||
| constexpr int WARPS_PER_ROW = (BLOCK_SIZE / ROWS_PER_BLOCK) / 32; | ||
|
|
||
| cudaLaunchConfig_t cfg = {}; | ||
| cfg.stream = stream; | ||
| cudaLaunchAttribute attrs[1] = {}; | ||
| attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; | ||
| attrs[0].val.programmaticStreamSerializationAllowed = 1; | ||
| cfg.attrs = attrs; | ||
| cfg.numAttrs = 1; | ||
|
|
||
| // Optional pipelined dispatch (multi-row CTA + circular TMA stages + warp specialization). | ||
| // Enabled by env var TLLM_DIT_NORM_PIPELINED=1. Fires only when: | ||
| // - hidden_dim == 4096 (video path; audio D=2048 has too small a grid to amortize the | ||
| // warp-specialization overhead and bench-regresses) | ||
| // - variant is KA/KB/KC (KD's HAS_GATE && !HAS_MODULATE is HBM-bound and bench-regresses) | ||
| // - tokens_per_batch >= 4 && % 4 == 0 (R_CTA=4 rows share batchIdx for per-CTA mod cache) | ||
| // Otherwise falls through to the default path below. | ||
| static bool const pipelined_enabled = []() | ||
| { | ||
| char const* env = std::getenv("TLLM_DIT_NORM_PIPELINED"); | ||
| return env != nullptr && env[0] != '\0' && env[0] != '0'; | ||
| }(); | ||
| constexpr bool kPipelinedVariantOK = HAS_MODULATE || !HAS_GATE; | ||
| if constexpr (kPipelinedVariantOK) | ||
| { | ||
| // NUM_STAGES chosen per HAS_QUANT: bf16 path is HBM-latency-bound so deeper pipeline | ||
| // (3 stages) helps; quant Phase 2 is compute-heavy and extra stages just add SMEM | ||
| // pressure without commensurate latency hiding. | ||
| constexpr int PIPE_NUM_STAGES = HAS_QUANT ? 2 : 3; | ||
| constexpr int PIPE_BLOCK_SIZE = 288; | ||
| constexpr int PIPE_D = 4096; | ||
| constexpr int PIPE_R_CTA = 4; | ||
| if (pipelined_enabled && hidden_dim == PIPE_D && (params.num_tokens % PIPE_R_CTA == 0) | ||
| && (params.tokens_per_batch >= PIPE_R_CTA) && (params.tokens_per_batch % PIPE_R_CTA == 0)) | ||
| { | ||
| constexpr int pipe_stage_elems = HAS_RESIDUAL ? (2 * PIPE_D) : PIPE_D; | ||
| size_t const pipe_smem_bytes | ||
| = static_cast<size_t>(PIPE_NUM_STAGES) * pipe_stage_elems * sizeof(__nv_bfloat16) | ||
| + 2 * PIPE_NUM_STAGES * sizeof(uint64_t) + 8 * sizeof(float) + 16; | ||
| cudaLaunchConfig_t pipe_cfg = cfg; | ||
| pipe_cfg.gridDim = dim3((params.num_tokens + PIPE_R_CTA - 1) / PIPE_R_CTA); | ||
| pipe_cfg.blockDim = dim3(PIPE_BLOCK_SIZE); | ||
| pipe_cfg.dynamicSmemBytes = static_cast<int>(pipe_smem_bytes); | ||
| auto* pipe_kp = fusedDiTNormKernelPipelined<PIPE_D, PIPE_R_CTA, PIPE_NUM_STAGES, HAS_RESIDUAL, HAS_GATE, | ||
| HAS_MODULATE, NUM_OUT, HAS_QUANT>; | ||
| cudaFuncSetAttribute( | ||
| pipe_kp, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast<int>(pipe_smem_bytes)); | ||
| cudaLaunchKernelEx(&pipe_cfg, pipe_kp, params); | ||
| return; | ||
| } | ||
| } | ||
| (void) pipelined_enabled; | ||
|
|
||
| #define LAUNCH(D_VAL) \ | ||
| do \ | ||
| { \ | ||
| constexpr bool USE_TMA_HOST = (D_VAL >= 4096) && !(HAS_QUANT && HAS_MODULATE && NUM_OUT == 1); \ | ||
| constexpr bool USE_TMA_ATTN_HOST = USE_TMA_HOST && HAS_RESIDUAL && HAS_QUANT; \ | ||
| int const attn_extra_bytes \ | ||
| = USE_TMA_ATTN_HOST ? (ROWS_PER_BLOCK * D_VAL * static_cast<int>(sizeof(__nv_bfloat16))) : 0; \ | ||
| cfg.gridDim = dim3((params.num_tokens + ROWS_PER_BLOCK - 1) / ROWS_PER_BLOCK); \ | ||
| cfg.blockDim = dim3(BLOCK_SIZE); \ | ||
| cfg.dynamicSmemBytes = ROWS_PER_BLOCK * D_VAL * static_cast<int>(sizeof(__nv_bfloat16)) + attn_extra_bytes \ | ||
| + ROWS_PER_BLOCK * WARPS_PER_ROW * static_cast<int>(sizeof(float)); \ | ||
| cudaLaunchKernelEx(&cfg, \ | ||
| fusedDiTNormKernel<D_VAL, ROWS_PER_BLOCK, BLOCK_SIZE, HAS_RESIDUAL, HAS_GATE, HAS_MODULATE, NUM_OUT, \ | ||
| HAS_QUANT>, \ | ||
| params); \ | ||
| } while (0) | ||
|
|
||
| switch (hidden_dim) | ||
| { | ||
| case 2048: LAUNCH(2048); break; | ||
| case 4096: LAUNCH(4096); break; | ||
| default: TLLM_THROW("Unsupported hidden_dim for fusedDiTNorm: %d (only 2048, 4096)", hidden_dim); | ||
| } |
There was a problem hiding this comment.
Add a runtime SM guard before launching quantized specializations.
HAS_QUANT kernels only emit FP4/SF stores in __CUDA_ARCH__ >= 1000 blocks, but the launcher currently dispatches them unconditionally. On lower-SM runtime targets this can return undefined outputs instead of failing fast.
Suggested fix
template <bool HAS_RESIDUAL, bool HAS_GATE, bool HAS_MODULATE, int NUM_OUT, bool HAS_QUANT>
void launchFusedDiTNorm(AdaLNNormParams const& params, int hidden_dim, cudaStream_t stream)
{
+ if constexpr (HAS_QUANT)
+ {
+ int dev = 0;
+ int cc_major = 0;
+ cudaGetDevice(&dev);
+ cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, dev);
+ TLLM_CHECK_WITH_INFO(
+ cc_major >= 10,
+ "fusedDiTNorm quant variants require SM100+; got compute capability major %d",
+ cc_major);
+ }
+
TLLM_CHECK_WITH_INFO(params.num_tokens >= 1, "num_tokens must be >= 1, got %d", params.num_tokens);🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@cpp/tensorrt_llm/kernels/fusedDiTNormKernel.cu` around lines 958 - 1046, The
launcher dispatches HAS_QUANT specializations unconditionally but those kernels
require runtime SM >= 10; add a runtime compute-capability guard at the top of
launchFusedDiTNorm that queries cudaGetDeviceProperties (or cudaGetDevice) and,
when HAS_QUANT is true and deviceProp.major < 10 (or equivalent SM check), fail
fast (TLLM_THROW or TLLM_CHECK_WITH_INFO) rather than launching; apply the same
check before dispatching both the pipelined path (fusedDiTNormKernelPipelined)
and the regular fusedDiTNormKernel launches so quantized specializations are
never launched on unsupported SMs.
| @torch.library.register_fake("trtllm::fused_dit_rmsnorm_shift_scale_quant") | ||
| def _(x, scale_table, scale_ts, shift_table, shift_ts, sf_scale, eps): | ||
| """Fake/meta for fused KA + NVFP4 quant. SF layout is SWIZZLED 128x4. | ||
|
|
||
| Modulator is built inline from (table, ts) pairs: | ||
| scale[d] = scale_table[d] + scale_ts[b, d] | ||
| Folds the upstream broadcast-add Triton prep kernel into the C++ op. | ||
| """ | ||
| D = x.shape[-1] | ||
| M = 1 | ||
| for d in x.shape[:-1]: | ||
| M *= d | ||
| _, scale_shape = fp4_utils.get_fp4_shape((M, D), | ||
| 16, | ||
| is_swizzled_layout=True) | ||
| out_fp4 = x.new_empty((M, D // 2), dtype=torch.uint8) | ||
| out_sf = x.new_empty((scale_shape, ), dtype=torch.uint8) | ||
| return out_fp4, out_sf | ||
|
|
||
| @torch.library.register_fake( | ||
| "trtllm::fused_dit_gate_resid_rms_modulate_quant") | ||
| def _(x, attn_out, gate_table, gate_ts, scale_table, scale_ts, shift_table, | ||
| shift_ts, sf_scale, eps): | ||
| """Fake/meta for fused KC + NVFP4 quant. SF layout is SWIZZLED 128x4. | ||
|
|
||
| Modulators built inline from (table, ts) pairs; folds upstream | ||
| broadcast-add Triton prep kernel into the C++ op. | ||
| """ | ||
| D = x.shape[-1] | ||
| M = 1 | ||
| for d in x.shape[:-1]: | ||
| M *= d | ||
| _, scale_shape = fp4_utils.get_fp4_shape((M, D), | ||
| 16, | ||
| is_swizzled_layout=True) | ||
| out_fp4 = x.new_empty((M, D // 2), dtype=torch.uint8) | ||
| out_sf = x.new_empty((scale_shape, ), dtype=torch.uint8) | ||
| return out_fp4, out_sf | ||
|
|
||
| @torch.library.register_fake("trtllm::fused_dit_resid_gate_rms_norm") | ||
| def _(x, attn_out, gate_table, gate_ts, eps): | ||
| """Fake/meta for fused KD bf16 (residual_add + gate_mul + rms_norm). | ||
| Gate built inline from (table, ts) pair -- folds the upstream | ||
| broadcast-add Triton prep + `attn * gate` mul into Phase 0b.""" | ||
| return torch.empty_like(x) | ||
|
|
||
| @torch.library.register_fake("trtllm::fused_dit_resid_gate_rms_norm_quant") | ||
| def _(x, attn_out, gate_table, gate_ts, sf_scale, eps): | ||
| """Fake/meta for fused KD (residual_add + gate_mul + rms_norm + NVFP4 quant). | ||
| Gate built inline from (table, ts) pair. SF layout is SWIZZLED 128x4.""" | ||
| D = x.shape[-1] | ||
| M = 1 | ||
| for d in x.shape[:-1]: | ||
| M *= d | ||
| _, scale_shape = fp4_utils.get_fp4_shape((M, D), | ||
| 16, | ||
| is_swizzled_layout=True) | ||
| out_fp4 = x.new_empty((M, D // 2), dtype=torch.uint8) | ||
| out_sf = x.new_empty((scale_shape, ), dtype=torch.uint8) | ||
| return out_fp4, out_sf | ||
|
|
||
| @torch.library.register_fake( | ||
| "trtllm::fused_dit_resid_rms_shift_scale_dual_quant") | ||
| def _(x, attn2_out, scale_dir1_table, scale_dir1_ts, shift_dir1_table, | ||
| shift_dir1_ts, scale_dir2_table, scale_dir2_ts, shift_dir2_table, | ||
| shift_dir2_ts, sf_scale1, sf_scale2, eps): | ||
| """Fake/meta for fused KB + dual NVFP4 quant. SF layout is SWIZZLED 128x4. | ||
|
|
||
| 4 modulators built inline from (table, ts) pairs; folds upstream | ||
| broadcast-add Triton prep kernel into the C++ op. | ||
| """ | ||
| D = x.shape[-1] | ||
| M = 1 | ||
| for d in x.shape[:-1]: | ||
| M *= d | ||
| _, scale_shape = fp4_utils.get_fp4_shape((M, D), | ||
| 16, | ||
| is_swizzled_layout=True) | ||
| out1_fp4 = x.new_empty((M, D // 2), dtype=torch.uint8) | ||
| out1_sf = x.new_empty((scale_shape, ), dtype=torch.uint8) | ||
| out2_fp4 = x.new_empty((M, D // 2), dtype=torch.uint8) | ||
| out2_sf = x.new_empty((scale_shape, ), dtype=torch.uint8) | ||
| return out1_fp4, out1_sf, out2_fp4, out2_sf |
There was a problem hiding this comment.
Register fake/meta kernels for the bf16 variants too.
These additions cover the quantized ops plus KD bf16, but apply_fused_adaln_modulate(), apply_fused_gate_resid_rms_modulate(), and apply_fused_resid_rms_dual_shift_scale() still dispatch to the bf16 op names trtllm::fused_dit_rmsnorm_shift_scale, trtllm::fused_dit_gate_resid_rms_modulate, and trtllm::fused_dit_resid_rms_shift_scale_dual when the downstream linear is not NVFP4-quantized. Without fake registrations for those schemas, torch.compile / FakeTensor tracing will fail on the bf16 path even though eager execution works.
Suggested fake registrations
+ `@torch.library.register_fake`("trtllm::fused_dit_rmsnorm_shift_scale")
+ def _(x, scale_table, scale_ts, shift_table, shift_ts, eps):
+ return torch.empty_like(x)
+
+ `@torch.library.register_fake`("trtllm::fused_dit_gate_resid_rms_modulate")
+ def _(x, attn_out, gate_table, gate_ts, scale_table, scale_ts, shift_table, shift_ts, eps):
+ return torch.empty_like(x)
+
+ `@torch.library.register_fake`("trtllm::fused_dit_resid_rms_shift_scale_dual")
+ def _(
+ x,
+ attn2_out,
+ scale_dir1_table,
+ scale_dir1_ts,
+ shift_dir1_table,
+ shift_dir1_ts,
+ scale_dir2_table,
+ scale_dir2_ts,
+ shift_dir2_table,
+ shift_dir2_ts,
+ eps,
+ ):
+ return torch.empty_like(x), torch.empty_like(x)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py` around lines 1154 - 1236,
The fake/meta registrations for the quantized ops were added but the bf16 op
schemas are missing, causing FakeTensor tracing to fail when
apply_fused_adaln_modulate, apply_fused_gate_resid_rms_modulate, or
apply_fused_resid_rms_dual_shift_scale dispatch the bf16 paths; add
`@torch.library.register_fake` handlers for the bf16 op names
trtllm::fused_dit_rmsnorm_shift_scale,
trtllm::fused_dit_gate_resid_rms_modulate, and
trtllm::fused_dit_resid_rms_shift_scale_dual that mirror the non-quantized
signatures used by those functions and return tensors with the same
shapes/dtypes as the inputs (e.g., torch.empty_like(x) or the appropriate tuple
of empty_like tensors matching the eager op outputs), so FakeTensor/tracing can
succeed.
| # Whether to dispatch AdaLN modulation to the fused CUDA kernels. Resolved | ||
| # once at construction; call sites just consult the flag. The kernels are | ||
| # bf16 + hidden_dim in {2048, 4096}; non-matching cases raise at the C++ | ||
| # boundary -- this flag is the only Python-side guard. | ||
| from .ltx2_core.utils_ltx2 import is_fused_adaln_supported_dim | ||
|
|
||
| v_ok = video is None or is_fused_adaln_supported_dim(video.dim) | ||
| a_ok = audio is None or is_fused_adaln_supported_dim(audio.dim) | ||
| self._fuse_adaln = v_ok and a_ok |
There was a problem hiding this comment.
Gate _fuse_adaln on bf16 as well as hidden size.
All four new fused thop entrypoints hard-check torch.bfloat16, but _fuse_adaln only checks the hidden dims. A float16/float32 LTX-2 config with a supported inner dim will now take the fused path and fail on the first custom-op call instead of cleanly staying on the eager path.
Also applies to: 1288-1296
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py` around lines
459 - 467, The _fuse_adaln flag currently only checks hidden dims via
is_fused_adaln_supported_dim(video.dim)/audio.dim; update the guard to also
require bf16 dtype so non-bf16 configs don't take the fused path. Import torch
if needed and change the v_ok/a_ok conditions in transformer_ltx2 (where
_fuse_adaln is set) to require video is None or (video.dtype == torch.bfloat16
and is_fused_adaln_supported_dim(video.dim)) and similarly for audio, then set
self._fuse_adaln = v_ok and a_ok.
| # bf16 reference: compute KC's bf16 output, then F.linear with random W. | ||
| torch.manual_seed(13) | ||
| W = torch.randn(out_dim, hidden_dim, dtype=torch.bfloat16, device="cuda") * 0.05 | ||
|
|
||
| x_bf16 = x.clone() | ||
| out_bf16 = torch.ops.trtllm.fused_dit_gate_resid_rms_modulate( | ||
| x_bf16, attn, gate_table, gate_ts, scale_table, scale_ts, shift_table, shift_ts, eps | ||
| ) | ||
| D_ref = F.linear(out_bf16, W) |
There was a problem hiding this comment.
Use an independent reference in the quant test path.
Line 128 derives D_ref from the bf16 fused op, so quant validation is coupled to the same implementation family. That weakens detection of shared regressions at tokens_per_batch=126 (a shape not covered by the bf16 test). Use torch_ref(...) for the baseline here.
Suggested patch
- x_bf16 = x.clone()
- out_bf16 = torch.ops.trtllm.fused_dit_gate_resid_rms_modulate(
- x_bf16, attn, gate_table, gate_ts, scale_table, scale_ts, shift_table, shift_ts, eps
- )
+ out_bf16, _ = torch_ref(
+ x,
+ attn,
+ gate_table,
+ gate_ts,
+ scale_table,
+ scale_ts,
+ shift_table,
+ shift_ts,
+ tokens_per_batch,
+ eps,
+ )
D_ref = F.linear(out_bf16, W)As per coding guidelines, tests/** should provide sufficient and actionable coverage; this change prevents false confidence from coupled references.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # bf16 reference: compute KC's bf16 output, then F.linear with random W. | |
| torch.manual_seed(13) | |
| W = torch.randn(out_dim, hidden_dim, dtype=torch.bfloat16, device="cuda") * 0.05 | |
| x_bf16 = x.clone() | |
| out_bf16 = torch.ops.trtllm.fused_dit_gate_resid_rms_modulate( | |
| x_bf16, attn, gate_table, gate_ts, scale_table, scale_ts, shift_table, shift_ts, eps | |
| ) | |
| D_ref = F.linear(out_bf16, W) | |
| # bf16 reference: compute KC's bf16 output, then F.linear with random W. | |
| torch.manual_seed(13) | |
| W = torch.randn(out_dim, hidden_dim, dtype=torch.bfloat16, device="cuda") * 0.05 | |
| out_bf16, _ = torch_ref( | |
| x, | |
| attn, | |
| gate_table, | |
| gate_ts, | |
| scale_table, | |
| scale_ts, | |
| shift_table, | |
| shift_ts, | |
| tokens_per_batch, | |
| eps, | |
| ) | |
| D_ref = F.linear(out_bf16, W) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In
`@tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_gate_resid_rms_modulate.py`
around lines 123 - 131, The test currently computes D_ref by calling
torch.ops.trtllm.fused_dit_gate_resid_rms_modulate and then F.linear, coupling
the quant-path validation to the bf16 fused op; change the baseline to use an
independent reference implementation by calling torch_ref(...) (the standalone
reference kernel) with the same inputs instead of out_bf16, and assign its
result to D_ref so the quant test validates against torch_ref rather than the
fused op; update the D_ref assignment in the test block that currently uses
out_bf16 to call torch_ref with x, attn, gate_table, gate_ts, scale_table,
scale_ts, shift_table, shift_ts, eps (or the appropriate torch_ref signature)
and then apply F.linear with W as before.
Source: Coding guidelines
| sf_scale = torch.tensor([1.0], dtype=torch.float32, device=device) | ||
| with pytest.raises(RuntimeError): | ||
| torch.ops.trtllm.fused_dit_resid_gate_rms_norm_quant( | ||
| x, attn, gate_table, gate_ts, sf_scale, 1e-6 | ||
| ) |
There was a problem hiding this comment.
Make the unsupported-dim test fail only for the intended reason.
Line 113 uses a broad RuntimeError assertion, and Line 112 uses a [1] tensor for sf_scale, which can introduce alternate failure causes. Use a scalar sf_scale and match the unsupported-dimension message.
Suggested patch
- sf_scale = torch.tensor([1.0], dtype=torch.float32, device=device)
- with pytest.raises(RuntimeError):
+ sf_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
+ with pytest.raises(RuntimeError, match=r"(unsupported|2048|4096)"):
torch.ops.trtllm.fused_dit_resid_gate_rms_norm_quant(
x, attn, gate_table, gate_ts, sf_scale, 1e-6
)As per coding guidelines, tests/** should keep coverage actionable and targeted to the contract being validated.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In
`@tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_resid_gate_rms_norm_quant.py`
around lines 112 - 116, The test should fail only for the unsupported-dimension
error: change sf_scale from a one-element vector to a scalar tensor (e.g.,
torch.tensor(1.0, dtype=torch.float32, device=device)) so it cannot cause a
different failure, and tighten the assertion on
torch.ops.trtllm.fused_dit_resid_gate_rms_norm_quant by using
pytest.raises(RuntimeError, match=r"unsupported.*dimension") (or the exact
unsupported-dimension message) to ensure the test only passes when that specific
error is raised.
Source: Coding guidelines
| # bf16 reference: KB bf16 outputs -> F.linear with random W1 / W2. | ||
| x_bf16 = x.clone() | ||
| o1_bf16, o2_bf16 = torch.ops.trtllm.fused_dit_resid_rms_shift_scale_dual( | ||
| x_bf16, attn2, s1_table, s1_ts, h1_table, h1_ts, s2_table, s2_ts, h2_table, h2_ts, eps | ||
| ) | ||
| D1_ref = F.linear(o1_bf16, W1) | ||
| D2_ref = F.linear(o2_bf16, W2) |
There was a problem hiding this comment.
Decouple quant validation from the bf16 fused implementation.
Line 140 builds D1_ref/D2_ref from fused_dit_resid_rms_shift_scale_dual, which is part of the same implementation family being tested. This can mask shared regressions at tokens_per_batch=126. Build the quant reference from torch_ref(...) instead.
Suggested patch
- # bf16 reference: KB bf16 outputs -> F.linear with random W1 / W2.
- x_bf16 = x.clone()
- o1_bf16, o2_bf16 = torch.ops.trtllm.fused_dit_resid_rms_shift_scale_dual(
- x_bf16, attn2, s1_table, s1_ts, h1_table, h1_ts, s2_table, s2_ts, h2_table, h2_ts, eps
- )
+ # Independent bf16 reference from torch_ref.
+ o1_bf16, o2_bf16, _ = torch_ref(
+ x,
+ attn2,
+ s1_table,
+ s1_ts,
+ h1_table,
+ h1_ts,
+ s2_table,
+ s2_ts,
+ h2_table,
+ h2_ts,
+ tokens_per_batch,
+ eps,
+ )
D1_ref = F.linear(o1_bf16, W1)
D2_ref = F.linear(o2_bf16, W2)As per coding guidelines, for tests/** coverage should be sufficient and actionable; independent references are needed to catch real regressions.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In
`@tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_resid_rms_shift_scale_dual.py`
around lines 138 - 144, The test currently computes D1_ref/D2_ref by calling the
implementation under test torch.ops.trtllm.fused_dit_resid_rms_shift_scale_dual
(referenced as fused_dit_resid_rms_shift_scale_dual), which can mask shared
regressions; replace that use with the independent reference implementation
torch_ref (call torch_ref(x_bf16, attn2, s1_table, s1_ts, h1_table, h1_ts,
s2_table, s2_ts, h2_table, h2_ts, eps) or the appropriate torch_ref signature)
to produce o1_bf16/o2_bf16 and then build D1_ref = F.linear(o1_bf16, W1) and
D2_ref = F.linear(o2_bf16, W2); keep all input variables (x_bf16, attn2,
s*_table, s*_ts, h*_table, h*_ts, eps, W1, W2) the same so the test decouples
quant validation from fused_dit_resid_rms_shift_scale_dual.
Source: Coding guidelines
| with pytest.raises(RuntimeError): | ||
| torch.ops.trtllm.fused_dit_rmsnorm_shift_scale( | ||
| x, scale_table, scale_ts, shift_table, shift_ts, 1e-6 | ||
| ) |
There was a problem hiding this comment.
Narrow the unsupported-dimension assertion to the intended contract.
Line 165 currently accepts any RuntimeError, so unrelated failures can make this test pass. Match the expected unsupported-dimension error text.
Suggested patch
- with pytest.raises(RuntimeError):
+ with pytest.raises(RuntimeError, match=r"(unsupported|2048|4096)"):
torch.ops.trtllm.fused_dit_rmsnorm_shift_scale(
x, scale_table, scale_ts, shift_table, shift_ts, 1e-6
)As per coding guidelines, tests/** feedback should be actionable and coverage should clearly validate the intended behavior.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In
`@tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_rmsnorm_shift_scale.py`
around lines 165 - 168, The test currently broad-catches any RuntimeError when
calling torch.ops.trtllm.fused_dit_rmsnorm_shift_scale; tighten the assertion to
only accept the unsupported-dimension error by using pytest.raises with a match
regex (e.g. pytest.raises(RuntimeError, match=r"unsupported.*dimension") ) so
the test for fused_dit_rmsnorm_shift_scale validates the specific
unsupported-dimension contract rather than any runtime failure.
Source: Coding guidelines
… (table,ts) pair-form modulator
Four fused CUDA kernels cover all per-block AdaLN sites for the LTX-2
transformer (video + audio modalities, 48 blocks per stage). Each kernel
consumes (table_fp32[D], ts_bf16[B,T,D]) modulator pairs and combines
them inline in Phase 0b, instead of taking a pre-added bf16 modulator.
The combine semantics match PyTorch eager _get_all_ada_values: narrow
the fp32 Parameter slice to bf16 first, then bf16 hw add via
__hadd2(__float22bfloat162_rn(table), ts). This folds the broadcast-add
prep work (previously emitted by Inductor as a separate leaf POI kernel
per call site) into the consumer C++ op.
Kernels (file -> op family):
- fusedDiTRmsNormShiftScaleKernel (KA)
rms_norm(x) -> (1 + scale) * norm + shift
2 modulator pairs (scale, shift). 2 rows / CTA, BLOCK_SIZE=256.
- fusedDiTResidualRmsNormShiftScaleDualKernelA (KB)
(x + attn2) -> rms_norm -> two parallel (1+s)*norm + h
4 modulator pairs (2x scale,shift). 1 row / CTA, BLOCK_SIZE=512.
- fusedDiTGateResidualRmsNormModulateKernel (KC)
(x + attn * gate) -> rms_norm -> (1 + scale) * norm + shift
3 modulator pairs (gate, scale, shift). 1 row / CTA, BLOCK_SIZE=512.
- fusedDiTResidGateRmsNormQuantKernel (KD)
(x + attn * gate) -> rms_norm [-> NVFP4 quant]
1 modulator pair (gate). 1 row / CTA, BLOCK_SIZE=512.
Each kernel has a bf16-output variant and an NVFP4-output variant
selected via a <bool DO_QUANT> template parameter. The quant variant
directly emits packed FP4 + 128x4 SWIZZLED FP8 e4m3 scale factors that
the downstream NVFP4 Linear (qkv_proj, ff.up_proj, cross_attn.to_q/to_k)
consumes without re-quantizing.
Op signatures (Tensor schemas in thop):
- KA: (x, scale_table, scale_ts, shift_table, shift_ts, [sf_scale,] eps)
- KB: (x, attn2, *(table, ts) x4, [sf_scale1, sf_scale2,] eps)
- KC: (x, attn, *(table, ts) x3, [sf_scale,] eps)
- KD: (x, attn_out, gate_table, gate_ts, [sf_scale,] eps)
Python:
- _get_ada_value_pair helper in transformer_ltx2.py returns a
(table_slice, ts_slice) view without materializing the broadcast-add.
Sibling to the existing _get_all_ada_values helper.
- apply_fused_adaln_modulate / _gate_resid_rms_modulate /
_resid_rms_dual_shift_scale / _resid_gate_rms_quant wrappers in
utils_ltx2.py dispatch the bf16 / quant variants based on
fp4_input_scale.
- Linear forward (modules/linear.py) accepts Fp4QuantizedTensor input.
- Attention forward (visual_gen/modules/attention.py) handles
Fp4QuantizedTensor for qkv_proj input.
- apply_fused_resid_rms_quant kept as backward-compat alias for
apply_fused_resid_gate_rms_quant.
Fake op registrations in cpp_custom_ops.py updated to match the new
schemas.
Unit tests (under tests/unittest/_torch/thop/parallel_hw_agnostic/):
KA: test_fused_dit_rmsnorm_shift_scale.py 28/28 pass
KB: test_fused_dit_resid_rms_shift_scale_dual_a.py 12/12 pass
KC: test_fused_dit_gate_resid_rms_modulate.py 12/12 pass
KD: test_fused_dit_resid_gate_rms_norm_quant.py 13/13 pass
Tolerance: rtol=2e-2, atol=1e-2. Numerics match PyTorch eager
_get_all_ada_values byte-for-byte at the combined modulator boundary.
E2E (1 GPU, CUDA Graph, 121 frames @ 768x1280, 40 steps, B200, 3-run avg):
fuse_OFF (Inductor eager) : 483.86 ms / step
fuse_ON (this commit) : 481.67 ms / step (-2.19 ms, -0.45%)
Inductor cache dump verifies the KA/KB/KC/KD modulator prep kernels are
eliminated; remaining POI kernels are unrelated GEMM bias adds and FFN
GELU+FP4 quant pointwise.
Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
…by template
Adds compile-time hybrid TMA / cp.async gating to the unified fusedDiTNormKernel.
On Blackwell (sm_90+), the X tensor (and, for quant+residual templates, the attn
tensor) is loaded via a single cp.async.bulk.shared::cta.global instruction with
mbarrier completion, replacing the 32-issue cp.async loop. The original cp.async
path is preserved as the compile-time fallback for templates where TMA does not
win on paired bench.
Two constexpr gates pick the load path per template instantiation:
USE_TMA = (D >= 4096) && !(HAS_QUANT && HAS_MODULATE && NUM_OUT == 1)
USE_TMA_ATTN = USE_TMA && HAS_RESIDUAL && HAS_QUANT
Microbench summary (paired vs baseline, no regressions across 32 cells):
Template V2 d% V1 d%
KA-bf16 -9.0 -9.0 (TMA)
KA-quant -0.1 +1.0 (cp.async fallback)
KB-bf16 -11.0 -11.5 (TMA)
KB-quant -19.7 -20.9 (TMA + TMA-attn)
KC-bf16 -1.2 -0.8 (TMA, marginal)
KC-quant +0.1 +0.2 (cp.async fallback)
KD-bf16 -1.4 -1.1 (TMA, marginal)
KD-quant -9.6 -11.5 (TMA + TMA-attn)
Production E2E (cfg=1, 40-step, paired 5 runs/side):
per_step : 481.4 -> 473.2 ms (-8.2 ms, -1.71%)
E2E : 19.26 -> 18.93 s (-330 ms, -1.71%)
Step-4 nsys (--cuda-graph-trace=node):
total GPU: 377.85 -> 371.01 ms (-6.84 ms, -1.81%)
launches : 3617 -> 3395
KB-quant V1 (heaviest kernel): 8.50 ms / 42 calls / 202 us per call.
Implementation notes:
- mbarrier in static SMEM (8 B, alignas(8)); compiler elides when USE_TMA=false.
- When USE_TMA_ATTN, smem_attn slot follows smem_x; warp_sums offset shifts.
Host launcher dyn_smem includes the extra slot via matching constexpr.
- Phase 0c reads attn from smem_attn (USE_TMA_ATTN) or stays as direct LDG.
- Wait site gated: mbarrier.try_wait.parity vs __pipeline_wait_prior(0).
- No new fields in AdaLNNormParams; non-tensor cp.async.bulk variant means
no cuTensorMapEncodeTiled host overhead.
- All 8 template instantiations rebuilt; ptxas reports 0 spill stores/loads.
Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
… per-slot mod helper cpp/tensorrt_llm/kernels/fusedDiTNormKernel.cu: Add fusedDiTNormKernelPipelined: multi-row CTA (R_CTA=4) with circular SMEM stages and 1 producer / 8 consumer warp specialization. Producer issues cp.async.bulk for X (and attn when HAS_RESIDUAL); consumers wait on full_bar, compute x_new + RMSNorm + modulate + write, then arrive empty_bar. NUM_STAGES=3 for bf16, NUM_STAGES=2 for quant. Inline PTX helpers (mbar_init/arrive/wait, cp_async_bulk, bar_sync_consumer) are hoisted out of the existing kernel for reuse. Gated by env var TLLM_DIT_NORM_PIPELINED on D=4096 + KA/KB/KC variants; KD falls through to the default path. All other shapes/variants unchanged. cpp/tensorrt_llm/kernels/fusedDiTNormKernel.h: Trim NCU sweep narrative in the tile-selection comment. tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py: Restore _get_ada_values(table, batch_size, timestep, indices: slice) signature so MSA (slice 0..3) / MLP (slice 3..6) / gate (slice 5..6) fetch independently. Add pair-form companion _get_ada_table_ts_pairs(..., indices: slice) returning one (table_slice, ts_slice) per slot. Collapse 22 per-slot call sites across MSA video / MSA audio / AV CA video / AV CA audio / MLP video / MLP audio into 6 slice-based calls. Drop dead pre-fetch of all 6 combined-form values at the top of each run_*x branch. Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
…drop unused alias - Remove backward-compat alias ``apply_fused_resid_rms_quant`` (zero callers). - Strip references to renamed/dropped helpers (``_get_ada_value_pair``, ``_get_all_ada_values``) and the "previously pre-added bf16 modulator" / "Inductor leaf POI prep kernel" narrative. - Condense each eager-fallback comment to a single line that points at the one autotuner-bug note in ``apply_fused_adaln_modulate``. Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
…strip K* names - Auto-dispatch in launchFusedDiTNorm: replace TLLM_DIT_NORM_PIPELINED env-var gate with a pure shape/template predicate. One Python op entry point; C++ decides between default and pipelined kernels per (hidden_dim, num_tokens, tokens_per_batch, variant flags). - transformer_ltx2: drop defer_text_residual / defer_a2v_into_ka / defer_v2a_into_ka flags. Producers always park attn outputs in *_attn_raw slots; consumers naturally bifurcate on `is not None`; per-batch perturbation masks are pre-multiplied onto attn before deferring (mask is per-batch, gate is per-feature -- commutative). A single av_ca_runs flag gates a cleanup at the AV CA exit for single-modality models where the consumer block doesn't run. - Add _get_av_ca_ada_table_ts_pairs as pair-form companion to _get_av_ca_ada_values; refactor the two KB-site callers to use it instead of `_get_ada_table_ts_pairs(table[:4], slice(0, 4))`. - Strip all KA/KB/KC/KD dev-stage names from code comments and test files; describe each kernel by its operation instead. - Add tokens_per_batch=128 cells to KA/KB/KC quant unit tests so the auto-dispatched pipelined kernel gets CI numerical coverage on the quant path (existing tokens_per_batch=126 cells hit the default tile). - Rename v_ok / a_ok -> video_supports_fused_adaln / audio_supports_fused_adaln; hoist `import os` to file head; refresh the TLLM_DISABLE_FUSE_ADALN comment to describe its intentional A/B benchmarking role. Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
8facf8f to
1bca0a6
Compare
…+ FFN callsite Extend the fused-DiT kernel family with a gate-residual-only variant for the FFN write-back site (`vx = vx + ff(vx_scaled) * gate`): - Add HAS_NORM template flag to fusedDiTGateResidNormModulateKernel; the HAS_NORM=false instantiation skips the rms reduce + Phase 2 modulate paths and returns after Phase 0d's in-place x_new writeback. - New thop op `fused_dit_gate_resid` (cpp/tensorrt_llm/thop/fusedDiTGateResidOp.cpp) that launches the HAS_NORM=false specialization. - New Python wrapper `apply_fused_gate_resid` in utils_ltx2.py with a torch.compile-safe eager fallback for fuse=False. - Wire FFN gate-residual callsites in transformer_ltx2.py (video + audio FFN) to the new wrapper. Naming consistency: - Rename `apply_fused_adaln_modulate` -> `apply_fused_rms_shift_scale` and `apply_fused_resid_gate` -> `apply_fused_gate_resid` so all wrapper names read in op-execution order (gate->resid->rms->shift_scale). Cleanups: - Strip the TLLM_DIT_NORM_FORCE_DEFAULT and TLLM_DISABLE_FUSE_ADALN bench-only env vars; auto-dispatch / shape-based fuse is the only path now. - Drop the now-unused <cstdlib> include and `import os`. - Move the combined-form `_get_av_ca_ada_values` (AV cross-attn modulators) into the eager `else` branch where its outputs are actually consumed. In the fused path its scale/shift are recomputed via pair-form and its gate output was already dead, so the unconditional call only emitted a redundant torch.compile prep kernel (~9.7us/step) -- now eliminated. E2E (LTX-2 single-stage, audio+video, guidance_scale=1.0, 768x1280x121, 40 steps, B200, 5-run avg): denoise per-step 343.7ms -> 329.6ms (-14.1ms, -4.09%), denoise loop 13.746s -> 13.184s; full E2E (incl VAE) 18.955s -> 18.282s (-3.55%). Per-kernel quant microbench: KA/KB/KC/KD 1.55-1.85x vs torch.compile (fused inline FP4). gate_resid (bf16) ties with compile at memory-bound peak. 77/77 unit tests pass (tests/unittest/_torch/thop/parallel_hw_agnostic/). Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
1bca0a6 to
f2b20a5
Compare
Description
Fuses LTX-2 NVFP4 transformer block's AdaLN modulation patterns into a single templated CUDA kernel. Replaces torch.compile-generated Triton prep kernels plus standalone bf16-to-fp4 quantize kernels with one fused launch per modulation site. The kernel is shape-polymorphic across video and audio tokens and composes six behaviors via templates: residual add, gate weighting, RMSNorm, AdaLN shift_scale, single vs dual shift_scale output, and inline fp4 quantization with per-block scale factors.
Patterns Replaced
Each block modulation pattern becomes one cpp op (the modulating ones also have a
_quantvariant that emits FP4 + per-block SF inline):fused_dit_rmsnorm_shift_scale[_quant]fused_dit_resid_rmsnorm_shift_scale_dual[_quant]fused_dit_gate_resid_rmsnorm_shift_scale[_quant]fused_dit_gate_resid_rmsnorm[_quant]fused_dit_gate_residvx + ff(vx_scaled)·gate)Design
One templated kernel
fusedDiTGateResidNormShiftScaleKernel<D, ROWS_PER_BLOCK, BLOCK_SIZE, HAS_RESIDUAL, HAS_GATE, HAS_NORM, HAS_SHIFT_SCALE, NUM_OUT, HAS_QUANT>(cpp/tensorrt_llm/kernels/fusedDiTNormKernel.cu). Six compile-time flags compose the block-prefix pipeline; 9 explicitlaunchFusedDiTNorminstantiations cover every call site:Modulators are passed as
(table, ts)pair-form and composed inline in-kernel — fp32 table narrowed to bf16 then one__hadd2— so no separate prep kernel runs. FP4 variants emit packed FP4 + 128×4 swizzled SF inline (no standalone quantize launch).Tile + load tactics (auto-dispatched by shape — no env var, no caller knob)
launchFusedDiTNormpicks one of two tiles and one of two X-load mechanisms from(D, variant, token alignment)at launch:Pipeline depth (pipelined only): bf16 = 3 stages (HBM-latency-bound, deeper hides more), fp4 = 2 stages (Phase-2 quant is compute-heavy; extra stages only add SMEM pressure).
attn load (default tile): also TMA-bulk when
HAS_RESIDUAL && HAS_QUANT(KD-quant) — its load overlaps the heavy quant Phase 2; bf16 keeps per-thread LDG (already overlaps the TMA-X load on a separate LSU pipe).Auto-dispatch predicate for the pipelined tile (all must hold):
D==4096, variant hasHAS_SHIFT_SCALE || !HAS_GATE(KA/KB/KC; KD and gate_resid fail it), andnum_tokens % 4 == 0 && tokens_per_batch ≥ 4 && tokens_per_batch % 4 == 0. Default tile is 1-row/256t, chosen by NCU sweep (2 CTAs/SM at ≤80 regs, best per-element rate among {2r256, 1r256, 1r512}).Pipelined vs default mechanics:
__launch_bounds__(288,2))__syncthreads()mbarrier.try_wait.parity+ consumer-onlybar.sync 1, 256Perf (CUDA-Graph, true SM time, video D=4096): pipelined is −10…−25% vs default on KA/KB/KC bf16, −11…−15% on fp4. KD and audio always use the default tile.
Per-kernel Microbenchmark
Pure GPU time via CUDA Graph capture + replay. V1 = (B=1, T=15360, D=4096), V2 = (B=2, T=15360, D=4096).
cpp= this PR's fused kernel (auto-selected variant per the dispatch predicate above);compile=torch.compileof an equivalent eager expression that ends intorch.ops.trtllm.tunable_fp4_quantizewithis_sf_swizzled_layout=True(matches the production Linear path). bf16 outputs verified by cosine ≥ 0.99 against an fp32 reference; FP4 outputs verified at the unit-test layer. Bench harness:tmp/feature/ltx2-adaln-refactor/bench_kabcd_vs_compile.py.End-to-end (LTX-2 single-stage, 1 GPU, audio+video, guidance_scale=1.0, 768×1280×121 frames, 40 denoise steps)
5-run average, warmup excluded. B200. Denoise per-step is the pipeline's own
Denoising doneloop timing (40 steps, excludes VAE decode + text encode); full E2E is the wholegenerate()wall including VAE + text encode.The fused AdaLN kernels touch only the denoise loop, so the -4.09% denoise gain is the kernel-level signal; the full-E2E -3.55% is diluted by the fixed VAE decode + text encode cost shared by both arms.
Per-step kernel A/B (nsys, step 10, ON vs OFF, CUDA-Graph node trace)
Net change in the kernels this PR touches (AdaLN modulation + inline FP4 quant). GEMM, attention (cuDNN SDPA), and QK/split-norm kernels are excluded — this PR does not touch them. torch.compile triton-numbering suffixes (
_view_N) are stripped before aggregation so renumbered-but-identical kernels match.quantize_with_block_size(FP4 quant now inline)The dominant win is the FP4 quantize collapsing from 840 to 336 launches: the KA/KB/KC/KD quant variants now emit packed FP4 + per-block SF inline inside the fused kernel, eliminating 504 standalone
quantize_with_block_sizelaunches plus their bf16 round-trips through HBM. The ~480 added fused kernels are roughly time-neutral against the ~480 triton rms/shift_scale prep kernels they replace.Test Coverage
Four unit tests under
tests/unittest/_torch/thop/parallel_hw_agnostic/exercise each op against an eager PyTorch reference. Thehidden_dim ∈ {2048, 4096} × batch_size ∈ {1, 2} × tokens_per_batchparametrize sweep covers both the default tile and the pipelined variant naturally (no env-var injection):test_fused_dit_rmsnorm_shift_scale.py(KA bf16 + quant; bf16 tpb ∈ {1, 16, 512}, quant tpb ∈ {126, 128})test_fused_dit_resid_rmsnorm_shift_scale_dual.py(KB bf16 + quant; bf16 tpb ∈ {16, 512}, quant tpb ∈ {126, 128})test_fused_dit_gate_resid_rmsnorm_shift_scale.py(KC bf16 + quant; bf16 tpb ∈ {16, 512}, quant tpb ∈ {126, 128})test_fused_dit_gate_resid_rmsnorm_quant.py(KD bf16 + quant; KD always uses the default tile)The
tpb=126cases exercise the default tile (126 % 4 = 2);tpb={16, 128, 512}exercise the pipelined variant on D=4096 KA/KB/KC.PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
If PR introduces API changes, an appropriate PR label is added - either
api-compatibleorapi-breaking. Forapi-breaking, includeBREAKINGin the PR title.Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.Summary by CodeRabbit
New Features
Tests