Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
NVIDIA Model Optimizer Changelog
================================

0.44 (2026-xx-xx)
^^^^^^^^^^^^^^^^^

**New Features**

- Enable PTQ workflow for the Step3.5-Flash MoE model with NVFP4 W4A4 + FP8 KV cache quantization.

0.43 (2026-03-xx)
^^^^^^^^^^^^^^^^^

Expand Down
46 changes: 46 additions & 0 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,9 @@ def _process_quantized_modules(
):
sub_module.unpack_weight()
if get_quantization_format(sub_module) != QUANTIZATION_NONE:
# Skip QuantMoELinear - it's handled separately in _reconstruct_step3p5_moe_linear
if type(sub_module).__name__ == "QuantMoELinear":
continue
if is_quantlinear(sub_module):
try:
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
Expand Down Expand Up @@ -670,6 +673,46 @@ def _process_quantized_modules(
_export_quantized_weight(sub_module, dtype, weight_name)


def _reconstruct_step3p5_moe_linear(model: nn.Module) -> None:
"""Reconstruct QuantMoELinear per-expert weights back to original 3D MoELinear format.

After _process_quantized_modules, each expert's nn.Linear inside QuantMoELinear has:
- weight: fp4-quantized tensor [out_features, in_features]
- weight_scale, weight_scale_2: per-block / global scales
- input_scale: activation scale (if calibrated)

This stacks them back into the original MoELinear layout so the exported state_dict
uses the original key names (e.g. moe.up_proj.weight with shape [N, out, in]).

Note: QuantMoELinear is the dynamically generated class name (Quant + MoELinear),
not _QuantMoELinear which is the implementation class.
"""
for name, module in model.named_modules():
# Match QuantMoELinear (dynamically generated name) not _QuantMoELinear (implementation class)
if type(module).__name__ != "QuantMoELinear":
continue

n = module.num_experts
experts = module.experts

# Reconstruct 3D weight: [num_experts, out_features, in_features]
module.weight = nn.Parameter(
torch.stack([experts[i].weight.data for i in range(n)]),
requires_grad=False,
)

# Stack per-expert scales back under the original attribute names
for attr in ("weight_scale", "weight_scale_2", "input_scale"):
if hasattr(experts[0], attr):
module.register_buffer(
attr,
torch.stack([getattr(experts[i], attr) for i in range(n)]),
)
Comment on lines +705 to +710
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Handle untouched experts before stacking input_scale.

This assumes every Step3p5 expert exported the same scale buffers. Unlike the other MoE export paths earlier in this file, QuantMoELinear.experts never get a set_expert_quantizer_amax() fallback, so any calibration run that leaves some experts untouched can make input_scale missing on only a subset of experts. That turns export into either an AttributeError here or a silently missing stacked input_scale when expert 0 was never hit. Please backfill/validate buffer presence before this stack.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_hf.py` around lines 705 - 710, The loop
that stacks per-expert buffers assumes every expert in QuantMoELinear.experts
has weight_scale/weight_scale_2/input_scale, but some experts can be untouched
(no input_scale) causing AttributeError or silent omission; before stacking each
attr, iterate experts to verify presence and for missing buffers backfill a
sensible fallback (e.g., zeros tensor with same shape or copy from a calibrated
expert) or raise a clear error, then build the list via getattr(experts[i],
attr) for i in range(n) and register_buffer as shown; ensure this
validation/backfill logic targets the experts list on QuantMoELinear (and the
attr names "weight_scale", "weight_scale_2", "input_scale") so the torch.stack
call always receives n tensors.


# Remove expanded experts — the reconstructed 3D tensors replace them
del module.experts


def _export_transformers_checkpoint(
model: nn.Module, dtype: torch.dtype | None = None, is_modelopt_qlora: bool = False, **kwargs
) -> tuple[dict[str, Any], dict[str, Any]]:
Expand Down Expand Up @@ -791,6 +834,9 @@ def _export_transformers_checkpoint(
# Process all quantized modules and export weights
_process_quantized_modules(model, dtype, is_modelopt_qlora)

# Reconstruct Step3p5 MoELinear: per-expert _QuantLinear weights → original 3D format
_reconstruct_step3p5_moe_linear(model)

if accelerator is not None:
# Gather state_dict from all ranks
quantized_state_dict = accelerator.get_state_dict(model)
Expand Down
7 changes: 7 additions & 0 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,13 @@
"*mlp*input_quantizer": _nvfp4_quantizer,
"*block_sparse_moe*weight_quantizer": _nvfp4_quantizer,
"*block_sparse_moe*input_quantizer": _nvfp4_quantizer,
# Step3p5 MoE experts: MoELinear lives at *.moe.{up,gate,down}_proj
"*moe*weight_quantizer": _nvfp4_quantizer,
"*moe*input_quantizer": _nvfp4_quantizer,
Comment on lines +645 to +646
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this wildcard pattern moe too broad? This matches ANY module with "moe" anywhere in its path

# disable *mode.gate.* for router
"*moe.gate.*": {"enable": False},
# Disable share_expert (dense MLP alongside MoE, not in MLP-only quant scope)
"*share_expert*": {"enable": False},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we move it to the experts_only cfg?

**_default_disabled_quantizer_cfg,
}

Expand Down
64 changes: 64 additions & 0 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -1468,10 +1468,74 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model):
is_homogeneous_hf_model, get_homogeneous_hf_decoder_layers
)


class _QuantMoELinear(QuantModule):
"""Quantization wrapper for Step3p5 MoELinear modules (fused expert weights).

MoELinear has weight shape [num_experts, out_features, in_features] with
forward(x, expert_id). We expand it into per-expert nn.Linear modules so
each expert gets its own weight_quantizer and input_quantizer, calibrated
only on tokens actually routed to that expert.

On export, _reconstruct_step3p5_moe_linear() stacks the per-expert quantized
weights and scales back into the original 3D format.
"""

def _setup(self):
from accelerate import init_empty_weights

dtype, device = self.weight.dtype, self.weight.device

with init_empty_weights():
experts = nn.ModuleList(
[
nn.Linear(self.in_features, self.out_features, bias=False)
for _ in range(self.num_experts)
]
)

for i in range(self.num_experts):
experts[i].to_empty(device=device)
with torch.no_grad():
experts[i].weight.data = self.weight[i].detach().to(dtype=dtype, device=device)

delattr(self, "weight")
self.experts = experts

def forward(self, x, expert_id):
# experts[expert_id] is a _QuantLinear after quantization wrapping,
# providing per-expert input_quantizer and weight_quantizer.
# Cast input to match expert weight dtype before linear operation,
# then cast output to float32 to match original MoELinear forward behavior.
expert = self.experts[expert_id]
x = x.to(expert.weight.dtype)
return expert(x).float()


def register_step3p5_moe_on_the_fly(model):
"""Register Step3p5 MoELinear for quantization.

Step3p5 uses a custom MoELinear class (loaded via trust_remote_code) with
weight shape [num_experts, out_features, in_features] and forward(x, expert_id).
We detect it by model class name, then grab the type from the first MoE layer.
"""
if type(model).__name__ not in ("Step3p5ForCausalLM", "Step3p5Model"):
return
for module in model.modules():
if type(module).__name__ == "Step3p5MoEMLP":
moe_linear_type = type(module.up_proj)
if QuantModuleRegistry.get(moe_linear_type) is None:
QuantModuleRegistry.register({moe_linear_type: f"hf.{moe_linear_type.__name__}"})(
_QuantMoELinear
)
break


CUSTOM_MODEL_PLUGINS.update(
[
register_falcon_linears_on_the_fly,
register_dbrx_moe_on_the_fly,
register_step3p5_moe_on_the_fly,
register_sparse_moe_on_the_fly,
register_hf_attentions_on_the_fly,
convert_hf_parallel_linears_on_the_fly,
Expand Down
18 changes: 18 additions & 0 deletions modelopt_recipes/general/ptq/nvfp4_mlp_only-fp8_kv.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,24 @@ ptq_cfg:
scale_bits: e4m3
num_bits: e2m1
enable: true
'*moe*weight_quantizer':
block_sizes:
-1: 16
type: dynamic
scale_bits: e4m3
num_bits: e2m1
enable: true
'*moe*input_quantizer':
block_sizes:
-1: 16
type: dynamic
scale_bits: e4m3
num_bits: e2m1
enable: true
'*moe.gate.*':
enable: false
'*share_expert*':
enable: false
default:
enable: false
'*block_sparse_moe.gate*':
Expand Down
84 changes: 84 additions & 0 deletions modelopt_recipes/models/Step3.5-Flash-nvfp4-mlp-only.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

metadata:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shengliangxu should we go with /models/Step3.5-Flash/nvfp4-mlp-only.yaml instead?

recipe_type: ptq
description: NVFP4 static weight and dynamic activation for all linear layers (W4A4), FP8 KV cache, max calibration.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Align the recipe description with the enabled scope.

Only *moe* / *mlp* quantizers plus KV BMM are enabled here; default: false leaves the rest of the linear layers disabled. Calling this “all linear layers” will set the wrong expectation for users picking the recipe.

📝 Suggested wording
-  description: NVFP4 static weight and dynamic activation for all linear layers (W4A4), FP8 KV cache, max calibration.
+  description: NVFP4 W4A4 for MoE/MLP projections, FP8 KV cache, max calibration.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
description: NVFP4 static weight and dynamic activation for all linear layers (W4A4), FP8 KV cache, max calibration.
description: NVFP4 W4A4 for MoE/MLP projections, FP8 KV cache, max calibration.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt_recipes/models/Step3.5-Flash.yaml` at line 18, Update the recipe
description string to accurately reflect the enabled scope: change "all linear
layers" to specify that only MoE/MLP layers and KV BMM are quantized (e.g.,
"NVFP4 static weight and dynamic activation for MoE/MLP linear layers (W4A4),
FP8 KV cache, max calibration") so it matches the actual configuration where
default: false leaves other linear layers disabled; reference the description
field and the quantizer scopes like *moe*, *mlp* and the KV BMM setting when
editing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Align the description with the actual quantization scope.

The description says “for all linear layers,” but default.enable: false plus explicit allowlist patterns means only selected groups are quantized. Please tighten wording to avoid misconfiguration during usage.

Also applies to: 54-55

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt_recipes/models/Step3.5-Flash-nvfp4-moe-only.yaml` at line 18, The
YAML description string inaccurately states "for all linear layers" while the
recipe uses "default.enable: false" plus explicit allowlist patterns, so update
the description field (the top-level description) to reflect that only selected
groups are quantized (e.g., "selected linear layers via allowlist patterns" or
"specific module groups, not all linear layers"); also review and adjust any
duplicate descriptions later in the file that mirror this wording (the secondary
description occurrences) so they match the actual quantization scope defined by
default.enable and the allowlist patterns.

ptq_cfg:
algorithm: max
quant_cfg:
'*moe*weight_quantizer':
block_sizes:
-1: 16
type: dynamic
scale_bits: e4m3
num_bits: e2m1
enable: true
'*moe*input_quantizer':
block_sizes:
-1: 16
type: dynamic
scale_bits: e4m3
num_bits: e2m1
enable: true
'*mlp*weight_quantizer':
block_sizes:
-1: 16
type: dynamic
scale_bits: e4m3
num_bits: e2m1
enable: true
'*mlp*input_quantizer':
block_sizes:
-1: 16
type: dynamic
scale_bits: e4m3
num_bits: e2m1
enable: true
'*share_expert.*':
enable: false
'*moe.gate.*':
enable: false
default:
enable: false
'*linear_attn.conv1d*':
enable: false
'*lm_head*':
enable: false
'*mixer.conv1d*':
enable: false
'*output_layer*':
enable: false
'*proj_out.*':
enable: false
'*router*':
enable: false
output.*:
enable: false
nn.BatchNorm1d:
'*':
enable: false
nn.BatchNorm2d:
'*':
enable: false
nn.BatchNorm3d:
'*':
enable: false
nn.LeakyReLU:
'*':
enable: false
'*[kv]_bmm_quantizer':
num_bits: e4m3
enable: true
Loading