diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a7a856cd5..ae876dd64 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,13 @@ NVIDIA Model Optimizer Changelog ================================ +0.44 (2026-xx-xx) +^^^^^^^^^^^^^^^^^ + +**New Features** + +- Enable PTQ workflow for the Step3.5-Flash MoE model with NVFP4 W4A4 + FP8 KV cache quantization. + 0.43 (2026-03-xx) ^^^^^^^^^^^^^^^^^ diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 78c8874a0..831f68318 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -641,6 +641,9 @@ def _process_quantized_modules( ): sub_module.unpack_weight() if get_quantization_format(sub_module) != QUANTIZATION_NONE: + # Skip QuantMoELinear - it's handled separately in _reconstruct_step3p5_moe_linear + if type(sub_module).__name__ == "QuantMoELinear": + continue if is_quantlinear(sub_module): try: with fsdp2_aware_weight_update(model, sub_module, reshard=False): @@ -670,6 +673,48 @@ def _process_quantized_modules( _export_quantized_weight(sub_module, dtype, weight_name) +def _reconstruct_step3p5_moe_linear(model: nn.Module) -> None: + """Reconstruct QuantMoELinear per-expert weights back to original 3D MoELinear format. + + After _process_quantized_modules, each expert's nn.Linear inside QuantMoELinear has: + - weight: fp4-quantized tensor [out_features, in_features] + - weight_scale, weight_scale_2: per-block / global scales + - input_scale: activation scale (if calibrated) + + This stacks them back into the original MoELinear layout so the exported state_dict + uses the original key names (e.g. moe.up_proj.weight with shape [N, out, in]). + + Note: QuantMoELinear is the dynamically generated class name (Quant + MoELinear), + not _QuantMoELinear which is the implementation class. + """ + for name, module in model.named_modules(): + # Match QuantMoELinear (dynamically generated name) not _QuantMoELinear (implementation class) + if type(module).__name__ != "QuantMoELinear": + continue + + n = module.num_experts + experts = module.experts + + # Reconstruct 3D weight: [num_experts, out_features, in_features] + module.weight = nn.Parameter( + torch.stack([experts[i].weight.data for i in range(n)]), + requires_grad=False, + ) + + # Stack per-expert scales back under the original attribute names. + # Check all experts: some may lack input_scale if they were never routed + # during calibration, so only stack when every expert has the attribute. + for attr in ("weight_scale", "weight_scale_2", "input_scale"): + if all(hasattr(experts[i], attr) for i in range(n)): + module.register_buffer( + attr, + torch.stack([getattr(experts[i], attr) for i in range(n)]), + ) + + # Remove expanded experts — the reconstructed 3D tensors replace them + del module.experts + + def _export_transformers_checkpoint( model: nn.Module, dtype: torch.dtype | None = None, is_modelopt_qlora: bool = False, **kwargs ) -> tuple[dict[str, Any], dict[str, Any]]: @@ -791,6 +836,9 @@ def _export_transformers_checkpoint( # Process all quantized modules and export weights _process_quantized_modules(model, dtype, is_modelopt_qlora) + # Reconstruct Step3p5 MoELinear: per-expert _QuantLinear weights → original 3D format + _reconstruct_step3p5_moe_linear(model) + if accelerator is not None: # Gather state_dict from all ranks quantized_state_dict = accelerator.get_state_dict(model) diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index e894e6221..e2457273c 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -1468,10 +1468,74 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model): is_homogeneous_hf_model, get_homogeneous_hf_decoder_layers ) + +class _QuantMoELinear(QuantModule): + """Quantization wrapper for Step3p5 MoELinear modules (fused expert weights). + + MoELinear has weight shape [num_experts, out_features, in_features] with + forward(x, expert_id). We expand it into per-expert nn.Linear modules so + each expert gets its own weight_quantizer and input_quantizer, calibrated + only on tokens actually routed to that expert. + + On export, _reconstruct_step3p5_moe_linear() stacks the per-expert quantized + weights and scales back into the original 3D format. + """ + + def _setup(self): + from accelerate import init_empty_weights + + dtype, device = self.weight.dtype, self.weight.device + + with init_empty_weights(): + experts = nn.ModuleList( + [ + nn.Linear(self.in_features, self.out_features, bias=False) + for _ in range(self.num_experts) + ] + ) + + for i in range(self.num_experts): + experts[i].to_empty(device=device) + with torch.no_grad(): + experts[i].weight.data = self.weight[i].detach().to(dtype=dtype, device=device) + + delattr(self, "weight") + self.experts = experts + + def forward(self, x, expert_id): + # experts[expert_id] is a _QuantLinear after quantization wrapping, + # providing per-expert input_quantizer and weight_quantizer. + # Cast input to match expert weight dtype before linear operation, + # then cast output to float32 to match original MoELinear forward behavior. + expert = self.experts[expert_id] + x = x.to(expert.weight.dtype) + return expert(x).float() + + +def register_step3p5_moe_on_the_fly(model): + """Register Step3p5 MoELinear for quantization. + + Step3p5 uses a custom MoELinear class (loaded via trust_remote_code) with + weight shape [num_experts, out_features, in_features] and forward(x, expert_id). + We detect it by model class name, then grab the type from the first MoE layer. + """ + if type(model).__name__ not in ("Step3p5ForCausalLM", "Step3p5Model"): + return + for module in model.modules(): + if type(module).__name__ == "Step3p5MoEMLP": + moe_linear_type = type(module.up_proj) + if QuantModuleRegistry.get(moe_linear_type) is None: + QuantModuleRegistry.register({moe_linear_type: f"hf.{moe_linear_type.__name__}"})( + _QuantMoELinear + ) + break + + CUSTOM_MODEL_PLUGINS.update( [ register_falcon_linears_on_the_fly, register_dbrx_moe_on_the_fly, + register_step3p5_moe_on_the_fly, register_sparse_moe_on_the_fly, register_hf_attentions_on_the_fly, convert_hf_parallel_linears_on_the_fly, diff --git a/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml b/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml new file mode 100644 index 000000000..e70160e98 --- /dev/null +++ b/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +metadata: + recipe_type: ptq + description: NVFP4 static weight and dynamic activation for MoE/MLP projections (W4A4), FP8 KV cache, max calibration. +ptq_cfg: + algorithm: max + quant_cfg: + '*moe*weight_quantizer': + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + enable: true + '*moe*input_quantizer': + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + enable: true + '*mlp*weight_quantizer': + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + enable: true + '*mlp*input_quantizer': + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + enable: true + '*share_expert*': + enable: false + '*moe.gate.*': + enable: false + default: + enable: false + '*linear_attn.conv1d*': + enable: false + '*lm_head*': + enable: false + '*mixer.conv1d*': + enable: false + '*output_layer*': + enable: false + '*proj_out.*': + enable: false + '*router*': + enable: false + output.*: + enable: false + nn.BatchNorm1d: + '*': + enable: false + nn.BatchNorm2d: + '*': + enable: false + nn.BatchNorm3d: + '*': + enable: false + nn.LeakyReLU: + '*': + enable: false + '*[kv]_bmm_quantizer': + num_bits: e4m3 + enable: true