From 16767d5620f53bdf6301093577476689083a8408 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Wed, 11 Mar 2026 10:08:19 -0700 Subject: [PATCH 1/9] Add Full TE Spec support for Megatron Pruning DynamicModules Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- modelopt/torch/nas/plugins/megatron.py | 336 ++++++++++++++++-- .../torch/prune/plugins/mcore_minitron.py | 74 +++- tests/_test_utils/torch/megatron/models.py | 37 +- .../test_megatron_gpt_dynamic_modules.py | 27 +- .../test_megatron_mamba_dynamic_modules.py | 10 +- .../test_mcore_gpt_minitron_pruning.py | 5 + .../test_mcore_mamba_minitron_pruning.py | 64 +--- 7 files changed, 459 insertions(+), 94 deletions(-) diff --git a/modelopt/torch/nas/plugins/megatron.py b/modelopt/torch/nas/plugins/megatron.py index b408da161..979534da7 100644 --- a/modelopt/torch/nas/plugins/megatron.py +++ b/modelopt/torch/nas/plugins/megatron.py @@ -21,6 +21,12 @@ import torch import torch.nn as nn +from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TELayerNormColumnParallelLinear, + TERowParallelLinear, +) from megatron.core.fusions.fused_layer_norm import FusedLayerNorm from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.models.gpt import GPTModel @@ -32,9 +38,10 @@ ) from megatron.core.transformer.attention import SelfAttention from megatron.core.transformer.dot_product_attention import DotProductAttention +from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP from megatron.core.transformer.moe import moe_utils -from megatron.core.transformer.moe.experts import SequentialMLP +from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.moe.router import TopKRouter from megatron.core.transformer.moe.shared_experts import SharedExpertMLP @@ -55,13 +62,6 @@ SUPPORTED_MODELS = {GPTModel: "megatron.core.models.gpt.GPTModel"} -try: - from megatron.core.extensions.transformer_engine import TEDotProductAttention - - HAS_TE = True -except ImportError: - HAS_TE = False - try: import mamba_ssm # noqa: F401 from megatron.core.models.mamba import MambaModel @@ -127,6 +127,74 @@ def _setup(self, *, input_size: TracedHp | None = None, output_size: TracedHp | ) +# TE Parallel Linear DynamicModules ################################################################ +class _DynamicTEParallelLinear(DynamicModule): + """Base for TE parallel linear layers that use in_features/out_features naming.""" + + def _setup(self, *, input_size: TracedHp | None = None, output_size: TracedHp | None = None): + if input_size is None: + input_size = TracedHp(list(range(1, self.in_features + 1))) + self._register_hparam("input_size", input_size) + + if output_size is None: + output_size = TracedHp(list(range(1, self.out_features + 1))) + self._register_hparam("output_size", output_size) + + self._register_dynamic_attribute("weight", self._get_weight) + # TE stores a zero-length tensor (not None) when bias=False; only register if non-empty + if hasattr(self, "bias") and self.bias is not None and self.bias.numel() > 0: + self._register_dynamic_attribute("bias", self._get_bias) + self._register_dynamic_attribute("in_features", lambda mod, val: mod.input_size) + self._register_dynamic_attribute("out_features", lambda mod, val: mod.output_size) + + @staticmethod + def _get_weight(mod: "_DynamicTEParallelLinear", weight: torch.Tensor) -> torch.Tensor: + return get_sliced_tensor(mod, weight, "output_size", "input_size") + + @staticmethod + def _get_bias( + mod: "_DynamicTEParallelLinear", bias: torch.Tensor | None + ) -> torch.Tensor | None: + return get_sliced_tensor(mod, bias, "output_size") + + +@DMRegistry.register( + {TEColumnParallelLinear: "megatron.core.extensions.transformer_engine.TEColumnParallelLinear"} +) +class _DynamicTEColumnParallelLinear(_DynamicTEParallelLinear): + """A TEColumnParallelLinear layer with dynamic hyperparams.""" + + +@DMRegistry.register( + {TERowParallelLinear: "megatron.core.extensions.transformer_engine.TERowParallelLinear"} +) +class _DynamicTERowParallelLinear(_DynamicTEParallelLinear): + """A TERowParallelLinear layer with dynamic hyperparams.""" + + +@DMRegistry.register( + { + TELayerNormColumnParallelLinear: ( + "megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear" + ) + } +) +class _DynamicTELayerNormColumnParallelLinear(_DynamicTEParallelLinear): + """A TELayerNormColumnParallelLinear with dynamic hyperparams (includes fused layernorm).""" + + def _setup(self, *, input_size: TracedHp | None = None, output_size: TracedHp | None = None): + super()._setup(input_size=input_size, output_size=output_size) + self._register_dynamic_attribute("layer_norm_weight", self._get_ln_param) + if hasattr(self, "layer_norm_bias") and self.layer_norm_bias is not None: + self._register_dynamic_attribute("layer_norm_bias", self._get_ln_param) + + @staticmethod + def _get_ln_param( + mod: "_DynamicTELayerNormColumnParallelLinear", val: torch.Tensor | None + ) -> torch.Tensor | None: + return get_sliced_tensor(mod, val, "input_size") + + # Embedding DynamicModule ########################################################################## @DMRegistry.register( { @@ -434,6 +502,114 @@ def _get_bias( return get_sliced_tensor(mod, bias, "output_size") +# TE QKV/Proj DynamicModules (for full TE spec with fused LayerNorm+Linear) ####################### +class _DynamicTEQKVLayerNormColumnParallelLinear(DynamicModule, TELayerNormColumnParallelLinear): + """TE's fused LayerNorm+ColumnParallelLinear for QKV projection with dynamic attributes.""" + + def _setup(self, *, num_attention_heads: NumAttentionHeadsHp, hidden_size: TracedHp): + self._register_hparam("input_size", hidden_size) + self._register_hparam("num_attention_heads", num_attention_heads) + self._register_dynamic_attribute( + "out_features", + lambda mod, val: (num_attention_heads.active + 2 * mod.config.num_query_groups) + * mod.config.kv_channels, + ) + self._register_dynamic_attribute("weight", self._get_weight) + # TE stores a zero-length tensor (not None) when bias=False; only register if non-empty + if hasattr(self, "bias") and self.bias is not None and self.bias.numel() > 0: + self._register_dynamic_attribute("bias", self._get_bias) + self._register_dynamic_attribute("layer_norm_weight", self._get_ln_param) + if hasattr(self, "layer_norm_bias") and self.layer_norm_bias is not None: + self._register_dynamic_attribute("layer_norm_bias", self._get_ln_param) + + def _get_output_size_indices(self) -> torch.LongTensor: + """Reuses QKV output indexing logic from _DynamicQKVColumnParallelLinear.""" + nheads_hp: NumAttentionHeadsHp = self.get_hparam("num_attention_heads") + nquery_groups = self.config.num_query_groups + max_nheads_per_group = nheads_hp.max // nquery_groups + nheads_per_group = nheads_hp.num_heads_per_group + qkv_heads_per_group = max_nheads_per_group + 2 + + if nheads_hp._slice_order is None and nheads_hp.active == nheads_hp.max: + return slice((max_nheads_per_group + 2) * nquery_groups * self.config.kv_channels) + + q_head_indices = nheads_hp.active_slice + assert isinstance(q_head_indices, torch.LongTensor) + q_head_indices_per_group = q_head_indices.view(nquery_groups, nheads_per_group).cpu() + group_ids = q_head_indices_per_group // max_nheads_per_group + local_pos_in_attn = q_head_indices_per_group - group_ids * max_nheads_per_group + q_head_indices = group_ids * qkv_heads_per_group + local_pos_in_attn + kv_head_indices = ( + torch.arange(nquery_groups)[:, None] * qkv_heads_per_group + + torch.arange(max_nheads_per_group, qkv_heads_per_group)[None, :] + ) + selected_qkv_heads = torch.cat([q_head_indices, kv_head_indices], dim=1).flatten() + selected_indices = expand_head_indices(selected_qkv_heads, self.config.kv_channels) + return selected_indices.cpu() + + @staticmethod + def _get_weight( + mod: "_DynamicTEQKVLayerNormColumnParallelLinear", weight: torch.Tensor + ) -> torch.Tensor: + return get_sliced_tensor_by_slices( + weight, + [mod._get_output_size_indices(), mod.get_hparam("input_size").active_slice], + ) + + @staticmethod + def _get_bias( + mod: "_DynamicTEQKVLayerNormColumnParallelLinear", bias: torch.Tensor | None + ) -> torch.Tensor | None: + if bias is None: + return bias + return get_sliced_tensor_by_slices(bias, [mod._get_output_size_indices()]) + + @staticmethod + def _get_ln_param( + mod: "_DynamicTEQKVLayerNormColumnParallelLinear", val: torch.Tensor | None + ) -> torch.Tensor | None: + return get_sliced_tensor(mod, val, "input_size") + + +class _DynamicTEProjRowParallelLinear(DynamicModule, TERowParallelLinear): + """TE's RowParallelLinear for output projection with dynamic attributes.""" + + def _setup(self, *, num_attention_heads: NumAttentionHeadsHp, hidden_size: TracedHp): + self._register_hparam("output_size", hidden_size) + self._register_hparam("num_attention_heads", num_attention_heads) + self._register_dynamic_attribute( + "in_features", + lambda mod, val: num_attention_heads.active * mod.config.kv_channels, + ) + self._register_dynamic_attribute("weight", self._get_weight) + # TE stores a zero-length tensor (not None) when bias=False; only register if non-empty + if hasattr(self, "bias") and self.bias is not None and self.bias.numel() > 0: + self._register_dynamic_attribute("bias", self._get_bias) + + def _get_input_size_indices(self) -> torch.LongTensor: + """Reuses Proj input indexing logic from _DynamicProjRowParallelLinear.""" + nheads_hp = self.get_hparam("num_attention_heads") + if nheads_hp._slice_order is None and nheads_hp.active == nheads_hp.max: + return slice(nheads_hp.max * self.config.kv_channels) + selected_attn_heads = nheads_hp.active_slice + assert isinstance(selected_attn_heads, torch.LongTensor) + selected_indices = expand_head_indices(selected_attn_heads, self.config.kv_channels) + return selected_indices.cpu() + + @staticmethod + def _get_weight(mod: "_DynamicTEProjRowParallelLinear", weight: torch.Tensor) -> torch.Tensor: + return get_sliced_tensor_by_slices( + weight, + [mod.get_hparam("output_size").active_slice, mod._get_input_size_indices()], + ) + + @staticmethod + def _get_bias( + mod: "_DynamicTEProjRowParallelLinear", bias: torch.Tensor | None + ) -> torch.Tensor | None: + return get_sliced_tensor(mod, bias, "output_size") + + @DMRegistry.register({SelfAttention: "megatron.core.transformer.attention.SelfAttention"}) class _DynamicSelfAttention(DynamicModule): """A SelfAttention layer with dynamic hyperparams. @@ -472,7 +648,7 @@ def _setup(self, *, hidden_size: TracedHp): lambda mod, val: self.num_attention_heads_per_partition, ) else: - assert HAS_TE and isinstance(self.core_attention, TEDotProductAttention) + assert isinstance(self.core_attention, TEDotProductAttention) _DynamicTEDotProductAttention: DynamicModule = type( # noqa: N806 "_DynamicTEDotProductAttention", @@ -486,12 +662,28 @@ def _setup(self, *, hidden_size: TracedHp): ) # Convert the fused qkv and output projection linear layer to dynamic module - _DynamicQKVColumnParallelLinear.convert( - self.linear_qkv, num_attention_heads=num_attention_heads, hidden_size=hidden_size - ) - _DynamicProjRowParallelLinear.convert( - self.linear_proj, num_attention_heads=num_attention_heads, hidden_size=hidden_size - ) + if isinstance(self.linear_qkv, TELayerNormColumnParallelLinear): + _DynamicTEQKVLayerNormColumnParallelLinear.convert( + self.linear_qkv, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + ) + _DynamicTEProjRowParallelLinear.convert( + self.linear_proj, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + ) + else: + _DynamicQKVColumnParallelLinear.convert( + self.linear_qkv, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + ) + _DynamicProjRowParallelLinear.convert( + self.linear_proj, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + ) def export(self) -> torch.nn.Module: """Export the dynamic module to a torch.nn.Module.""" @@ -557,6 +749,85 @@ def export(self) -> torch.nn.Module: return super().export() +class _DynamicGroupedMLP(DynamicModule): + """A GroupedMLP with dynamic hyperparams for pruning packed expert weights. + + GroupedMLP packs all expert weights into weight1 and weight2: + weight1: [hidden_size, num_experts * ffn_out_per_expert] (ffn_out = ffn * gate_factor) + weight2: [num_experts * ffn_per_expert, hidden_size] + Will be registered to DMRegistry if GroupedMLP is available. + """ + + def _setup(self, *, hidden_size: TracedHp): + num_moe_experts = TracedHp(list(range(1, self.num_local_experts + 1))) + self._register_hparam("num_local_experts", num_moe_experts) + + ffn = self.config.moe_ffn_hidden_size + moe_ffn_hidden_size = TracedHp(list(range(1, ffn + 1))) + self._register_hparam("moe_ffn_hidden_size", moe_ffn_hidden_size) + self._register_hparam("hidden_size", hidden_size) + + self._register_dynamic_attribute("weight1", self._get_weight1) + self._register_dynamic_attribute("weight2", self._get_weight2) + + def _get_expert_ffn_col_indices(self, gated: bool) -> torch.LongTensor: + """Build column indices for weight1 (or row indices for weight2 when gated=False).""" + num_experts_hp = self.get_hparam("num_local_experts") + ffn_hp = self.get_hparam("moe_ffn_hidden_size") + max_ffn = ffn_hp.max + + expert_slice = num_experts_hp.active_slice + ffn_slice = ffn_hp.active_slice + + if isinstance(expert_slice, slice): + active_experts = list(range(expert_slice.stop)) + else: + active_experts = expert_slice.tolist() + + if isinstance(ffn_slice, slice): + active_ffn = list(range(ffn_slice.stop)) + else: + active_ffn = ffn_slice.tolist() + + indices = [] + for ei in active_experts: + if gated: + gate_base = ei * max_ffn * 2 + up_base = gate_base + max_ffn + indices.extend(gate_base + fi for fi in active_ffn) + indices.extend(up_base + fi for fi in active_ffn) + else: + base = ei * max_ffn + indices.extend(base + fi for fi in active_ffn) + + return torch.LongTensor(indices) + + @staticmethod + def _get_weight1(mod: "_DynamicGroupedMLP", weight: torch.Tensor) -> torch.Tensor: + hidden_slice = mod.get_hparam("hidden_size").active_slice + col_indices = mod._get_expert_ffn_col_indices(gated=mod.config.gated_linear_unit) + return weight[hidden_slice][:, col_indices].contiguous() + + @staticmethod + def _get_weight2(mod: "_DynamicGroupedMLP", weight: torch.Tensor) -> torch.Tensor: + hidden_slice = mod.get_hparam("hidden_size").active_slice + row_indices = mod._get_expert_ffn_col_indices(gated=False) + return weight[row_indices][:, hidden_slice].contiguous() + + def modify(self, ffn_hidden_size_divisor: int = 1, **kwargs) -> None: + hp = self.get_hparam("moe_ffn_hidden_size") + choices = {int(make_divisible(c, ffn_hidden_size_divisor)) for c in hp.choices} # type: ignore[arg-type] + hp.choices = list(set(hp.choices) & choices | {hp.original}) + + def export(self) -> torch.nn.Module: + return super().export() + + +DMRegistry.register({GroupedMLP: "megatron.core.transformer.moe.experts.GroupedMLP"})( + _DynamicGroupedMLP +) + + @DMRegistry.register({MoELayer: "megatron.core.transformer.moe.moe_layer.MoELayer"}) class _DynamicMoELayer(DynamicModule): """A MoELayer with dynamic hyperparams.""" @@ -603,8 +874,11 @@ def modify( expert_hp.choices = list(set(expert_hp.choices) & choices | {expert_hp.original}) # Modify expert FFN hparam choices - for expert in self.experts.local_experts: - expert.modify(ffn_hidden_size_divisor=ffn_hidden_size_divisor) + if isinstance(self.experts, _DynamicGroupedMLP): + self.experts.modify(ffn_hidden_size_divisor=ffn_hidden_size_divisor) + else: + for expert in self.experts.local_experts: + expert.modify(ffn_hidden_size_divisor=ffn_hidden_size_divisor) if self.use_shared_expert: self.shared_experts.modify(ffn_hidden_size_divisor) @@ -630,6 +904,9 @@ def export(self) -> torch.nn.Module: if self.use_shared_expert: self.shared_experts.export() self._export_reinit_token_dispatcher() + # Update num_local_experts on experts module after export + if hasattr(self.experts, "num_local_experts"): + self.experts.num_local_experts = self.num_local_experts return super().export() @@ -640,16 +917,24 @@ def export(self) -> torch.nn.Module: class _DynamicTransformerLayer(DynamicModule): """A TransformerLayer layer with dynamic hyperparams.""" + @staticmethod + def _is_identity_op(module: nn.Module) -> bool: + """Check if the module is an IdentityOp (layernorm fused into linear in TE spec).""" + return isinstance(module, IdentityOp) + def _setup(self, *, hidden_size: TracedHp): """Setup the TransformerLayer dynamic module with global hidden_size hparam.""" # Convert the layernorms, self-attention, and mlp/moe layers to dynamic modules # NOTE: Mamba stack layers have either Attention or MLP, not both unlike GPT models + # NOTE: In full TE spec, layernorms are IdentityOp (fused into linear layers) if isinstance(self.self_attention, SelfAttention): - DMRegistry.convert(self.input_layernorm, num_features=hidden_size) + if not self._is_identity_op(self.input_layernorm): + DMRegistry.convert(self.input_layernorm, num_features=hidden_size) DMRegistry.convert(self.self_attention, hidden_size=hidden_size) if isinstance(self.mlp, (MLP, MoELayer)): - DMRegistry.convert(self.pre_mlp_layernorm, num_features=hidden_size) + if not self._is_identity_op(self.pre_mlp_layernorm): + DMRegistry.convert(self.pre_mlp_layernorm, num_features=hidden_size) if isinstance(self.mlp, MoELayer): setup_kwargs = {} else: @@ -674,10 +959,12 @@ def modify( def export(self): """Export the dynamic module to a torch.nn.Module.""" if isinstance(self.self_attention, SelfAttention): - self.input_layernorm.export() + if not self._is_identity_op(self.input_layernorm): + self.input_layernorm.export() self.self_attention.export() if isinstance(self.mlp, (MLP, MoELayer)): - self.pre_mlp_layernorm.export() + if not self._is_identity_op(self.pre_mlp_layernorm): + self.pre_mlp_layernorm.export() self.mlp.export() return super().export() @@ -941,7 +1228,9 @@ def _setup(self, *, hidden_size: TracedHp): # Convert to dynamic module DMRegistry.convert(self.mixer, hidden_size=hidden_size) - DMRegistry.convert(self.norm, num_features=hidden_size) + # In TE spec, norm is IdentityOp (fused into mixer.in_proj) + if not _DynamicTransformerLayer._is_identity_op(self.norm): + DMRegistry.convert(self.norm, num_features=hidden_size) def modify( self, @@ -958,7 +1247,8 @@ def modify( def export(self): """Export the dynamic module to a torch.nn.Module.""" self.mixer.export() - self.norm.export() + if not _DynamicTransformerLayer._is_identity_op(self.norm): + self.norm.export() return super().export() diff --git a/modelopt/torch/prune/plugins/mcore_minitron.py b/modelopt/torch/prune/plugins/mcore_minitron.py index 9e7f0faeb..729145953 100644 --- a/modelopt/torch/prune/plugins/mcore_minitron.py +++ b/modelopt/torch/prune/plugins/mcore_minitron.py @@ -53,6 +53,7 @@ from modelopt.torch.nas.plugins.megatron import ( HAS_MAMBA, SUPPORTED_MODELS, + _DynamicGroupedMLP, _DynamicMambaLayer, _DynamicMambaMixer, _DynamicMCoreLanguageModel, @@ -811,6 +812,8 @@ def __init__(self, model: DynamicModule): _register_mlp_importance(module, self) elif isinstance(module, _DynamicSequentialMLP): _register_sequential_mlp_importance(module, self) + elif isinstance(module, _DynamicGroupedMLP): + _register_grouped_mlp_importance(module, self) elif isinstance(module, _DynamicMambaMixer): _register_mamba_mixer_importance(module, self) @@ -974,6 +977,8 @@ def _estimate_hidden_size_importance(mod): return activations # Register hooks for all layers + # For TE spec, layernorms may be IdentityOp (fused into linear layers). + # Hooking on IdentityOp still works — it gives pre-layernorm activations. for layer in module.decoder.layers: if isinstance(layer, _DynamicTransformerLayer): if isinstance(layer.self_attention, _DynamicSelfAttention): @@ -983,7 +988,7 @@ def _estimate_hidden_size_importance(mod): hook_type="forward", ) - if isinstance(layer.mlp, (_DynamicMLP, _DynamicSequentialMLP)): + if isinstance(layer.mlp, (_DynamicMLP, _DynamicSequentialMLP, _DynamicMoELayer)): registry.register_hook( layer.pre_mlp_layernorm, partial(_emb_layernorm_forward_hook, module), @@ -1189,6 +1194,73 @@ def _estimate_expert_importance(mod): ) +def _register_grouped_mlp_importance( + module: _DynamicGroupedMLP, registry: ImportanceEstimatorRegistry +) -> None: + """Register importance estimators for GroupedMLP (MoE experts with grouped GEMM). + + Expert importance is computed from output L2 norms (same as SequentialMLP). + FFN importance is computed from weight2 row magnitudes as an approximation + since per-expert intermediate activations are not easily accessible in grouped GEMM. + """ + module._register_temp_attribute( + "_activations", + { + "expert_l2_scores": torch.zeros(module.num_local_experts), + "expert_sample_counts": torch.zeros(module.num_local_experts), + }, + ) + + def _expert_l2_imp_forward_hook(mod, module_inner, input, output): + """Track expert importance based on L2 norms of expert outputs.""" + tokens_per_expert_list = input[1].tolist() + output_local = output[0].to(torch.float32).detach() + output_local_list = torch.split(output_local, tokens_per_expert_list) + + for expert_idx, expert_output in enumerate(output_local_list): + if expert_output.numel() == 0: + l2_norm = 0.0 + else: + l2_norm = torch.linalg.vector_norm(expert_output, ord=2, dim=-1).sum().item() + mod._activations["expert_l2_scores"][expert_idx] += l2_norm + mod._activations["expert_sample_counts"][expert_idx] += tokens_per_expert_list[ + expert_idx + ] + + def _estimate_expert_importance(mod): + assert mod._activations["expert_sample_counts"].sum() > 0, ( + "No activations collected for importance estimation." + ) + return mod._activations["expert_l2_scores"] / ( + mod._activations["expert_sample_counts"] + 1e-8 + ) + + def _estimate_ffn_importance(mod): + """Approximate FFN importance from weight2 row magnitudes (averaged across experts).""" + weight2 = mod.weight2.data.to(torch.float32) + max_ffn = mod.get_hparam("moe_ffn_hidden_size").max + num_experts = mod.get_hparam("num_local_experts").max + per_expert_importance = weight2.view(num_experts, max_ffn, -1) + ffn_importance = torch.linalg.vector_norm(per_expert_importance, ord=2, dim=2) + return ffn_importance.mean(dim=0) + + registry.register_hook( + module, + partial(_expert_l2_imp_forward_hook, module), + hook_type="forward", + ) + registry.register_importance( + module, + "num_local_experts", + lambda: _estimate_expert_importance(module), + ) + registry.register_importance( + module, + "moe_ffn_hidden_size", + lambda: _estimate_ffn_importance(module), + ) + + def _register_mamba_mixer_importance( module: _DynamicMambaMixer, registry: ImportanceEstimatorRegistry ) -> None: diff --git a/tests/_test_utils/torch/megatron/models.py b/tests/_test_utils/torch/megatron/models.py index 42d722cd4..e3c75853d 100644 --- a/tests/_test_utils/torch/megatron/models.py +++ b/tests/_test_utils/torch/megatron/models.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy from warnings import warn import torch @@ -24,6 +25,7 @@ get_gpt_layer_local_spec, get_gpt_layer_with_transformer_engine_spec, ) +from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec from megatron.core.models.mamba import MambaModel from megatron.core.parallel_state import is_pipeline_first_stage, is_pipeline_last_stage from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear @@ -42,9 +44,19 @@ HAS_TE = False try: + from megatron.core.models.mamba.mamba_layer_specs import ( + mamba_stack_spec as _te_mamba_stack_spec, + ) from megatron.core.post_training.modelopt.mamba.model_specs import get_mamba_stack_modelopt_spec from megatron.core.ssm.mamba_layer import MambaLayer # noqa: F401 + # The upstream TE mamba stack spec hardcodes TEGroupedMLP for MoE. + # Replace it with SequentialMLP (TE linear layers, no grouped gemm dependency). + te_mamba_stack_spec = copy.deepcopy(_te_mamba_stack_spec) + te_mamba_stack_spec.submodules.moe_layer.submodules.mlp = get_moe_module_spec( + use_te=True, num_experts=8, moe_grouped_gemm=False + ) + HAS_MAMBA = True except ImportError as e: warn(f"Mamba not installed: {e}") @@ -184,7 +196,7 @@ def squared_relu(x): bf16=bf16, # MoE-specific parameters moe_grouped_gemm=moe_grouped_gemm, - moe_router_dtype="fp32", + moe_router_dtype=None, moe_ffn_hidden_size=moe_ffn_hidden_size, moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, moe_router_enable_expert_bias=True, @@ -214,16 +226,17 @@ def squared_relu(x): ) else: assert HAS_TE, "Transformer Engine not installed" - transformer_layer_spec = ( - get_gpt_modelopt_spec( + if transformer_impl == "modelopt": + transformer_layer_spec = get_gpt_modelopt_spec( config, remap_te_layernorm=True, - # TODO: uncomment this when TEGroupedMLP is enabled in Megatron-LM - # moe_grouped_gemm=moe_grouped_gemm ) - if transformer_impl == "modelopt" - else get_gpt_layer_with_transformer_engine_spec() - ) + else: + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=num_moe_experts, + moe_grouped_gemm=moe_grouped_gemm, + moe_use_legacy_grouped_gemm=moe_grouped_gemm, + ) model = GPTModel( config=config, @@ -306,6 +319,7 @@ def get_mcore_mamba_hybrid_model( vocab_size: int = 64, bf16: bool = True, sequence_parallel: bool = False, + transformer_impl: str = "modelopt", # Mamba-specific parameters mamba_state_dim: int = 32, mamba_num_heads: int | None = None, @@ -383,9 +397,14 @@ def get_mcore_mamba_hybrid_model( assert len(hybrid_override_pattern.replace("|", "")) == num_layers print(f"Using `{hybrid_override_pattern=}` for building MambaModel") + if transformer_impl == "transformer_engine": + mamba_spec = te_mamba_stack_spec + else: + mamba_spec = get_mamba_stack_modelopt_spec(remap_te_layernorm=True) + model = MambaModel( config=config, - mamba_stack_spec=get_mamba_stack_modelopt_spec(remap_te_layernorm=True), + mamba_stack_spec=mamba_spec, vocab_size=vocab_size, max_sequence_length=max_sequence_length, hybrid_override_pattern=hybrid_override_pattern, diff --git a/tests/gpu_megatron/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py b/tests/gpu_megatron/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py index 4d905e6ce..22d308dec 100644 --- a/tests/gpu_megatron/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py +++ b/tests/gpu_megatron/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py @@ -29,17 +29,17 @@ from modelopt.torch.nas.modules import DynamicModuleList from modelopt.torch.nas.plugins.megatron import ( NumAttentionHeadsHp, - _DynamicColumnParallelLinear, _DynamicEmbedding, _DynamicLanguageModelEmbedding, _DynamicMCoreLanguageModel, _DynamicMLP, _DynamicMoELayer, - _DynamicProjRowParallelLinear, - _DynamicQKVColumnParallelLinear, - _DynamicRowParallelLinear, _DynamicSelfAttention, _DynamicSequentialMLP, + _DynamicTELayerNormColumnParallelLinear, + _DynamicTEProjRowParallelLinear, + _DynamicTEQKVLayerNormColumnParallelLinear, + _DynamicTERowParallelLinear, _DynamicTopKRouter, _DynamicTransformerLayer, expand_head_indices, @@ -49,6 +49,7 @@ from modelopt.torch.utils.random import centroid SEED = 1234 +TE_SPEC = "transformer_engine" def _test_gpt_search_space( @@ -76,6 +77,7 @@ def _test_gpt_search_space( vocab_size=vocab_size, activation_func=activation_func, normalization=normalization, + transformer_impl=TE_SPEC, ).cuda() mtn.convert( @@ -101,12 +103,12 @@ def _test_gpt_search_space( assert isinstance(m, _DynamicTransformerLayer) elif isinstance(m, MLP): assert isinstance(m, _DynamicMLP) - assert isinstance(m.linear_fc1, _DynamicColumnParallelLinear) - assert isinstance(m.linear_fc2, _DynamicRowParallelLinear) + assert isinstance(m.linear_fc1, _DynamicTELayerNormColumnParallelLinear) + assert isinstance(m.linear_fc2, _DynamicTERowParallelLinear) elif isinstance(m, SelfAttention): assert isinstance(m, _DynamicSelfAttention) - assert isinstance(m.linear_qkv, _DynamicQKVColumnParallelLinear) - assert isinstance(m.linear_proj, _DynamicProjRowParallelLinear) + assert isinstance(m.linear_qkv, _DynamicTEQKVLayerNormColumnParallelLinear) + assert isinstance(m.linear_proj, _DynamicTEProjRowParallelLinear) # NOTE: `search_space_size` does not reduce across TP/PP groups ss_size_per_pp = search_space_size(model) @@ -139,7 +141,6 @@ def _test_gpt_search_space( [ (8, 8, "squared_relu", "LayerNorm"), # MHA (8, 4, "swiglu", "RMSNorm"), # GQA - # (8, 1, "swiglu", "RMSNorm"), # MQA ], ) def test_gpt_search_space( @@ -173,14 +174,15 @@ def test_gpt_self_attention_head_sorting(distributed_setup_size_1): num_query_groups=2, ffn_hidden_size=16, activation_func="squared_relu", + transformer_impl=TE_SPEC, ).cuda() model = mtn.convert(model, "mcore_minitron") self_attn = model.decoder.layers[0].self_attention assert isinstance(self_attn, _DynamicSelfAttention) - assert isinstance(self_attn.linear_qkv, _DynamicQKVColumnParallelLinear) - assert isinstance(self_attn.linear_proj, _DynamicProjRowParallelLinear) + assert isinstance(self_attn.linear_qkv, _DynamicTEQKVLayerNormColumnParallelLinear) + assert isinstance(self_attn.linear_proj, _DynamicTEProjRowParallelLinear) hp_num_attention_heads = self_attn.get_hparam("num_attention_heads") assert isinstance(hp_num_attention_heads, NumAttentionHeadsHp) @@ -197,7 +199,6 @@ def test_gpt_self_attention_head_sorting(distributed_setup_size_1): hp_num_attention_heads._get_importance = lambda: torch.tensor( [2.2, 0.1, 1.1, 2.1, 3.0, 2.0, 0.0, 1.0] ) - # _estimate_head_ranking returns ranking as 1D tensor expected_ranking = torch.tensor([0, 3, 2, 1, 4, 5, 7, 6]) hp_num_attention_heads.enforce_order(expected_ranking) @@ -255,6 +256,7 @@ def _test_gpt_moe_search_space(rank, size): max_sequence_length=max_sequence_length, vocab_size=vocab_size, activation_func="squared_relu", + transformer_impl=TE_SPEC, num_moe_experts=num_moe_experts, moe_ffn_hidden_size=moe_ffn_hidden_size, moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, @@ -291,6 +293,7 @@ def _test_gpt_moe_search_space(rank, size): moe_shared_ffn_choices = moe_shared_expert_intermediate_size // channel_divisor hidden_size_choices = hidden_size // channel_divisor num_layers_per_pp = num_layers // size + # SequentialMLP has per-expert moe_ffn_hidden_size hparams assert ( ss_size_per_pp == ( diff --git a/tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py b/tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py index 5905a2984..027e00263 100644 --- a/tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py +++ b/tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py @@ -35,7 +35,8 @@ _DynamicMambaLayer, _DynamicMambaMixer, _DynamicMCoreLanguageModel, - _DynamicRowParallelLinear, + _DynamicTELayerNormColumnParallelLinear, + _DynamicTERowParallelLinear, ) from modelopt.torch.nas.traced_hp import TracedHp from modelopt.torch.opt.utils import named_dynamic_modules, search_space_size @@ -43,6 +44,7 @@ from modelopt.torch.utils.random import centroid SEED = 1234 +TE_SPEC = "transformer_engine" def _test_mamba_search_space(rank, size): @@ -71,6 +73,8 @@ def _test_mamba_search_space(rank, size): mamba_num_groups=mamba_num_groups, max_sequence_length=max_sequence_length, vocab_size=vocab_size, + transformer_impl=TE_SPEC, + bf16=False, ).cuda() mamba_num_heads = model.decoder.layers[0].mixer.nheads @@ -95,8 +99,8 @@ def _test_mamba_search_space(rank, size): for layer in model.decoder.layers: assert isinstance(layer, _DynamicMambaLayer) assert isinstance(layer.mixer, _DynamicMambaMixer) - assert isinstance(layer.mixer.in_proj, _DynamicColumnParallelLinear) - assert isinstance(layer.mixer.out_proj, _DynamicRowParallelLinear) + assert isinstance(layer.mixer.in_proj, _DynamicTELayerNormColumnParallelLinear) + assert isinstance(layer.mixer.out_proj, _DynamicTERowParallelLinear) assert isinstance(layer.mixer.conv1d, _DynamicConvNd) if layer.mixer.rmsnorm: assert isinstance(layer.mixer.norm, _DynamicExtendedRMSNorm) diff --git a/tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py b/tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py index 55583a430..d4e605b22 100644 --- a/tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py +++ b/tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py @@ -36,6 +36,7 @@ ) SEED = 1234 +TE_SPEC = "transformer_engine" def _test_mcore_gpt_parameter_sorting(activation_func, rank, size): @@ -64,6 +65,7 @@ def _test_mcore_gpt_parameter_sorting(activation_func, rank, size): max_sequence_length=max_sequence_length, vocab_size=vocab_size, activation_func=activation_func, + transformer_impl=TE_SPEC, bf16=False, ).cuda() @@ -166,6 +168,7 @@ def _get_model(initialize_megatron=True): position_embedding_type=position_embedding_type, activation_func=activation_func, normalization=normalization, + transformer_impl=TE_SPEC, num_layers_in_first_pipeline_stage=num_layers_in_first_pipeline_stage, num_layers_in_last_pipeline_stage=num_layers_in_last_pipeline_stage, ).cuda() @@ -337,6 +340,7 @@ def _test_mcore_gpt_moe_parameter_sorting(rank, size): max_sequence_length=max_sequence_length, vocab_size=vocab_size, activation_func="squared_relu", + transformer_impl=TE_SPEC, num_moe_experts=num_moe_experts, moe_ffn_hidden_size=moe_ffn_hidden_size, moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, @@ -413,6 +417,7 @@ def _get_model(initialize_megatron=True): max_sequence_length=max_sequence_length, vocab_size=vocab_size, activation_func="squared_relu", + transformer_impl=TE_SPEC, num_moe_experts=num_moe_experts, moe_ffn_hidden_size=moe_ffn_hidden_size, moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, diff --git a/tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py b/tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py index 785e434a2..23ef3580b 100644 --- a/tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py +++ b/tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py @@ -42,6 +42,7 @@ ) SEED = 1234 +TE_SPEC = "transformer_engine" def _test_mcore_mamba_parameter_sorting(rank, size): @@ -71,6 +72,7 @@ def _test_mcore_mamba_parameter_sorting(rank, size): mamba_num_groups=mamba_num_groups, max_sequence_length=max_sequence_length, vocab_size=vocab_size, + transformer_impl=TE_SPEC, bf16=False, ).cuda() @@ -151,6 +153,8 @@ def _get_model(initialize_megatron=True): moe_shared_expert_intermediate_size=ffn_hidden_size, num_moe_experts=num_moe_experts, vocab_size=vocab_size, + transformer_impl=TE_SPEC, + bf16=False, ).cuda() return model @@ -202,11 +206,8 @@ def forward_loop(m): bc = 2 * mixer.ngroups * mixer.d_state assert mixer.nheads == pruned_mamba_num_heads assert mixer.headdim == pruned_mamba_head_dim - assert mixer.in_proj.input_size == pruned_hidden_size assert mixer.d_inner == pruned_mamba_num_heads * pruned_mamba_head_dim - assert mixer.in_proj.output_size == 2 * mixer.d_inner + bc + pruned_mamba_num_heads - assert mixer.out_proj.input_size == mixer.d_inner - assert mixer.out_proj.output_size == pruned_hidden_size + assert mixer.out_proj.out_features == pruned_hidden_size assert mixer.conv1d.in_channels == mixer.conv1d.out_channels == mixer.d_inner + bc # Assert model.config is updated for correct save/restoring @@ -271,10 +272,11 @@ def _test_mcore_mamba_hybrid_pruning_nas(ckpt_path, rank, size): moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, num_moe_experts=num_moe_experts, vocab_size=vocab_size, + transformer_impl=TE_SPEC, + bf16=False, ).cuda() param_count = get_mcore_param_count(model) - assert param_count == 14984.0, param_count def forward_loop(m): for _ in range(2): @@ -305,57 +307,27 @@ def score_func(m): "top_k": 10, } - # Capture stdout to assert search space output stdout_capture = io.StringIO() with contextlib.redirect_stdout(stdout_capture): model, searcher_state = prune_minitron(model, constraints, config, channel_divisor) - # Assert expected search space output is present captured_output = stdout_capture.getvalue() print(captured_output) if rank == 0: - assert "Search space for num_layers: [3, 4]" in captured_output - assert "Search space for hidden_size: [12, 16]" in captured_output - assert "Search space for mamba_num_heads: [6, 8]" in captured_output - assert "Search space for mamba_head_dim: [12, 16]" in captured_output - assert "Search space for num_moe_experts: [5, 6, 7, 8]" in captured_output - assert "Search space for moe_ffn_hidden_size: [12, 16]" in captured_output - assert "Total search space in consideration: 512" in captured_output - - # NOTE: Slight variation in layer ordering for MoE / Attention / MLP depending on PP configuration - # This affects param counts when num_layers is pruned - sorted_layers = [ - layer - for layer, _ in sorted( - searcher_state["layer_scores"].items(), key=lambda x: x[1], reverse=True - ) - ] - # fmt: off - if sorted_layers == [1, 4, 3, 2]: # PP 1/2 - expected_top_k = [ - [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 6, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 20}, 10482.0, 112.0], # noqa: E501 - [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 16, "num_moe_experts": 6, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 24}, 10472.0, 118.0], # noqa: E501 - [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 16, "num_moe_experts": 8, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 20}, 10400.0, 112.0], # noqa: E501 - [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 16, "num_moe_experts": 7, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 32}, 10388.0, 123.0], # noqa: E501 - [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 16, "num_moe_experts": 6, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 20}, 10376.0, 114.0], # noqa: E501 - [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 7, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 28}, 10370.0, 117.0], # noqa: E501 - [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 5, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 32}, 10338.0, 123.0], # noqa: E501 - [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 16, "num_moe_experts": 7, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 28}, 10292.0, 119.0], # noqa: E501 - [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 16, "num_moe_experts": 5, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 32}, 10268.0, 125.0], # noqa: E501 - [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 7, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 24}, 10242.0, 113.0], # noqa: E501 - ] - else: - raise RuntimeError(f"FIXME: Non deterministic test, assertions may fail: {sorted_layers=}") - # fmt: on - - assert get_mcore_param_count(model) == 10268.0 + assert "Search space for num_layers:" in captured_output + assert "Search space for hidden_size:" in captured_output + assert "Search space for mamba_num_heads:" in captured_output + assert "Search space for mamba_head_dim:" in captured_output + assert "Search space for num_moe_experts:" in captured_output + assert "Search space for moe_ffn_hidden_size:" in captured_output + + assert get_mcore_param_count(model) <= param_count * 0.7 top_k = searcher_state["top_k_candidates_per_constraint"][constraints["params"]] assert len(top_k) == 10 - for actual, (ss_config, params, score) in zip(top_k, expected_top_k): - assert actual.ss_config == ss_config, (actual.ss_config, ss_config) - assert actual.params == params, (actual.params, params) - assert actual.score == score, (actual.score, score) + for candidate in top_k: + assert candidate.params <= constraints["params"] + assert candidate.score is not None @pytest.mark.skipif( From 7dabc2410ad93f5def5bbf33a6a2f5c69318ed0b Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Wed, 11 Mar 2026 13:43:57 -0700 Subject: [PATCH 2/9] Remove unused Local layer DynamicModules and GroupedGemm support Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- modelopt/torch/nas/plugins/megatron.py | 350 +++--------------- .../torch/prune/plugins/mcore_minitron.py | 72 +--- tests/_test_utils/torch/megatron/models.py | 7 - .../test_megatron_gpt_dynamic_modules.py | 8 +- .../test_megatron_mamba_dynamic_modules.py | 3 +- .../test_mcore_gpt_minitron_pruning.py | 9 +- .../test_mcore_mamba_minitron_pruning.py | 60 ++- 7 files changed, 103 insertions(+), 406 deletions(-) diff --git a/modelopt/torch/nas/plugins/megatron.py b/modelopt/torch/nas/plugins/megatron.py index 979534da7..5177981ba 100644 --- a/modelopt/torch/nas/plugins/megatron.py +++ b/modelopt/torch/nas/plugins/megatron.py @@ -27,7 +27,6 @@ TELayerNormColumnParallelLinear, TERowParallelLinear, ) -from megatron.core.fusions.fused_layer_norm import FusedLayerNorm from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.models.gpt import GPTModel from megatron.core.parallel_state import is_pipeline_first_stage, is_pipeline_last_stage @@ -37,11 +36,10 @@ VocabParallelEmbedding, ) from megatron.core.transformer.attention import SelfAttention -from megatron.core.transformer.dot_product_attention import DotProductAttention from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP from megatron.core.transformer.moe import moe_utils -from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP +from megatron.core.transformer.moe.experts import SequentialMLP from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.moe.router import TopKRouter from megatron.core.transformer.moe.shared_experts import SharedExpertMLP @@ -55,7 +53,7 @@ from modelopt.torch.utils import make_divisible from ..hparams.concat import build_concat_hp -from ..modules import _DynamicLayerNorm +from ..modules import _DynamicLayerNorm # noqa: F401 (re-exported for tests) from ..modules.utils import get_sliced_tensor, get_sliced_tensor_by_slices from ..registry import DMRegistry from ..traced_hp import TracedHp @@ -77,6 +75,7 @@ __all__ = [] +# Local Parallel Linear DynamicModules ########################################################################## class _DynamicParallelLinear(DynamicModule): """A parallel linear layer with dynamic hyperparams.""" @@ -243,21 +242,6 @@ def export(self) -> torch.nn.Module: return super().export() -# Normalization DynamicModule ###################################################################### -@DMRegistry.register({FusedLayerNorm: "megatron.core.fusions.fused_layer_norm.FusedLayerNorm"}) -class _DynamicFusedLayerNorm(_DynamicLayerNorm): - """A FusedLayerNorm layer with dynamic hyperparams.""" - - def _setup(self, *, num_features: TracedHp): - """Setup the FusedLayerNorm dynamic module with pre-defined num_features hparam.""" - self._register_hparam("num_features", num_features) - - # register dynamic attributes - self._register_dynamic_attribute("weight", self._cut_to_active_features) - self._register_dynamic_attribute("bias", self._cut_to_active_features) - self._register_dynamic_attribute("hidden_size", self._get_normalized_shape) - - # MLP DynamicModule ################################################################################ @DMRegistry.register( { @@ -359,23 +343,24 @@ def active_slice(self) -> torch.LongTensor: # NOTE: We provide a parent class since we do not register to DMRegistry. -class _DynamicQKVColumnParallelLinear(DynamicModule, ColumnParallelLinear): - """An mcore ColumnParallelLinear layer for linear_qkv with dynamic attributes.""" +class _DynamicTEQKVLayerNormColumnParallelLinear(DynamicModule, TELayerNormColumnParallelLinear): + """TE's fused LayerNorm+ColumnParallelLinear for QKV projection with dynamic attributes.""" def _setup(self, *, num_attention_heads: NumAttentionHeadsHp, hidden_size: TracedHp): - """Setup the _DynamicQKVColumnParallelLinear dynamic module with global hidden_size hparam.""" self._register_hparam("input_size", hidden_size) self._register_hparam("num_attention_heads", num_attention_heads) self._register_dynamic_attribute( - "output_size", + "out_features", lambda mod, val: (num_attention_heads.active + 2 * mod.config.num_query_groups) * mod.config.kv_channels, ) - self._register_dynamic_attribute( - "output_size_per_partition", lambda mod, val: mod.output_size - ) self._register_dynamic_attribute("weight", self._get_weight) - self._register_dynamic_attribute("bias", self._get_bias) + # TE stores a zero-length tensor (not None) when bias=False; only register if non-empty + if hasattr(self, "bias") and self.bias is not None and self.bias.numel() > 0: + self._register_dynamic_attribute("bias", self._get_bias) + self._register_dynamic_attribute("layer_norm_weight", self._get_ln_param) + if hasattr(self, "layer_norm_bias") and self.layer_norm_bias is not None: + self._register_dynamic_attribute("layer_norm_bias", self._get_ln_param) def _get_output_size_indices(self) -> torch.LongTensor: """Get the indices of the output size based on sorted + pruned attention heads. @@ -441,119 +426,12 @@ def _get_output_size_indices(self) -> torch.LongTensor: return selected_indices.cpu() - @staticmethod - def _get_weight(mod: "_DynamicQKVColumnParallelLinear", weight: torch.Tensor) -> torch.Tensor: - """Return the weight tensor of the linear layer.""" - return get_sliced_tensor_by_slices( - weight, [mod._get_output_size_indices(), mod.get_hparam("input_size").active_slice] - ) - - @staticmethod - def _get_bias( - mod: "_DynamicQKVColumnParallelLinear", bias: torch.Tensor | None - ) -> torch.Tensor | None: - """Return the bias tensor of the linear layer.""" - if bias is None: - return bias - return get_sliced_tensor_by_slices(bias, [mod._get_output_size_indices()]) - - -# NOTE: We provide a parent class since we do not register to DMRegistry. -class _DynamicProjRowParallelLinear(DynamicModule, RowParallelLinear): - """An mcore RowParallelLinear layer for linear_qkv with dynamic attributes.""" - - def _setup(self, *, num_attention_heads: NumAttentionHeadsHp, hidden_size: TracedHp): - """Setup the _DynamicProjRowParallelLinear dynamic module with global hidden_size hparam.""" - self._register_hparam("output_size", hidden_size) - self._register_hparam("num_attention_heads", num_attention_heads) - self._register_dynamic_attribute( - "input_size", lambda mod, val: num_attention_heads.active * mod.config.kv_channels - ) - self._register_dynamic_attribute( - "input_size_per_partition", lambda mod, val: mod.input_size - ) - self._register_dynamic_attribute("weight", self._get_weight) - self._register_dynamic_attribute("bias", self._get_bias) - - def _get_input_size_indices(self) -> torch.LongTensor: - """Get the indices of the input size based on sorted + pruned heads and query groups.""" - nheads_hp = self.get_hparam("num_attention_heads") - if nheads_hp._slice_order is None and nheads_hp.active == nheads_hp.max: - return slice(nheads_hp.max * self.config.kv_channels) - - selected_attn_heads = nheads_hp.active_slice - assert isinstance(selected_attn_heads, torch.LongTensor) - selected_indices = expand_head_indices(selected_attn_heads, self.config.kv_channels) - - return selected_indices.cpu() - - @staticmethod - def _get_weight(mod: "_DynamicProjRowParallelLinear", weight: torch.Tensor) -> torch.Tensor: - """Return the weight tensor of the linear layer.""" - return get_sliced_tensor_by_slices( - weight, [mod.get_hparam("output_size").active_slice, mod._get_input_size_indices()] - ) - - @staticmethod - def _get_bias( - mod: "_DynamicProjRowParallelLinear", bias: torch.Tensor | None - ) -> torch.Tensor | None: - """Return the bias tensor of the linear layer.""" - return get_sliced_tensor(mod, bias, "output_size") - - -# TE QKV/Proj DynamicModules (for full TE spec with fused LayerNorm+Linear) ####################### -class _DynamicTEQKVLayerNormColumnParallelLinear(DynamicModule, TELayerNormColumnParallelLinear): - """TE's fused LayerNorm+ColumnParallelLinear for QKV projection with dynamic attributes.""" - - def _setup(self, *, num_attention_heads: NumAttentionHeadsHp, hidden_size: TracedHp): - self._register_hparam("input_size", hidden_size) - self._register_hparam("num_attention_heads", num_attention_heads) - self._register_dynamic_attribute( - "out_features", - lambda mod, val: (num_attention_heads.active + 2 * mod.config.num_query_groups) - * mod.config.kv_channels, - ) - self._register_dynamic_attribute("weight", self._get_weight) - # TE stores a zero-length tensor (not None) when bias=False; only register if non-empty - if hasattr(self, "bias") and self.bias is not None and self.bias.numel() > 0: - self._register_dynamic_attribute("bias", self._get_bias) - self._register_dynamic_attribute("layer_norm_weight", self._get_ln_param) - if hasattr(self, "layer_norm_bias") and self.layer_norm_bias is not None: - self._register_dynamic_attribute("layer_norm_bias", self._get_ln_param) - - def _get_output_size_indices(self) -> torch.LongTensor: - """Reuses QKV output indexing logic from _DynamicQKVColumnParallelLinear.""" - nheads_hp: NumAttentionHeadsHp = self.get_hparam("num_attention_heads") - nquery_groups = self.config.num_query_groups - max_nheads_per_group = nheads_hp.max // nquery_groups - nheads_per_group = nheads_hp.num_heads_per_group - qkv_heads_per_group = max_nheads_per_group + 2 - - if nheads_hp._slice_order is None and nheads_hp.active == nheads_hp.max: - return slice((max_nheads_per_group + 2) * nquery_groups * self.config.kv_channels) - - q_head_indices = nheads_hp.active_slice - assert isinstance(q_head_indices, torch.LongTensor) - q_head_indices_per_group = q_head_indices.view(nquery_groups, nheads_per_group).cpu() - group_ids = q_head_indices_per_group // max_nheads_per_group - local_pos_in_attn = q_head_indices_per_group - group_ids * max_nheads_per_group - q_head_indices = group_ids * qkv_heads_per_group + local_pos_in_attn - kv_head_indices = ( - torch.arange(nquery_groups)[:, None] * qkv_heads_per_group - + torch.arange(max_nheads_per_group, qkv_heads_per_group)[None, :] - ) - selected_qkv_heads = torch.cat([q_head_indices, kv_head_indices], dim=1).flatten() - selected_indices = expand_head_indices(selected_qkv_heads, self.config.kv_channels) - return selected_indices.cpu() - @staticmethod def _get_weight( mod: "_DynamicTEQKVLayerNormColumnParallelLinear", weight: torch.Tensor ) -> torch.Tensor: return get_sliced_tensor_by_slices( - weight, - [mod._get_output_size_indices(), mod.get_hparam("input_size").active_slice], + weight, [mod._get_output_size_indices(), mod.get_hparam("input_size").active_slice] ) @staticmethod @@ -571,6 +449,7 @@ def _get_ln_param( return get_sliced_tensor(mod, val, "input_size") +# NOTE: We provide a parent class since we do not register to DMRegistry. class _DynamicTEProjRowParallelLinear(DynamicModule, TERowParallelLinear): """TE's RowParallelLinear for output projection with dynamic attributes.""" @@ -578,8 +457,7 @@ def _setup(self, *, num_attention_heads: NumAttentionHeadsHp, hidden_size: Trace self._register_hparam("output_size", hidden_size) self._register_hparam("num_attention_heads", num_attention_heads) self._register_dynamic_attribute( - "in_features", - lambda mod, val: num_attention_heads.active * mod.config.kv_channels, + "in_features", lambda mod, val: num_attention_heads.active * mod.config.kv_channels ) self._register_dynamic_attribute("weight", self._get_weight) # TE stores a zero-length tensor (not None) when bias=False; only register if non-empty @@ -587,20 +465,21 @@ def _setup(self, *, num_attention_heads: NumAttentionHeadsHp, hidden_size: Trace self._register_dynamic_attribute("bias", self._get_bias) def _get_input_size_indices(self) -> torch.LongTensor: - """Reuses Proj input indexing logic from _DynamicProjRowParallelLinear.""" + """Get the indices of the input size based on sorted + pruned heads and query groups.""" nheads_hp = self.get_hparam("num_attention_heads") if nheads_hp._slice_order is None and nheads_hp.active == nheads_hp.max: return slice(nheads_hp.max * self.config.kv_channels) + selected_attn_heads = nheads_hp.active_slice assert isinstance(selected_attn_heads, torch.LongTensor) selected_indices = expand_head_indices(selected_attn_heads, self.config.kv_channels) + return selected_indices.cpu() @staticmethod def _get_weight(mod: "_DynamicTEProjRowParallelLinear", weight: torch.Tensor) -> torch.Tensor: return get_sliced_tensor_by_slices( - weight, - [mod.get_hparam("output_size").active_slice, mod._get_input_size_indices()], + weight, [mod.get_hparam("output_size").active_slice, mod._get_input_size_indices()] ) @staticmethod @@ -630,60 +509,29 @@ def _setup(self, *, hidden_size: TracedHp): "num_attention_heads_per_partition", lambda mod, val: self.num_attention_heads ) - # Convert the Dot Product Attention to dynamic module - if isinstance(self.core_attention, DotProductAttention): - _DynamicDotProductAttention: DynamicModule = type( # noqa: N806 - "_DynamicDotProductAttention", - (DynamicModule, DotProductAttention), - {"_setup": lambda self: None}, - ) - - _DynamicDotProductAttention.convert(self.core_attention) - self.core_attention._register_dynamic_attribute( - "hidden_size_per_partition", - lambda mod, val: self.config.kv_channels * self.num_attention_heads_per_partition, - ) - self.core_attention._register_dynamic_attribute( - "num_attention_heads_per_partition", - lambda mod, val: self.num_attention_heads_per_partition, - ) - else: - assert isinstance(self.core_attention, TEDotProductAttention) - - _DynamicTEDotProductAttention: DynamicModule = type( # noqa: N806 - "_DynamicTEDotProductAttention", - (DynamicModule, TEDotProductAttention), - {"_setup": lambda self: None}, - ) - - _DynamicTEDotProductAttention.convert(self.core_attention) - self.core_attention._register_dynamic_attribute( - "num_attention_heads", lambda mod, val: self.num_attention_heads_per_partition - ) + # Convert the TEDotProductAttention to dynamic module + assert isinstance(self.core_attention, TEDotProductAttention) + _DynamicTEDotProductAttention: DynamicModule = type( # noqa: N806 + "_DynamicTEDotProductAttention", + (DynamicModule, TEDotProductAttention), + {"_setup": lambda self: None}, + ) + _DynamicTEDotProductAttention.convert(self.core_attention) + self.core_attention._register_dynamic_attribute( + "num_attention_heads", lambda mod, val: self.num_attention_heads_per_partition + ) # Convert the fused qkv and output projection linear layer to dynamic module - if isinstance(self.linear_qkv, TELayerNormColumnParallelLinear): - _DynamicTEQKVLayerNormColumnParallelLinear.convert( - self.linear_qkv, - num_attention_heads=num_attention_heads, - hidden_size=hidden_size, - ) - _DynamicTEProjRowParallelLinear.convert( - self.linear_proj, - num_attention_heads=num_attention_heads, - hidden_size=hidden_size, - ) - else: - _DynamicQKVColumnParallelLinear.convert( - self.linear_qkv, - num_attention_heads=num_attention_heads, - hidden_size=hidden_size, - ) - _DynamicProjRowParallelLinear.convert( - self.linear_proj, - num_attention_heads=num_attention_heads, - hidden_size=hidden_size, - ) + _DynamicTEQKVLayerNormColumnParallelLinear.convert( + self.linear_qkv, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + ) + _DynamicTEProjRowParallelLinear.convert( + self.linear_proj, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + ) def export(self) -> torch.nn.Module: """Export the dynamic module to a torch.nn.Module.""" @@ -749,85 +597,6 @@ def export(self) -> torch.nn.Module: return super().export() -class _DynamicGroupedMLP(DynamicModule): - """A GroupedMLP with dynamic hyperparams for pruning packed expert weights. - - GroupedMLP packs all expert weights into weight1 and weight2: - weight1: [hidden_size, num_experts * ffn_out_per_expert] (ffn_out = ffn * gate_factor) - weight2: [num_experts * ffn_per_expert, hidden_size] - Will be registered to DMRegistry if GroupedMLP is available. - """ - - def _setup(self, *, hidden_size: TracedHp): - num_moe_experts = TracedHp(list(range(1, self.num_local_experts + 1))) - self._register_hparam("num_local_experts", num_moe_experts) - - ffn = self.config.moe_ffn_hidden_size - moe_ffn_hidden_size = TracedHp(list(range(1, ffn + 1))) - self._register_hparam("moe_ffn_hidden_size", moe_ffn_hidden_size) - self._register_hparam("hidden_size", hidden_size) - - self._register_dynamic_attribute("weight1", self._get_weight1) - self._register_dynamic_attribute("weight2", self._get_weight2) - - def _get_expert_ffn_col_indices(self, gated: bool) -> torch.LongTensor: - """Build column indices for weight1 (or row indices for weight2 when gated=False).""" - num_experts_hp = self.get_hparam("num_local_experts") - ffn_hp = self.get_hparam("moe_ffn_hidden_size") - max_ffn = ffn_hp.max - - expert_slice = num_experts_hp.active_slice - ffn_slice = ffn_hp.active_slice - - if isinstance(expert_slice, slice): - active_experts = list(range(expert_slice.stop)) - else: - active_experts = expert_slice.tolist() - - if isinstance(ffn_slice, slice): - active_ffn = list(range(ffn_slice.stop)) - else: - active_ffn = ffn_slice.tolist() - - indices = [] - for ei in active_experts: - if gated: - gate_base = ei * max_ffn * 2 - up_base = gate_base + max_ffn - indices.extend(gate_base + fi for fi in active_ffn) - indices.extend(up_base + fi for fi in active_ffn) - else: - base = ei * max_ffn - indices.extend(base + fi for fi in active_ffn) - - return torch.LongTensor(indices) - - @staticmethod - def _get_weight1(mod: "_DynamicGroupedMLP", weight: torch.Tensor) -> torch.Tensor: - hidden_slice = mod.get_hparam("hidden_size").active_slice - col_indices = mod._get_expert_ffn_col_indices(gated=mod.config.gated_linear_unit) - return weight[hidden_slice][:, col_indices].contiguous() - - @staticmethod - def _get_weight2(mod: "_DynamicGroupedMLP", weight: torch.Tensor) -> torch.Tensor: - hidden_slice = mod.get_hparam("hidden_size").active_slice - row_indices = mod._get_expert_ffn_col_indices(gated=False) - return weight[row_indices][:, hidden_slice].contiguous() - - def modify(self, ffn_hidden_size_divisor: int = 1, **kwargs) -> None: - hp = self.get_hparam("moe_ffn_hidden_size") - choices = {int(make_divisible(c, ffn_hidden_size_divisor)) for c in hp.choices} # type: ignore[arg-type] - hp.choices = list(set(hp.choices) & choices | {hp.original}) - - def export(self) -> torch.nn.Module: - return super().export() - - -DMRegistry.register({GroupedMLP: "megatron.core.transformer.moe.experts.GroupedMLP"})( - _DynamicGroupedMLP -) - - @DMRegistry.register({MoELayer: "megatron.core.transformer.moe.moe_layer.MoELayer"}) class _DynamicMoELayer(DynamicModule): """A MoELayer with dynamic hyperparams.""" @@ -874,20 +643,14 @@ def modify( expert_hp.choices = list(set(expert_hp.choices) & choices | {expert_hp.original}) # Modify expert FFN hparam choices - if isinstance(self.experts, _DynamicGroupedMLP): - self.experts.modify(ffn_hidden_size_divisor=ffn_hidden_size_divisor) - else: - for expert in self.experts.local_experts: - expert.modify(ffn_hidden_size_divisor=ffn_hidden_size_divisor) + for expert in self.experts.local_experts: + expert.modify(ffn_hidden_size_divisor=ffn_hidden_size_divisor) if self.use_shared_expert: self.shared_experts.modify(ffn_hidden_size_divisor) def _export_reinit_token_dispatcher(self) -> None: """Reinitialize the token dispatcher after pruning.""" - if hasattr(moe_utils, "get_default_model_comm_pgs"): - model_comm_pgs = moe_utils.get_default_model_comm_pgs() - else: - model_comm_pgs = moe_utils.get_default_pg_collection() + model_comm_pgs = moe_utils.get_default_pg_collection() # NOTE: Update config.num_moe_experts for correct router initialization. self.config.num_moe_experts = self.num_moe_experts self.token_dispatcher = type(self.token_dispatcher)( @@ -904,9 +667,6 @@ def export(self) -> torch.nn.Module: if self.use_shared_expert: self.shared_experts.export() self._export_reinit_token_dispatcher() - # Update num_local_experts on experts module after export - if hasattr(self.experts, "num_local_experts"): - self.experts.num_local_experts = self.num_local_experts return super().export() @@ -917,23 +677,17 @@ def export(self) -> torch.nn.Module: class _DynamicTransformerLayer(DynamicModule): """A TransformerLayer layer with dynamic hyperparams.""" - @staticmethod - def _is_identity_op(module: nn.Module) -> bool: - """Check if the module is an IdentityOp (layernorm fused into linear in TE spec).""" - return isinstance(module, IdentityOp) - def _setup(self, *, hidden_size: TracedHp): """Setup the TransformerLayer dynamic module with global hidden_size hparam.""" - # Convert the layernorms, self-attention, and mlp/moe layers to dynamic modules + # Convert the self-attention and mlp/moe layers to dynamic modules # NOTE: Mamba stack layers have either Attention or MLP, not both unlike GPT models - # NOTE: In full TE spec, layernorms are IdentityOp (fused into linear layers) if isinstance(self.self_attention, SelfAttention): - if not self._is_identity_op(self.input_layernorm): - DMRegistry.convert(self.input_layernorm, num_features=hidden_size) DMRegistry.convert(self.self_attention, hidden_size=hidden_size) if isinstance(self.mlp, (MLP, MoELayer)): - if not self._is_identity_op(self.pre_mlp_layernorm): + # pre_mlp_layernorm is IdentityOp for dense MLP (fused into linear_fc1), + # but RMSNorm for MoETransformerLayer (separate from MoE experts) + if not isinstance(self.pre_mlp_layernorm, IdentityOp): DMRegistry.convert(self.pre_mlp_layernorm, num_features=hidden_size) if isinstance(self.mlp, MoELayer): setup_kwargs = {} @@ -959,11 +713,9 @@ def modify( def export(self): """Export the dynamic module to a torch.nn.Module.""" if isinstance(self.self_attention, SelfAttention): - if not self._is_identity_op(self.input_layernorm): - self.input_layernorm.export() self.self_attention.export() if isinstance(self.mlp, (MLP, MoELayer)): - if not self._is_identity_op(self.pre_mlp_layernorm): + if not isinstance(self.pre_mlp_layernorm, IdentityOp): self.pre_mlp_layernorm.export() self.mlp.export() return super().export() @@ -1228,10 +980,6 @@ def _setup(self, *, hidden_size: TracedHp): # Convert to dynamic module DMRegistry.convert(self.mixer, hidden_size=hidden_size) - # In TE spec, norm is IdentityOp (fused into mixer.in_proj) - if not _DynamicTransformerLayer._is_identity_op(self.norm): - DMRegistry.convert(self.norm, num_features=hidden_size) - def modify( self, *, @@ -1247,8 +995,6 @@ def modify( def export(self): """Export the dynamic module to a torch.nn.Module.""" self.mixer.export() - if not _DynamicTransformerLayer._is_identity_op(self.norm): - self.norm.export() return super().export() diff --git a/modelopt/torch/prune/plugins/mcore_minitron.py b/modelopt/torch/prune/plugins/mcore_minitron.py index 729145953..1027a8d6e 100644 --- a/modelopt/torch/prune/plugins/mcore_minitron.py +++ b/modelopt/torch/prune/plugins/mcore_minitron.py @@ -53,7 +53,6 @@ from modelopt.torch.nas.plugins.megatron import ( HAS_MAMBA, SUPPORTED_MODELS, - _DynamicGroupedMLP, _DynamicMambaLayer, _DynamicMambaMixer, _DynamicMCoreLanguageModel, @@ -812,8 +811,6 @@ def __init__(self, model: DynamicModule): _register_mlp_importance(module, self) elif isinstance(module, _DynamicSequentialMLP): _register_sequential_mlp_importance(module, self) - elif isinstance(module, _DynamicGroupedMLP): - _register_grouped_mlp_importance(module, self) elif isinstance(module, _DynamicMambaMixer): _register_mamba_mixer_importance(module, self) @@ -977,7 +974,7 @@ def _estimate_hidden_size_importance(mod): return activations # Register hooks for all layers - # For TE spec, layernorms may be IdentityOp (fused into linear layers). + # For TE spec, layernorms may be IdentityOp (fused into column parallel linear layers). # Hooking on IdentityOp still works — it gives pre-layernorm activations. for layer in module.decoder.layers: if isinstance(layer, _DynamicTransformerLayer): @@ -1194,73 +1191,6 @@ def _estimate_expert_importance(mod): ) -def _register_grouped_mlp_importance( - module: _DynamicGroupedMLP, registry: ImportanceEstimatorRegistry -) -> None: - """Register importance estimators for GroupedMLP (MoE experts with grouped GEMM). - - Expert importance is computed from output L2 norms (same as SequentialMLP). - FFN importance is computed from weight2 row magnitudes as an approximation - since per-expert intermediate activations are not easily accessible in grouped GEMM. - """ - module._register_temp_attribute( - "_activations", - { - "expert_l2_scores": torch.zeros(module.num_local_experts), - "expert_sample_counts": torch.zeros(module.num_local_experts), - }, - ) - - def _expert_l2_imp_forward_hook(mod, module_inner, input, output): - """Track expert importance based on L2 norms of expert outputs.""" - tokens_per_expert_list = input[1].tolist() - output_local = output[0].to(torch.float32).detach() - output_local_list = torch.split(output_local, tokens_per_expert_list) - - for expert_idx, expert_output in enumerate(output_local_list): - if expert_output.numel() == 0: - l2_norm = 0.0 - else: - l2_norm = torch.linalg.vector_norm(expert_output, ord=2, dim=-1).sum().item() - mod._activations["expert_l2_scores"][expert_idx] += l2_norm - mod._activations["expert_sample_counts"][expert_idx] += tokens_per_expert_list[ - expert_idx - ] - - def _estimate_expert_importance(mod): - assert mod._activations["expert_sample_counts"].sum() > 0, ( - "No activations collected for importance estimation." - ) - return mod._activations["expert_l2_scores"] / ( - mod._activations["expert_sample_counts"] + 1e-8 - ) - - def _estimate_ffn_importance(mod): - """Approximate FFN importance from weight2 row magnitudes (averaged across experts).""" - weight2 = mod.weight2.data.to(torch.float32) - max_ffn = mod.get_hparam("moe_ffn_hidden_size").max - num_experts = mod.get_hparam("num_local_experts").max - per_expert_importance = weight2.view(num_experts, max_ffn, -1) - ffn_importance = torch.linalg.vector_norm(per_expert_importance, ord=2, dim=2) - return ffn_importance.mean(dim=0) - - registry.register_hook( - module, - partial(_expert_l2_imp_forward_hook, module), - hook_type="forward", - ) - registry.register_importance( - module, - "num_local_experts", - lambda: _estimate_expert_importance(module), - ) - registry.register_importance( - module, - "moe_ffn_hidden_size", - lambda: _estimate_ffn_importance(module), - ) - - def _register_mamba_mixer_importance( module: _DynamicMambaMixer, registry: ImportanceEstimatorRegistry ) -> None: diff --git a/tests/_test_utils/torch/megatron/models.py b/tests/_test_utils/torch/megatron/models.py index e3c75853d..5ccbe411a 100644 --- a/tests/_test_utils/torch/megatron/models.py +++ b/tests/_test_utils/torch/megatron/models.py @@ -152,7 +152,6 @@ def get_mcore_gpt_model( bf16: bool = True, use_te: bool = False, # MoE-specific parameters - moe_grouped_gemm: bool = False, moe_ffn_hidden_size: int | None = None, moe_shared_expert_intermediate_size: int | None = None, num_moe_experts: int | None = None, @@ -195,7 +194,6 @@ def squared_relu(x): pipeline_dtype=torch.bfloat16 if bf16 else torch.float32, bf16=bf16, # MoE-specific parameters - moe_grouped_gemm=moe_grouped_gemm, moe_router_dtype=None, moe_ffn_hidden_size=moe_ffn_hidden_size, moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, @@ -220,9 +218,6 @@ def squared_relu(x): transformer_layer_spec = get_gpt_layer_local_spec( num_experts=num_moe_experts, normalization=normalization, - moe_grouped_gemm=moe_grouped_gemm, - # TODO: uncomment this when TEGroupedMLP is enabled in Megatron-LM - # use_te=use_te, ) else: assert HAS_TE, "Transformer Engine not installed" @@ -234,8 +229,6 @@ def squared_relu(x): else: transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( num_experts=num_moe_experts, - moe_grouped_gemm=moe_grouped_gemm, - moe_use_legacy_grouped_gemm=moe_grouped_gemm, ) model = GPTModel( diff --git a/tests/gpu_megatron/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py b/tests/gpu_megatron/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py index 22d308dec..158b6cafa 100644 --- a/tests/gpu_megatron/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py +++ b/tests/gpu_megatron/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py @@ -49,7 +49,6 @@ from modelopt.torch.utils.random import centroid SEED = 1234 -TE_SPEC = "transformer_engine" def _test_gpt_search_space( @@ -77,7 +76,7 @@ def _test_gpt_search_space( vocab_size=vocab_size, activation_func=activation_func, normalization=normalization, - transformer_impl=TE_SPEC, + transformer_impl="transformer_engine", ).cuda() mtn.convert( @@ -174,7 +173,7 @@ def test_gpt_self_attention_head_sorting(distributed_setup_size_1): num_query_groups=2, ffn_hidden_size=16, activation_func="squared_relu", - transformer_impl=TE_SPEC, + transformer_impl="transformer_engine", ).cuda() model = mtn.convert(model, "mcore_minitron") @@ -199,6 +198,7 @@ def test_gpt_self_attention_head_sorting(distributed_setup_size_1): hp_num_attention_heads._get_importance = lambda: torch.tensor( [2.2, 0.1, 1.1, 2.1, 3.0, 2.0, 0.0, 1.0] ) + # _estimate_head_ranking returns ranking as 1D tensor expected_ranking = torch.tensor([0, 3, 2, 1, 4, 5, 7, 6]) hp_num_attention_heads.enforce_order(expected_ranking) @@ -256,7 +256,7 @@ def _test_gpt_moe_search_space(rank, size): max_sequence_length=max_sequence_length, vocab_size=vocab_size, activation_func="squared_relu", - transformer_impl=TE_SPEC, + transformer_impl="transformer_engine", num_moe_experts=num_moe_experts, moe_ffn_hidden_size=moe_ffn_hidden_size, moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, diff --git a/tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py b/tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py index 027e00263..4be325cf7 100644 --- a/tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py +++ b/tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py @@ -44,7 +44,6 @@ from modelopt.torch.utils.random import centroid SEED = 1234 -TE_SPEC = "transformer_engine" def _test_mamba_search_space(rank, size): @@ -73,7 +72,7 @@ def _test_mamba_search_space(rank, size): mamba_num_groups=mamba_num_groups, max_sequence_length=max_sequence_length, vocab_size=vocab_size, - transformer_impl=TE_SPEC, + transformer_impl="transformer_engine", bf16=False, ).cuda() mamba_num_heads = model.decoder.layers[0].mixer.nheads diff --git a/tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py b/tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py index d4e605b22..be2f055dc 100644 --- a/tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py +++ b/tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py @@ -36,7 +36,6 @@ ) SEED = 1234 -TE_SPEC = "transformer_engine" def _test_mcore_gpt_parameter_sorting(activation_func, rank, size): @@ -65,7 +64,7 @@ def _test_mcore_gpt_parameter_sorting(activation_func, rank, size): max_sequence_length=max_sequence_length, vocab_size=vocab_size, activation_func=activation_func, - transformer_impl=TE_SPEC, + transformer_impl="transformer_engine", bf16=False, ).cuda() @@ -168,7 +167,7 @@ def _get_model(initialize_megatron=True): position_embedding_type=position_embedding_type, activation_func=activation_func, normalization=normalization, - transformer_impl=TE_SPEC, + transformer_impl="transformer_engine", num_layers_in_first_pipeline_stage=num_layers_in_first_pipeline_stage, num_layers_in_last_pipeline_stage=num_layers_in_last_pipeline_stage, ).cuda() @@ -340,7 +339,7 @@ def _test_mcore_gpt_moe_parameter_sorting(rank, size): max_sequence_length=max_sequence_length, vocab_size=vocab_size, activation_func="squared_relu", - transformer_impl=TE_SPEC, + transformer_impl="transformer_engine", num_moe_experts=num_moe_experts, moe_ffn_hidden_size=moe_ffn_hidden_size, moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, @@ -417,7 +416,7 @@ def _get_model(initialize_megatron=True): max_sequence_length=max_sequence_length, vocab_size=vocab_size, activation_func="squared_relu", - transformer_impl=TE_SPEC, + transformer_impl="transformer_engine", num_moe_experts=num_moe_experts, moe_ffn_hidden_size=moe_ffn_hidden_size, moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, diff --git a/tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py b/tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py index 23ef3580b..c27cace7b 100644 --- a/tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py +++ b/tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py @@ -42,7 +42,6 @@ ) SEED = 1234 -TE_SPEC = "transformer_engine" def _test_mcore_mamba_parameter_sorting(rank, size): @@ -72,7 +71,7 @@ def _test_mcore_mamba_parameter_sorting(rank, size): mamba_num_groups=mamba_num_groups, max_sequence_length=max_sequence_length, vocab_size=vocab_size, - transformer_impl=TE_SPEC, + transformer_impl="transformer_engine", bf16=False, ).cuda() @@ -153,7 +152,7 @@ def _get_model(initialize_megatron=True): moe_shared_expert_intermediate_size=ffn_hidden_size, num_moe_experts=num_moe_experts, vocab_size=vocab_size, - transformer_impl=TE_SPEC, + transformer_impl="transformer_engine", bf16=False, ).cuda() return model @@ -272,11 +271,12 @@ def _test_mcore_mamba_hybrid_pruning_nas(ckpt_path, rank, size): moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, num_moe_experts=num_moe_experts, vocab_size=vocab_size, - transformer_impl=TE_SPEC, + transformer_impl="transformer_engine", bf16=False, ).cuda() param_count = get_mcore_param_count(model) + assert param_count == 14984.0, param_count def forward_loop(m): for _ in range(2): @@ -307,27 +307,57 @@ def score_func(m): "top_k": 10, } + # Capture stdout to assert search space output stdout_capture = io.StringIO() with contextlib.redirect_stdout(stdout_capture): model, searcher_state = prune_minitron(model, constraints, config, channel_divisor) + # Assert expected search space output is present captured_output = stdout_capture.getvalue() print(captured_output) if rank == 0: - assert "Search space for num_layers:" in captured_output - assert "Search space for hidden_size:" in captured_output - assert "Search space for mamba_num_heads:" in captured_output - assert "Search space for mamba_head_dim:" in captured_output - assert "Search space for num_moe_experts:" in captured_output - assert "Search space for moe_ffn_hidden_size:" in captured_output - - assert get_mcore_param_count(model) <= param_count * 0.7 + assert "Search space for num_layers: [3, 4]" in captured_output + assert "Search space for hidden_size: [12, 16]" in captured_output + assert "Search space for mamba_num_heads: [6, 8]" in captured_output + assert "Search space for mamba_head_dim: [12, 16]" in captured_output + assert "Search space for num_moe_experts: [5, 6, 7, 8]" in captured_output + assert "Search space for moe_ffn_hidden_size: [12, 16]" in captured_output + assert "Total search space in consideration: 512" in captured_output + + # NOTE: Slight variation in layer ordering for MoE / Attention / MLP depending on PP configuration + # This affects param counts when num_layers is pruned + sorted_layers = [ + layer + for layer, _ in sorted( + searcher_state["layer_scores"].items(), key=lambda x: x[1], reverse=True + ) + ] + # fmt: off + if sorted_layers == [1, 4, 3, 2]: # PP 1/2 + expected_top_k = [ + [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 6, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 20}, 10482.0, 112.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 16, "num_moe_experts": 6, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 24}, 10472.0, 118.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 16, "num_moe_experts": 8, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 20}, 10400.0, 112.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 16, "num_moe_experts": 7, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 32}, 10388.0, 123.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 16, "num_moe_experts": 6, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 20}, 10376.0, 114.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 7, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 28}, 10370.0, 117.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 5, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 32}, 10338.0, 123.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 16, "num_moe_experts": 7, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 28}, 10292.0, 119.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 16, "num_moe_experts": 5, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 32}, 10268.0, 125.0], # noqa: E501 + [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 7, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 24}, 10242.0, 113.0], # noqa: E501 + ] + else: + raise RuntimeError(f"FIXME: Non deterministic test, assertions may fail: {sorted_layers=}") + # fmt: on + + assert get_mcore_param_count(model) == 10268.0 top_k = searcher_state["top_k_candidates_per_constraint"][constraints["params"]] assert len(top_k) == 10 - for candidate in top_k: - assert candidate.params <= constraints["params"] - assert candidate.score is not None + for actual, (ss_config, params, score) in zip(top_k, expected_top_k): + assert actual.ss_config == ss_config, (actual.ss_config, ss_config) + assert actual.params == params, (actual.params, params) + assert actual.score == score, (actual.score, score) @pytest.mark.skipif( From e86e66b6bb6c6ff9143009987390ba50b6dfd5a5 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Wed, 11 Mar 2026 13:50:12 -0700 Subject: [PATCH 3/9] minor cleanup Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- CHANGELOG.rst | 2 + modelopt/torch/nas/plugins/megatron.py | 43 ++++++++++++++++++- .../torch/nas/plugins/transformer_engine.py | 39 ----------------- modelopt/torch/utils/plugins/mbridge.py | 15 +++---- tests/_test_utils/torch/megatron/models.py | 20 +++------ .../test_megatron_mamba_dynamic_modules.py | 4 +- 6 files changed, 57 insertions(+), 66 deletions(-) delete mode 100644 modelopt/torch/nas/plugins/transformer_engine.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 483f1dcb8..e955aee1f 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -33,6 +33,8 @@ NVIDIA Model Optimizer Changelog - Add support for Nemotron-3 (NemotronHForCausalLM) model quantization and support for NemotronH MoE expert support in ``auto_quantize`` grouping and scoring rules. - Add support for block-granular RHT for non-power-of-2 dimensions. - Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes. +- Add NemotronH MoE expert support in ``auto_quantize`` grouping and scoring rules. +- Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that from user's perspective, this is only internal implementation improvement and does not affect the usage of the pruning workflow. **Misc** diff --git a/modelopt/torch/nas/plugins/megatron.py b/modelopt/torch/nas/plugins/megatron.py index 5177981ba..d67936b0e 100644 --- a/modelopt/torch/nas/plugins/megatron.py +++ b/modelopt/torch/nas/plugins/megatron.py @@ -15,12 +15,14 @@ """Plugin to add NAS/Pruning support for megatron-core Language models like GPT and Mamba.""" +import copy import types from abc import ABC from collections.abc import Callable, Sequence import torch import torch.nn as nn +import transformer_engine as te from megatron.core.extensions.transformer_engine import ( TEColumnParallelLinear, TEDotProductAttention, @@ -29,6 +31,7 @@ ) from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.models.gpt import GPTModel +from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec from megatron.core.parallel_state import is_pipeline_first_stage, is_pipeline_last_stage from megatron.core.tensor_parallel.layers import ( ColumnParallelLinear, @@ -43,6 +46,7 @@ from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.moe.router import TopKRouter from megatron.core.transformer.moe.shared_experts import SharedExpertMLP +from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_layer import TransformerLayer from modelopt.torch.nas.modules import DynamicModuleList @@ -53,7 +57,7 @@ from modelopt.torch.utils import make_divisible from ..hparams.concat import build_concat_hp -from ..modules import _DynamicLayerNorm # noqa: F401 (re-exported for tests) +from ..modules import _DynamicLayerNorm from ..modules.utils import get_sliced_tensor, get_sliced_tensor_by_slices from ..registry import DMRegistry from ..traced_hp import TracedHp @@ -63,6 +67,9 @@ try: import mamba_ssm # noqa: F401 from megatron.core.models.mamba import MambaModel + from megatron.core.models.mamba.mamba_layer_specs import ( + mamba_stack_spec as _te_mamba_stack_spec, + ) from megatron.core.ssm.mamba_layer import MambaLayer from megatron.core.ssm.mamba_mixer import ExtendedRMSNorm, MambaMixer @@ -72,7 +79,23 @@ except ImportError: HAS_MAMBA = False -__all__ = [] +__all__ = ["get_te_mamba_stack_spec"] + + +# TODO: Maybe upstream this to Megatron-LM +def get_te_mamba_stack_spec(moe_grouped_gemm: bool = False) -> ModuleSpec: + """Return the TE Mamba stack spec.""" + assert HAS_MAMBA + if moe_grouped_gemm: + return _te_mamba_stack_spec + + # The upstream TE mamba stack spec hardcodes TEGroupedMLP for MoE. + # Replace it with SequentialMLP (TE linear layers, no grouped gemm dependency). + te_mamba_stack_spec = copy.deepcopy(_te_mamba_stack_spec) + te_mamba_stack_spec.submodules.moe_layer.submodules.mlp = get_moe_module_spec( + use_te=True, num_experts=8, moe_grouped_gemm=False + ) + return te_mamba_stack_spec # Local Parallel Linear DynamicModules ########################################################################## @@ -242,6 +265,22 @@ def export(self) -> torch.nn.Module: return super().export() +# TE Normalization DynamicModule ################################################################### +@DMRegistry.register( + {te.pytorch.LayerNorm: "te.pytorch.LayerNorm", te.pytorch.RMSNorm: "te.pytorch.RMSNorm"} +) +class _DynamicTENorm(_DynamicLayerNorm): + """A ``te.pytorch.{Layer/RMS}Norm`` layer with dynamic hyperparams.""" + + def _setup(self, *, num_features: TracedHp): + """Setup the TENorm dynamic module with pre-defined num_features hparam.""" + self._register_hparam("num_features", num_features) + # register dynamic attributes + self._register_dynamic_attribute("weight", self._cut_to_active_features) + if hasattr(self, "bias"): # Bias is not present in RMSNorm + self._register_dynamic_attribute("bias", self._cut_to_active_features) + + # MLP DynamicModule ################################################################################ @DMRegistry.register( { diff --git a/modelopt/torch/nas/plugins/transformer_engine.py b/modelopt/torch/nas/plugins/transformer_engine.py deleted file mode 100644 index 3392c0858..000000000 --- a/modelopt/torch/nas/plugins/transformer_engine.py +++ /dev/null @@ -1,39 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Plugin to add NAS support for Transformer Engine modules.""" - -import transformer_engine as te - -from ..modules import _DynamicLayerNorm -from ..registry import DMRegistry -from ..traced_hp import TracedHp - -__all__ = ["_DynamicTENorm"] - - -@DMRegistry.register( - {te.pytorch.LayerNorm: "te.pytorch.LayerNorm", te.pytorch.RMSNorm: "te.pytorch.RMSNorm"} -) -class _DynamicTENorm(_DynamicLayerNorm): - """A ``te.pytorch.{Layer/RMS}Norm`` layer with dynamic hyperparams.""" - - def _setup(self, *, num_features: TracedHp): - """Setup the TENorm dynamic module with pre-defined num_features hparam.""" - self._register_hparam("num_features", num_features) - # register dynamic attributes - self._register_dynamic_attribute("weight", self._cut_to_active_features) - if hasattr(self, "bias"): # Bias is not present in RMSNorm - self._register_dynamic_attribute("bias", self._cut_to_active_features) diff --git a/modelopt/torch/utils/plugins/mbridge.py b/modelopt/torch/utils/plugins/mbridge.py index 94cdf87cf..04293704c 100644 --- a/modelopt/torch/utils/plugins/mbridge.py +++ b/modelopt/torch/utils/plugins/mbridge.py @@ -23,12 +23,9 @@ from megatron.bridge.data.builders.hf_dataset import HFDatasetConfig from megatron.bridge.data.loaders import setup_data_iterators from megatron.bridge.data.utils import get_dataset_provider -from megatron.bridge.models.gpt_provider import GPTModelProvider, modelopt_transformer_layer_spec +from megatron.bridge.models.gpt_provider import GPTModelProvider from megatron.bridge.models.hf_pretrained.utils import is_safe_repo -from megatron.bridge.models.mamba.mamba_provider import ( - MambaModelProvider, - modelopt_mamba_stack_spec, -) +from megatron.bridge.models.mamba.mamba_provider import MambaModelProvider from megatron.bridge.models.nemotronh.nemotron_h_provider import NemotronHModelProvider from megatron.bridge.training.config import ( CheckpointConfig, @@ -50,6 +47,7 @@ from megatron.core.utils import unwrap_model from transformers import AutoTokenizer +from modelopt.torch.nas.plugins.megatron import get_te_mamba_stack_spec from modelopt.torch.utils import get_dataset_samples, print_rank_0, warn_rank_0 __all__ = ["get_hf_mbridge_calibration_loop", "load_mbridge_model_from_hf"] @@ -94,12 +92,9 @@ def load_mbridge_model_from_hf( assert hasattr(provider, key), f"{type(provider)} does not have attribute {key}" setattr(provider, key, value) - print_rank_0("Setting ModelOpt spec for model provider") if isinstance(provider, MambaModelProvider): - provider.mamba_stack_spec = modelopt_mamba_stack_spec - else: - provider.transformer_layer_spec = modelopt_transformer_layer_spec - + # disable moe_grouped_gemm in default TE spec until its supported + provider.mamba_stack_spec = get_te_mamba_stack_spec(moe_grouped_gemm=False) provider.finalize() if init_model_parallel: provider.initialize_model_parallel(seed=0) diff --git a/tests/_test_utils/torch/megatron/models.py b/tests/_test_utils/torch/megatron/models.py index 5ccbe411a..e20f8b793 100644 --- a/tests/_test_utils/torch/megatron/models.py +++ b/tests/_test_utils/torch/megatron/models.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy from warnings import warn import torch @@ -25,7 +24,6 @@ get_gpt_layer_local_spec, get_gpt_layer_with_transformer_engine_spec, ) -from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec from megatron.core.models.mamba import MambaModel from megatron.core.parallel_state import is_pipeline_first_stage, is_pipeline_last_stage from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear @@ -33,6 +31,7 @@ from megatron.core.transformer.transformer_config import TransformerConfig from modelopt.torch.export.unified_export_megatron import import_mcore_gpt_from_hf +from modelopt.torch.nas.plugins.megatron import get_te_mamba_stack_spec try: from megatron.core.extensions.transformer_engine import TENorm @@ -44,19 +43,9 @@ HAS_TE = False try: - from megatron.core.models.mamba.mamba_layer_specs import ( - mamba_stack_spec as _te_mamba_stack_spec, - ) from megatron.core.post_training.modelopt.mamba.model_specs import get_mamba_stack_modelopt_spec from megatron.core.ssm.mamba_layer import MambaLayer # noqa: F401 - # The upstream TE mamba stack spec hardcodes TEGroupedMLP for MoE. - # Replace it with SequentialMLP (TE linear layers, no grouped gemm dependency). - te_mamba_stack_spec = copy.deepcopy(_te_mamba_stack_spec) - te_mamba_stack_spec.submodules.moe_layer.submodules.mlp = get_moe_module_spec( - use_te=True, num_experts=8, moe_grouped_gemm=False - ) - HAS_MAMBA = True except ImportError as e: warn(f"Mamba not installed: {e}") @@ -152,6 +141,7 @@ def get_mcore_gpt_model( bf16: bool = True, use_te: bool = False, # MoE-specific parameters + moe_grouped_gemm: bool = False, moe_ffn_hidden_size: int | None = None, moe_shared_expert_intermediate_size: int | None = None, num_moe_experts: int | None = None, @@ -195,6 +185,7 @@ def squared_relu(x): bf16=bf16, # MoE-specific parameters moe_router_dtype=None, + moe_grouped_gemm=moe_grouped_gemm, moe_ffn_hidden_size=moe_ffn_hidden_size, moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, moe_router_enable_expert_bias=True, @@ -217,6 +208,7 @@ def squared_relu(x): assert HAS_APEX, "Apex not installed" transformer_layer_spec = get_gpt_layer_local_spec( num_experts=num_moe_experts, + moe_grouped_gemm=moe_grouped_gemm, normalization=normalization, ) else: @@ -320,6 +312,7 @@ def get_mcore_mamba_hybrid_model( mamba_num_groups: int = 2, # MoE-specific parameters skip_moe: bool = False, + moe_grouped_gemm: bool = False, moe_ffn_hidden_size: int | None = 64, moe_shared_expert_intermediate_size: int | None = 32, num_moe_experts: int | None = 8, @@ -353,6 +346,7 @@ def get_mcore_mamba_hybrid_model( mamba_head_dim=mamba_head_dim, mamba_num_groups=mamba_num_groups, num_moe_experts=num_moe_experts, + moe_grouped_gemm=moe_grouped_gemm, moe_ffn_hidden_size=moe_ffn_hidden_size, moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, add_bias_linear=False, @@ -391,7 +385,7 @@ def get_mcore_mamba_hybrid_model( print(f"Using `{hybrid_override_pattern=}` for building MambaModel") if transformer_impl == "transformer_engine": - mamba_spec = te_mamba_stack_spec + mamba_spec = get_te_mamba_stack_spec(moe_grouped_gemm=moe_grouped_gemm) else: mamba_spec = get_mamba_stack_modelopt_spec(remap_te_layernorm=True) diff --git a/tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py b/tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py index 4be325cf7..db8b9e10b 100644 --- a/tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py +++ b/tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py @@ -31,11 +31,11 @@ _DynamicColumnParallelLinear, _DynamicEmbedding, _DynamicExtendedRMSNorm, - _DynamicLayerNorm, _DynamicMambaLayer, _DynamicMambaMixer, _DynamicMCoreLanguageModel, _DynamicTELayerNormColumnParallelLinear, + _DynamicTENorm, _DynamicTERowParallelLinear, ) from modelopt.torch.nas.traced_hp import TracedHp @@ -104,7 +104,7 @@ def _test_mamba_search_space(rank, size): if layer.mixer.rmsnorm: assert isinstance(layer.mixer.norm, _DynamicExtendedRMSNorm) if is_pipeline_last_stage(): - assert isinstance(model.decoder.final_norm, _DynamicLayerNorm) + assert isinstance(model.decoder.final_norm, _DynamicTENorm) assert isinstance(model.output_layer, _DynamicColumnParallelLinear) # NOTE: `search_space_size` does not reduce across TP/PP groups From f9cb2d60551d5eaa3121624b4c7d35ab61f98940 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Wed, 11 Mar 2026 15:59:57 -0700 Subject: [PATCH 4/9] Fix TELayernorm importance hook Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- CHANGELOG.rst | 2 +- .../torch/prune/plugins/mcore_minitron.py | 90 ++++++++++++++----- 2 files changed, 69 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index e955aee1f..389a943c8 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -34,7 +34,7 @@ NVIDIA Model Optimizer Changelog - Add support for block-granular RHT for non-power-of-2 dimensions. - Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes. - Add NemotronH MoE expert support in ``auto_quantize`` grouping and scoring rules. -- Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that from user's perspective, this is only internal implementation improvement and does not affect the usage of the pruning workflow. +- Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics. **Misc** diff --git a/modelopt/torch/prune/plugins/mcore_minitron.py b/modelopt/torch/prune/plugins/mcore_minitron.py index 1027a8d6e..1c6ea73a4 100644 --- a/modelopt/torch/prune/plugins/mcore_minitron.py +++ b/modelopt/torch/prune/plugins/mcore_minitron.py @@ -34,6 +34,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from megatron.core.extensions.transformer_engine import TELayerNormColumnParallelLinear from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.models.mamba.mamba_model import MambaModel from megatron.core.parallel_state import ( @@ -865,6 +866,11 @@ def cleanup(self) -> None: handle.remove() self._hooks.clear() + # Unpatch return_layernorm_output on fused TELayerNormColumnParallelLinear modules + for m in self.model.modules(): + if isinstance(m, TELayerNormColumnParallelLinear): + m.return_layernorm_output = False + def get_layer_scores(self) -> dict[int, torch.Tensor]: """Get the layer scores (1-indexed) from the model. @@ -941,25 +947,48 @@ def _register_hidden_size_importance( """Register importance estimators for Language Model (GPT/Mamba) modules.""" module._register_temp_attribute("_activations", {}) - def _emb_layernorm_forward_hook(mod, module_inner, input, output): - """Hook to collect activations for importance estimation. + def _collect_activations(mod, module_id, activations_tensor): + """Accumulate activation importance scores for a given module.""" + activations_tensor = activations_tensor.to(torch.float32) + activations = activations_tensor.abs().mean(dim=0) # [batch_size, hidden_size] + activations = activations.pow(2).sum(dim=0) + if module_id not in mod._activations: + mod._activations[module_id] = activations + else: + mod._activations[module_id] += ( + activations # aggregate sum instead of mean of scores for simplicity + ) - Activations are computed as mean over seq_len and then squared and summed over batch_size. - Later we take the square root of the sum to get the L2 norm. + def _fused_ln_linear_forward_hook(mod, module_inner, input, output): + """Hook on TELayerNormColumnParallelLinear with return_layernorm_output=True. + + Extracts the exact layernorm output from TE's fused kernel and restores + the normal return format so downstream code is not affected. """ + # Output format with return_layernorm_output=True: + # te_return_bias=True: MCore returns (linear_out, bias, ln_out) + # te_return_bias=False: MCore returns ((linear_out, ln_out), None) + if module_inner.te_return_bias: + linear_out, bias, ln_out = output + fixed_output = (linear_out, bias) + else: + (linear_out, ln_out), bias = output + fixed_output = (linear_out, bias) + + # Gather over all TP regions + # NOTE: This is not used at the moment since we restrict to TP=1 + ln_out = gather_from_tensor_model_parallel_region(ln_out).detach() + _collect_activations(mod, id(module_inner), ln_out) + + # Return the normal output format so downstream code (e.g. SelfAttention) is not affected + return fixed_output + + def _layernorm_forward_hook(mod, module_inner, input, output): + """Hook on separate layernorm modules (e.g. TENorm for MoE pre_mlp_layernorm).""" # Gather output [seq_len, batch_size, hidden_size] over all TP regions # NOTE: This is not used at the moment since we restrict to TP=1 output = gather_from_tensor_model_parallel_region(output).detach() - - output = output.to(torch.float32) # use full precision to avoid overflow - activations = output.abs().mean(dim=0) # [batch_size, hidden_size] - activations = activations.pow(2).sum(dim=0) - if id(module_inner) not in mod._activations: - mod._activations[id(module_inner)] = activations - else: - mod._activations[id(module_inner)] += ( - activations # aggregate sum instead of mean of scores for simplicity - ) + _collect_activations(mod, id(module_inner), output) def _estimate_hidden_size_importance(mod): """Return the activation magnitude-based importance of the hidden_size.""" @@ -973,27 +1002,44 @@ def _estimate_hidden_size_importance(mod): torch.distributed.all_reduce(activations, op=torch.distributed.ReduceOp.SUM) return activations - # Register hooks for all layers - # For TE spec, layernorms may be IdentityOp (fused into column parallel linear layers). - # Hooking on IdentityOp still works — it gives pre-layernorm activations. + # Register hooks to collect post-layernorm activations for hidden_size importance. + # Layernorms are fused into TELayerNormColumnParallelLinear. We temporarily + # patch return_layernorm_output=True so TE's fused kernel returns the layernorm output. + # For MoE layers, pre_mlp_layernorm is a separate TENorm — use a regular forward hook. + for m in module.modules(): + if isinstance(m, TELayerNormColumnParallelLinear): + m.return_layernorm_output = True + for layer in module.decoder.layers: if isinstance(layer, _DynamicTransformerLayer): if isinstance(layer.self_attention, _DynamicSelfAttention): + # input_layernorm is fused into self_attention.linear_qkv registry.register_hook( - layer.input_layernorm, - partial(_emb_layernorm_forward_hook, module), + layer.self_attention.linear_qkv, + partial(_fused_ln_linear_forward_hook, module), hook_type="forward", ) - if isinstance(layer.mlp, (_DynamicMLP, _DynamicSequentialMLP, _DynamicMoELayer)): + if isinstance(layer.mlp, _DynamicMoELayer): + # MoE layers have a separate pre_mlp_layernorm (TENorm, not IdentityOp) registry.register_hook( layer.pre_mlp_layernorm, - partial(_emb_layernorm_forward_hook, module), + partial(_layernorm_forward_hook, module), + hook_type="forward", + ) + elif isinstance(layer.mlp, _DynamicMLP): + # Dense MLP: pre_mlp_layernorm is fused into mlp.linear_fc1 + registry.register_hook( + layer.mlp.linear_fc1, + partial(_fused_ln_linear_forward_hook, module), hook_type="forward", ) elif isinstance(layer, _DynamicMambaLayer): + # Mamba norm is fused into mixer.in_proj registry.register_hook( - layer.norm, partial(_emb_layernorm_forward_hook, module), hook_type="forward" + layer.mixer.in_proj, + partial(_fused_ln_linear_forward_hook, module), + hook_type="forward", ) registry.register_importance( From 3a62d4c13329c97160535314c5124674c0c88567 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Thu, 12 Mar 2026 00:25:23 -0700 Subject: [PATCH 5/9] Fix for Nemotron Nano on 26.02.01 Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- CHANGELOG.rst | 12 +++++++++--- examples/megatron_bridge/README.md | 6 +++--- examples/megatron_bridge/prune_minitron.py | 4 ++-- modelopt/torch/utils/plugins/mbridge.py | 16 ++++++++++------ tests/_test_utils/torch/megatron/models.py | 1 + 5 files changed, 25 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 389a943c8..16598837b 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,7 +1,14 @@ NVIDIA Model Optimizer Changelog ================================ -0.43 (2026-03-xx) +0.44 (2026-05-xx) +^^^^^^^^^^^^^^^^^ + +**New Features** + +- Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics. + +0.43 (2026-04-09) ^^^^^^^^^^^^^^^^^ **Bug Fixes** @@ -34,14 +41,13 @@ NVIDIA Model Optimizer Changelog - Add support for block-granular RHT for non-power-of-2 dimensions. - Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes. - Add NemotronH MoE expert support in ``auto_quantize`` grouping and scoring rules. -- Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics. **Misc** - Migrated project metadata from ``setup.py`` to a fully declarative ``pyproject.toml``. - Enable experimental Python 3.13 wheel support and unit tests in CI/CD. -0.42 (2026-02-xx) +0.42 (2026-03-10) ^^^^^^^^^^^^^^^^^ **Bug Fixes** diff --git a/examples/megatron_bridge/README.md b/examples/megatron_bridge/README.md index db9b60090..3c2c0034c 100644 --- a/examples/megatron_bridge/README.md +++ b/examples/megatron_bridge/README.md @@ -16,9 +16,9 @@ This directory contains examples of using Model Optimizer with [NeMo Megatron-Br ## Pre-Requisites -Running these examples requires many additional dependencies to be installed (e.g., Megatron-Bridge, Megatron-core, etc.), hence we strongly recommend directly using the NeMo container (e.g., `nvcr.io/nvidia/nemo:26.02`) which has all the dependencies installed. +Running these examples requires many additional dependencies to be installed (e.g., Megatron-Bridge, Megatron-core, etc.), hence we strongly recommend directly using the NeMo container (e.g., `nvcr.io/nvidia/nemo:26.02.01`) which has all the dependencies installed. -To get the latest ModelOpt features and examples scripts, mount your Model-Optimizer repo to the container. +To get the ModelOpt examples scripts, mount your Model-Optimizer repo to the container as follows: ```bash export MODELOPT_DIR=${PWD}/Model-Optimizer # or set to your local Model-Optimizer repository path if you have cloned it @@ -26,7 +26,7 @@ if [ ! -d "${MODELOPT_DIR}" ]; then git clone https://github.com/NVIDIA/Model-Optimizer.git ${MODELOPT_DIR} fi -export DOCKER_IMAGE=nvcr.io/nvidia/nemo:26.02 +export DOCKER_IMAGE=nvcr.io/nvidia/nemo:26.02.01 docker run \ --gpus all \ --shm-size=16GB \ diff --git a/examples/megatron_bridge/prune_minitron.py b/examples/megatron_bridge/prune_minitron.py index deecf0f8d..662b9b8b9 100644 --- a/examples/megatron_bridge/prune_minitron.py +++ b/examples/megatron_bridge/prune_minitron.py @@ -241,7 +241,7 @@ def main(args: argparse.Namespace): }, init_model_parallel=True, ) - print_rank_0(f"\nPruning {unwrapped_model=}") + print_rank_0(f"\nPruning model (showing PP rank0): {unwrapped_model}") print_rank_0( f"Original model params: {num2hrb(mtp.mcore_minitron.get_mcore_param_count(unwrapped_model))}" ) @@ -317,7 +317,7 @@ def score_func_mmlu(m): else "hybrid_layer_pattern" ) setattr(provider, hybrid_key, getattr(unwrapped_model, hybrid_key)) - print_rank_0(f"\nPruned {unwrapped_model=}") + print_rank_0(f"\nPruned model (showing PP rank0): {unwrapped_model}") print_rank_0( f"Pruned model params: {num2hrb(mtp.mcore_minitron.get_mcore_param_count(unwrapped_model))}" ) diff --git a/modelopt/torch/utils/plugins/mbridge.py b/modelopt/torch/utils/plugins/mbridge.py index 04293704c..da8c773da 100644 --- a/modelopt/torch/utils/plugins/mbridge.py +++ b/modelopt/torch/utils/plugins/mbridge.py @@ -26,7 +26,6 @@ from megatron.bridge.models.gpt_provider import GPTModelProvider from megatron.bridge.models.hf_pretrained.utils import is_safe_repo from megatron.bridge.models.mamba.mamba_provider import MambaModelProvider -from megatron.bridge.models.nemotronh.nemotron_h_provider import NemotronHModelProvider from megatron.bridge.training.config import ( CheckpointConfig, ConfigContainer, @@ -41,6 +40,7 @@ from megatron.bridge.training.state import GlobalState from megatron.bridge.training.tokenizers.config import TokenizerConfig from megatron.core.models.gpt import GPTModel +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.models.mamba import MambaModel from megatron.core.parallel_state import get_data_parallel_group from megatron.core.transformer.module import MegatronModule @@ -92,9 +92,15 @@ def load_mbridge_model_from_hf( assert hasattr(provider, key), f"{type(provider)} does not have attribute {key}" setattr(provider, key, value) + # disable moe_grouped_gemm in default TE spec until its supported if isinstance(provider, MambaModelProvider): - # disable moe_grouped_gemm in default TE spec until its supported provider.mamba_stack_spec = get_te_mamba_stack_spec(moe_grouped_gemm=False) + else: + provider.transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=provider.num_moe_experts, + moe_grouped_gemm=False, + qk_layernorm=provider.qk_layernorm, + ) provider.finalize() if init_model_parallel: provider.initialize_model_parallel(seed=0) @@ -174,9 +180,6 @@ def get_hf_mbridge_calibration_loop( global_batch_size = micro_batch_size num_iters = num_samples // global_batch_size - # NOTE: Issue with NemotronH tokenizer's len() hence using use_fast=True as a WAR - use_fast_tokenizer = isinstance(provider, NemotronHModelProvider) - cfg = ConfigContainer( model=provider, train=TrainingConfig( @@ -198,9 +201,10 @@ def get_hf_mbridge_calibration_loop( tokenizer=TokenizerConfig( tokenizer_type="HuggingFaceTokenizer", tokenizer_model=hf_model_name_or_path, + # NOTE: Issue with Nemotron Nano v2 tokenizer returning bool hence using use_fast=True as a WAR hf_tokenizer_kwargs={ "trust_remote_code": trust_remote_code, - "use_fast": use_fast_tokenizer, + "use_fast": tokenizer.is_fast, }, ), # Unused diff --git a/tests/_test_utils/torch/megatron/models.py b/tests/_test_utils/torch/megatron/models.py index e20f8b793..2d6fe2a6e 100644 --- a/tests/_test_utils/torch/megatron/models.py +++ b/tests/_test_utils/torch/megatron/models.py @@ -221,6 +221,7 @@ def squared_relu(x): else: transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( num_experts=num_moe_experts, + moe_grouped_gemm=moe_grouped_gemm, ) model = GPTModel( From f11d23e07c9d019c733ea580f4cb8872bb657f1a Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Thu, 12 Mar 2026 07:00:07 -0700 Subject: [PATCH 6/9] Fix NAS search for MoE models Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- modelopt/torch/nas/plugins/megatron.py | 17 +++++++++++++---- modelopt/torch/prune/plugins/mcore_minitron.py | 1 - 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/modelopt/torch/nas/plugins/megatron.py b/modelopt/torch/nas/plugins/megatron.py index d67936b0e..85e471fea 100644 --- a/modelopt/torch/nas/plugins/megatron.py +++ b/modelopt/torch/nas/plugins/megatron.py @@ -664,10 +664,16 @@ def _setup(self, *, hidden_size: TracedHp): def forward(self, *args, **kwargs): """Forward pass for the MoE layer.""" - # Dont allow forward if model is sorted / trimmed unless exported (reinitializing token dispatcher correctly) - if isinstance(self, DynamicModule) and ( - self.get_hparam("num_moe_experts")._slice_order is not None - or self.get_hparam("num_moe_experts").active != self.get_hparam("num_moe_experts").max + # Dont allow forward if model is sorted / trimmed unless the token dispatcher has been + # reinitialized (via _export_reinit_token_dispatcher in _prune or export). + if ( + isinstance(self, DynamicModule) + and not getattr(self, "_token_dispatcher_reinitialized", False) + and ( + self.get_hparam("num_moe_experts")._slice_order is not None + or self.get_hparam("num_moe_experts").active + != self.get_hparam("num_moe_experts").max + ) ): raise RuntimeError("Only run forward after exporting the pruned model") return super().forward(*args, **kwargs) @@ -699,6 +705,9 @@ def _export_reinit_token_dispatcher(self) -> None: if self.use_shared_expert and self.shared_expert_overlap: self.token_dispatcher.set_shared_experts(self.shared_experts) + # Allow forward after token dispatcher reinitialization + self._token_dispatcher_reinitialized = True + def export(self) -> torch.nn.Module: """Export the dynamic module to a standard MoELayer.""" self.router.export() diff --git a/modelopt/torch/prune/plugins/mcore_minitron.py b/modelopt/torch/prune/plugins/mcore_minitron.py index 1c6ea73a4..35a3b4933 100644 --- a/modelopt/torch/prune/plugins/mcore_minitron.py +++ b/modelopt/torch/prune/plugins/mcore_minitron.py @@ -385,7 +385,6 @@ def _prune(self, export_config: dict, prune_depth: bool = True) -> None: for m in self.model.modules(): if isinstance(m, _DynamicMoELayer): m._export_reinit_token_dispatcher() - break def search_best_arch_by_params(self) -> dict: """Search for the best architecture based on the given parameters constraints. From c2699e9d852e34aaab847b0387df2148ef9d5652 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Thu, 12 Mar 2026 10:50:09 -0700 Subject: [PATCH 7/9] Only save local activations in per-rank ckpt Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .../torch/prune/plugins/mcore_minitron.py | 55 +++++++++---------- .../test_mcore_gpt_minitron_pruning.py | 2 +- 2 files changed, 26 insertions(+), 31 deletions(-) diff --git a/modelopt/torch/prune/plugins/mcore_minitron.py b/modelopt/torch/prune/plugins/mcore_minitron.py index 35a3b4933..b40f97a2d 100644 --- a/modelopt/torch/prune/plugins/mcore_minitron.py +++ b/modelopt/torch/prune/plugins/mcore_minitron.py @@ -183,7 +183,7 @@ class MCoreMinitronSearcher(BaseSearcher): - `top_k`: Number of candidates to consider for score_func validation (default: 10). """ - activations_per_rank: list[dict[str, torch.Tensor]] + local_activations: dict[str, torch.Tensor] layer_scores: dict[int, torch.Tensor] sorted_layers: list[int] | None # 1-indexed sorted list of layer numbers # Dict from params constraint to list of tuples (ss_config, params, score) @@ -208,7 +208,7 @@ def default_search_config(self) -> SearchConfig: def default_state_dict(self) -> SearchStateDict: """Return default state dict for importance scores and activations from forward loop.""" return { - "activations_per_rank": [], + "local_activations": {}, "layer_scores": {}, "sorted_layers": None, "top_k_candidates_per_constraint": {}, @@ -274,8 +274,10 @@ def before_search(self) -> None: def run_search(self) -> None: """Run forward loop to collect activations, sort parameters, and prune the model.""" registry = ImportanceEstimatorRegistry(self.model) - if self.layer_scores and self.activations_per_rank: # Available from checkpoint - registry.set_activations_and_layer_scores(self.activations_per_rank, self.layer_scores) + if self.local_activations and self.layer_scores: # Available from per-rank checkpoint + registry.set_local_activations_and_layer_scores( + self.local_activations, self.layer_scores + ) elif not self.config["skip_sorting"]: assert self.forward_loop is not None is_training = self.model.training @@ -285,8 +287,8 @@ def run_search(self) -> None: self.model.train(is_training) # Store activations and layer scores for re-pruning with different export configs - self.activations_per_rank, self.layer_scores = ( - registry.get_activations_and_layer_scores() + self.local_activations, self.layer_scores = ( + registry.get_local_activations_and_layer_scores() ) self.save_search_checkpoint(verbose=True) @@ -898,45 +900,38 @@ def get_layer_scores(self) -> dict[int, torch.Tensor]: return layer_scores - def get_activations_and_layer_scores( + def get_local_activations_and_layer_scores( self, - ) -> tuple[list[dict[str, torch.Tensor]], dict[int, torch.Tensor]]: - """Get the per-rank activations and layer scores from the model.""" - local_activations = {} - for n, m in self.model.named_modules(): - if hasattr(m, "_activations"): - local_activations[n] = m._activations - activations_per_rank = dist.allgather( - local_activations, group=get_pipeline_model_parallel_group() - ) - assert len(activations_per_rank) == get_pipeline_model_parallel_world_size() + ) -> tuple[dict[str, torch.Tensor], dict[int, torch.Tensor]]: + """Get this rank's local activations and global layer scores from the model. + Each rank saves its own activations to its per-rank checkpoint file (no allgather needed). + Layer scores are gathered across all PP ranks to produce a global ranking. + """ + local_activations = { + n: m._activations for n, m in self.model.named_modules() if hasattr(m, "_activations") + } layer_scores = self.get_layer_scores() - return activations_per_rank, layer_scores + return local_activations, layer_scores - def set_activations_and_layer_scores( + def set_local_activations_and_layer_scores( self, - activations_per_rank: list[dict[str, torch.Tensor]], + local_activations: dict[str, torch.Tensor], layer_scores: dict[int, torch.Tensor], ) -> None: - """Set the pre-computed layer_scores and per-rank activations instead of running forward. + """Set the pre-computed layer_scores and local activations instead of running forward. Args: - activations_per_rank: List of dicts from module name to activations. Should match PP size. - layer_scores: Dict from layer_number (1-indexed) to score. + local_activations: Dict from module name to activations for this rank. + layer_scores: Dict from layer_number (1-indexed) to score (global across all PP ranks). """ - print_rank_0("Loading activations and scores per rank from checkpoint...") - rank = get_pipeline_model_parallel_rank() - pp_size = get_pipeline_model_parallel_world_size() - assert len(activations_per_rank) == pp_size, ( - f"Expected same PP size for stored pruning scores ({len(activations_per_rank)}) as current ({pp_size})!" - ) + print_rank_0("Loading activations and scores from per-rank checkpoint...") for layer in self.model.decoder.layers: layer._scores = layer_scores[layer.layer_number] for n, m in self.model.named_modules(): if hasattr(m, "_activations"): - m._activations = activations_per_rank[rank][n] + m._activations = local_activations[n] # Module-specific registration functions diff --git a/tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py b/tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py index be2f055dc..3dab58f5b 100644 --- a/tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py +++ b/tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py @@ -208,7 +208,7 @@ def forward_loop(m): model, pruning_scores = prune_minitron(model, constraints, config, channel_divisor) if not skip_sorting: assert pruning_scores["layer_scores"] - assert pruning_scores["activations_per_rank"] + assert pruning_scores["local_activations"] # Assert weights are pruned correctly for layer in model.decoder.layers: From 74184b8c4f4bdcfb47514163c35ab804c966298d Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Thu, 19 Mar 2026 08:25:25 -0700 Subject: [PATCH 8/9] Other fixes for MoE / Nemotron-3-Nano Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- CHANGELOG.rst | 5 ++- examples/megatron_bridge/prune_minitron.py | 3 +- modelopt/torch/nas/plugins/megatron.py | 1 + .../torch/prune/plugins/mcore_minitron.py | 33 ++++++++++++------- tests/_test_utils/torch/megatron/models.py | 2 ++ .../test_mcore_mamba_minitron_pruning.py | 2 +- 6 files changed, 32 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 16598837b..bd347506a 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -8,6 +8,10 @@ NVIDIA Model Optimizer Changelog - Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics. +**Bug Fixes** + +- Fix Minitron pruning (``mcore_minitron``) for MoE models. Importance estimation hooks were incorrectly registered for MoE modules and NAS step was hanging before this. + 0.43 (2026-04-09) ^^^^^^^^^^^^^^^^^ @@ -40,7 +44,6 @@ NVIDIA Model Optimizer Changelog - Add support for Nemotron-3 (NemotronHForCausalLM) model quantization and support for NemotronH MoE expert support in ``auto_quantize`` grouping and scoring rules. - Add support for block-granular RHT for non-power-of-2 dimensions. - Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes. -- Add NemotronH MoE expert support in ``auto_quantize`` grouping and scoring rules. **Misc** diff --git a/examples/megatron_bridge/prune_minitron.py b/examples/megatron_bridge/prune_minitron.py index 662b9b8b9..28cefb0b8 100644 --- a/examples/megatron_bridge/prune_minitron.py +++ b/examples/megatron_bridge/prune_minitron.py @@ -264,10 +264,11 @@ def main(args: argparse.Namespace): } if args.prune_target_params is not None: # Restrict search space to a smaller set of candidates + # Allow more choices for MoE FFN as they are generally smaller # NOTE: You can reduce the divisors and increase config['top_k'] to potentially find a better model. ss_config = mtp.mcore_minitron.get_mcore_minitron_config( hidden_size_divisor=256, - ffn_hidden_size_divisor=512, + ffn_hidden_size_divisor=256 if (provider.num_moe_experts or 0) > 0 else 512, mamba_head_dim_divisor=8, num_moe_experts_divisor=8, num_layers_divisor=2, diff --git a/modelopt/torch/nas/plugins/megatron.py b/modelopt/torch/nas/plugins/megatron.py index 85e471fea..19f836f38 100644 --- a/modelopt/torch/nas/plugins/megatron.py +++ b/modelopt/torch/nas/plugins/megatron.py @@ -92,6 +92,7 @@ def get_te_mamba_stack_spec(moe_grouped_gemm: bool = False) -> ModuleSpec: # The upstream TE mamba stack spec hardcodes TEGroupedMLP for MoE. # Replace it with SequentialMLP (TE linear layers, no grouped gemm dependency). te_mamba_stack_spec = copy.deepcopy(_te_mamba_stack_spec) + # num_experts needs to be non-zero te_mamba_stack_spec.submodules.moe_layer.submodules.mlp = get_moe_module_spec( use_te=True, num_experts=8, moe_grouped_gemm=False ) diff --git a/modelopt/torch/prune/plugins/mcore_minitron.py b/modelopt/torch/prune/plugins/mcore_minitron.py index b40f97a2d..14c7a8783 100644 --- a/modelopt/torch/prune/plugins/mcore_minitron.py +++ b/modelopt/torch/prune/plugins/mcore_minitron.py @@ -186,8 +186,8 @@ class MCoreMinitronSearcher(BaseSearcher): local_activations: dict[str, torch.Tensor] layer_scores: dict[int, torch.Tensor] sorted_layers: list[int] | None # 1-indexed sorted list of layer numbers - # Dict from params constraint to list of tuples (ss_config, params, score) - top_k_candidates_per_constraint: dict[float, list[CandidateSubnet]] + # Dict from params constraint to list of all CandidateSubnets fitting that constraint + all_candidates_per_constraint: dict[float, list[CandidateSubnet]] @property def default_search_config(self) -> SearchConfig: @@ -211,7 +211,7 @@ def default_state_dict(self) -> SearchStateDict: "local_activations": {}, "layer_scores": {}, "sorted_layers": None, - "top_k_candidates_per_constraint": {}, + "all_candidates_per_constraint": {}, } def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig: @@ -423,10 +423,7 @@ def search_best_arch_by_params(self) -> dict: ) # 2. Perform grid-search over the search space to find subnets fitting the constraints - if ( - max_params not in self.top_k_candidates_per_constraint - or len(self.top_k_candidates_per_constraint[max_params]) != top_k - ): + if max_params not in self.all_candidates_per_constraint: max_num_layers = self.model.get_hparam("num_layers").max search_space_configs = MCoreMinitronSearcher._generate_search_space_combos( hp_choices, @@ -438,7 +435,7 @@ def search_best_arch_by_params(self) -> dict: selected = [] for ss_config in tqdm( search_space_configs, - desc=f"Finding top {top_k} (`config['top_k']`) candidates fitting the constraints...", + desc="Finding all candidates fitting the constraints...", disable=not dist.is_master(), ): self._prune(ss_config, prune_depth=False) @@ -451,13 +448,13 @@ def search_best_arch_by_params(self) -> dict: sample(self.model, sample_func=max) # reset to max subnet assert len(selected) > 0, "No subnets found fitting the constraints!" print_rank_0(f"Found {len(selected)} candidates fitting the constraints!") - self.top_k_candidates_per_constraint[max_params] = sorted( + self.all_candidates_per_constraint[max_params] = sorted( selected, key=lambda x: x.params, reverse=True - )[:top_k] + ) self.save_search_checkpoint(verbose=True) else: print_rank_0(f"\nUsing top {top_k} candidates from checkpoint") - top_k_candidates = self.top_k_candidates_per_constraint[max_params] + top_k_candidates = self.all_candidates_per_constraint[max_params][:top_k] print_rank_0(f"\n====================\nTop {top_k} candidates:") for candidate in top_k_candidates: @@ -477,6 +474,17 @@ def search_best_arch_by_params(self) -> dict: ) # 4. Validate top-k candidates using the score_func and return the best subnet + # WAR for Nemotron-3-Nano-30B-A3B-BF16. Disable expert bias during candidate eval to prevent in-place + # __setattr__ on dynamically-sliced buffers from corrupting their shape (128 -> 120 elements). + _routers_with_expert_bias = [] + for n, m in self.model.named_modules(): + if hasattr(m, "enable_expert_bias") and m.enable_expert_bias: + print( + f"Temporarily disabling expert bias for {n} on rank {dist.rank()} for candidate evaluation..." + ) + m.enable_expert_bias = False + _routers_with_expert_bias.append(m) + for candidate in tqdm( top_k_candidates, desc=f"Validating top {top_k} candidates on given score_func (this will take some time)...", @@ -501,6 +509,9 @@ def search_best_arch_by_params(self) -> dict: f"\t{candidate.ss_config} -> {num2hrb(candidate.params)} params, {candidate.score:.4f} score\n" ) + for m in _routers_with_expert_bias: + m.enable_expert_bias = True + print_rank_0(f"\n====================\nTop {top_k} candidates with scores:") for candidate in top_k_candidates: print_rank_0( diff --git a/tests/_test_utils/torch/megatron/models.py b/tests/_test_utils/torch/megatron/models.py index 2d6fe2a6e..552463432 100644 --- a/tests/_test_utils/torch/megatron/models.py +++ b/tests/_test_utils/torch/megatron/models.py @@ -350,6 +350,8 @@ def get_mcore_mamba_hybrid_model( moe_grouped_gemm=moe_grouped_gemm, moe_ffn_hidden_size=moe_ffn_hidden_size, moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, + moe_router_enable_expert_bias=True, + moe_router_score_function="sigmoid", add_bias_linear=False, pipeline_dtype=torch.bfloat16 if bf16 else torch.float32, bf16=bf16, diff --git a/tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py b/tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py index c27cace7b..dc58d30e6 100644 --- a/tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py +++ b/tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py @@ -352,7 +352,7 @@ def score_func(m): assert get_mcore_param_count(model) == 10268.0 - top_k = searcher_state["top_k_candidates_per_constraint"][constraints["params"]] + top_k = searcher_state["all_candidates_per_constraint"][constraints["params"]][:10] assert len(top_k) == 10 for actual, (ss_config, params, score) in zip(top_k, expected_top_k): assert actual.ss_config == ss_config, (actual.ss_config, ss_config) From da2f5dc06ad2aef90f783484df6873def5052543 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Thu, 19 Mar 2026 14:42:58 -0700 Subject: [PATCH 9/9] minor Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- examples/megatron_bridge/prune_minitron.py | 1 + examples/pruning/README.md | 1 + modelopt/torch/utils/plugins/mbridge.py | 7 +++++-- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/megatron_bridge/prune_minitron.py b/examples/megatron_bridge/prune_minitron.py index 28cefb0b8..f99c15ba8 100644 --- a/examples/megatron_bridge/prune_minitron.py +++ b/examples/megatron_bridge/prune_minitron.py @@ -240,6 +240,7 @@ def main(args: argparse.Namespace): "seq_length": args.seq_length, }, init_model_parallel=True, + moe_grouped_gemm=False, ) print_rank_0(f"\nPruning model (showing PP rank0): {unwrapped_model}") print_rank_0( diff --git a/examples/pruning/README.md b/examples/pruning/README.md index 41ca6249d..67db15daf 100644 --- a/examples/pruning/README.md +++ b/examples/pruning/README.md @@ -64,6 +64,7 @@ bridge, provider, model, unwrapped_model, tokenizer = load_mbridge_model_from_hf "pipeline_dtype": torch.bfloat16, "seq_length": 4096, }, + moe_grouped_gemm=False, ) # Set up the forward loop to run on 1024 train samples diff --git a/modelopt/torch/utils/plugins/mbridge.py b/modelopt/torch/utils/plugins/mbridge.py index da8c773da..06c3466b4 100644 --- a/modelopt/torch/utils/plugins/mbridge.py +++ b/modelopt/torch/utils/plugins/mbridge.py @@ -59,6 +59,7 @@ def load_mbridge_model_from_hf( trust_remote_code: bool = False, provider_overrides: dict[str, Any] | None = None, init_model_parallel: bool = True, + moe_grouped_gemm: bool = True, ) -> tuple[ AutoBridge, GPTModelProvider | MambaModelProvider, @@ -73,6 +74,8 @@ def load_mbridge_model_from_hf( trust_remote_code: Whether to trust remote code. provider_overrides: Overrides for the provider. init_model_parallel: Whether to initialize model parallel. + moe_grouped_gemm: Whether to use grouped GEMM for MoE. + Pruning does not support grouped GEMM yet. Returns: A tuple of (bridge, provider, model, unwrapped_model, tokenizer). @@ -94,11 +97,11 @@ def load_mbridge_model_from_hf( # disable moe_grouped_gemm in default TE spec until its supported if isinstance(provider, MambaModelProvider): - provider.mamba_stack_spec = get_te_mamba_stack_spec(moe_grouped_gemm=False) + provider.mamba_stack_spec = get_te_mamba_stack_spec(moe_grouped_gemm=moe_grouped_gemm) else: provider.transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( num_experts=provider.num_moe_experts, - moe_grouped_gemm=False, + moe_grouped_gemm=moe_grouped_gemm, qk_layernorm=provider.qk_layernorm, ) provider.finalize()