Skip to content

Commit 98d5291

Browse files
minor cleanup
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent cc584fe commit 98d5291

6 files changed

Lines changed: 56 additions & 66 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ NVIDIA Model Optimizer Changelog
2424
- Add ``get_auto_quantize_config`` API to extract a flat quantization config from ``auto_quantize`` search results, enabling re-quantization at different effective bit targets without re-running calibration.
2525
- Improve ``auto_quantize`` checkpoint/resume: calibration state is now saved and restored across runs, avoiding redundant calibration when resuming a search.
2626
- Add NemotronH MoE expert support in ``auto_quantize`` grouping and scoring rules.
27+
- Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that from user's perspective, this is only internal implementation improvement and does not affect the usage of the pruning workflow.
2728

2829
**Misc**
2930

modelopt/torch/nas/plugins/megatron.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515

1616
"""Plugin to add NAS/Pruning support for megatron-core Language models like GPT and Mamba."""
1717

18+
import copy
1819
import types
1920
from abc import ABC
2021
from collections.abc import Callable, Sequence
2122

2223
import torch
2324
import torch.nn as nn
25+
import transformer_engine as te
2426
from megatron.core.extensions.transformer_engine import (
2527
TEColumnParallelLinear,
2628
TEDotProductAttention,
@@ -29,6 +31,7 @@
2931
)
3032
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
3133
from megatron.core.models.gpt import GPTModel
34+
from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec
3235
from megatron.core.parallel_state import is_pipeline_first_stage, is_pipeline_last_stage
3336
from megatron.core.tensor_parallel.layers import (
3437
ColumnParallelLinear,
@@ -43,6 +46,7 @@
4346
from megatron.core.transformer.moe.moe_layer import MoELayer
4447
from megatron.core.transformer.moe.router import TopKRouter
4548
from megatron.core.transformer.moe.shared_experts import SharedExpertMLP
49+
from megatron.core.transformer.spec_utils import ModuleSpec
4650
from megatron.core.transformer.transformer_layer import TransformerLayer
4751

4852
from modelopt.torch.nas.modules import DynamicModuleList
@@ -53,7 +57,7 @@
5357
from modelopt.torch.utils import make_divisible
5458

5559
from ..hparams.concat import build_concat_hp
56-
from ..modules import _DynamicLayerNorm # noqa: F401 (re-exported for tests)
60+
from ..modules import _DynamicLayerNorm
5761
from ..modules.utils import get_sliced_tensor, get_sliced_tensor_by_slices
5862
from ..registry import DMRegistry
5963
from ..traced_hp import TracedHp
@@ -63,6 +67,9 @@
6367
try:
6468
import mamba_ssm # noqa: F401
6569
from megatron.core.models.mamba import MambaModel
70+
from megatron.core.models.mamba.mamba_layer_specs import (
71+
mamba_stack_spec as _te_mamba_stack_spec,
72+
)
6673
from megatron.core.ssm.mamba_layer import MambaLayer
6774
from megatron.core.ssm.mamba_mixer import ExtendedRMSNorm, MambaMixer
6875

@@ -72,7 +79,23 @@
7279
except ImportError:
7380
HAS_MAMBA = False
7481

75-
__all__ = []
82+
__all__ = ["get_te_mamba_stack_spec"]
83+
84+
85+
# TODO: Maybe upstream this to Megatron-LM
86+
def get_te_mamba_stack_spec(moe_grouped_gemm: bool = False) -> ModuleSpec:
87+
"""Return the TE Mamba stack spec."""
88+
assert HAS_MAMBA
89+
if moe_grouped_gemm:
90+
return _te_mamba_stack_spec
91+
92+
# The upstream TE mamba stack spec hardcodes TEGroupedMLP for MoE.
93+
# Replace it with SequentialMLP (TE linear layers, no grouped gemm dependency).
94+
te_mamba_stack_spec = copy.deepcopy(_te_mamba_stack_spec)
95+
te_mamba_stack_spec.submodules.moe_layer.submodules.mlp = get_moe_module_spec(
96+
use_te=True, num_experts=8, moe_grouped_gemm=False
97+
)
98+
return te_mamba_stack_spec
7699

77100

78101
# Local Parallel Linear DynamicModules ##########################################################################
@@ -242,6 +265,22 @@ def export(self) -> torch.nn.Module:
242265
return super().export()
243266

244267

268+
# TE Normalization DynamicModule ###################################################################
269+
@DMRegistry.register(
270+
{te.pytorch.LayerNorm: "te.pytorch.LayerNorm", te.pytorch.RMSNorm: "te.pytorch.RMSNorm"}
271+
)
272+
class _DynamicTENorm(_DynamicLayerNorm):
273+
"""A ``te.pytorch.{Layer/RMS}Norm`` layer with dynamic hyperparams."""
274+
275+
def _setup(self, *, num_features: TracedHp):
276+
"""Setup the TENorm dynamic module with pre-defined num_features hparam."""
277+
self._register_hparam("num_features", num_features)
278+
# register dynamic attributes
279+
self._register_dynamic_attribute("weight", self._cut_to_active_features)
280+
if hasattr(self, "bias"): # Bias is not present in RMSNorm
281+
self._register_dynamic_attribute("bias", self._cut_to_active_features)
282+
283+
245284
# MLP DynamicModule ################################################################################
246285
@DMRegistry.register(
247286
{

modelopt/torch/nas/plugins/transformer_engine.py

Lines changed: 0 additions & 39 deletions
This file was deleted.

modelopt/torch/utils/plugins/mbridge.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,9 @@
2323
from megatron.bridge.data.builders.hf_dataset import HFDatasetConfig
2424
from megatron.bridge.data.loaders import setup_data_iterators
2525
from megatron.bridge.data.utils import get_dataset_provider
26-
from megatron.bridge.models.gpt_provider import GPTModelProvider, modelopt_transformer_layer_spec
26+
from megatron.bridge.models.gpt_provider import GPTModelProvider
2727
from megatron.bridge.models.hf_pretrained.utils import is_safe_repo
28-
from megatron.bridge.models.mamba.mamba_provider import (
29-
MambaModelProvider,
30-
modelopt_mamba_stack_spec,
31-
)
28+
from megatron.bridge.models.mamba.mamba_provider import MambaModelProvider
3229
from megatron.bridge.models.nemotronh.nemotron_h_provider import NemotronHModelProvider
3330
from megatron.bridge.training.config import (
3431
CheckpointConfig,
@@ -50,6 +47,7 @@
5047
from megatron.core.utils import unwrap_model
5148
from transformers import AutoTokenizer
5249

50+
from modelopt.torch.nas.plugins.megatron import get_te_mamba_stack_spec
5351
from modelopt.torch.utils import get_dataset_samples, print_rank_0, warn_rank_0
5452

5553
__all__ = ["get_hf_mbridge_calibration_loop", "load_mbridge_model_from_hf"]
@@ -94,12 +92,9 @@ def load_mbridge_model_from_hf(
9492
assert hasattr(provider, key), f"{type(provider)} does not have attribute {key}"
9593
setattr(provider, key, value)
9694

97-
print_rank_0("Setting ModelOpt spec for model provider")
9895
if isinstance(provider, MambaModelProvider):
99-
provider.mamba_stack_spec = modelopt_mamba_stack_spec
100-
else:
101-
provider.transformer_layer_spec = modelopt_transformer_layer_spec
102-
96+
# disable moe_grouped_gemm in default TE spec until its supported
97+
provider.mamba_stack_spec = get_te_mamba_stack_spec(moe_grouped_gemm=False)
10398
provider.finalize()
10499
if init_model_parallel:
105100
provider.initialize_model_parallel(seed=0)

tests/_test_utils/torch/megatron/models.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
import copy
1615
from warnings import warn
1716

1817
import torch
@@ -25,14 +24,14 @@
2524
get_gpt_layer_local_spec,
2625
get_gpt_layer_with_transformer_engine_spec,
2726
)
28-
from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec
2927
from megatron.core.models.mamba import MambaModel
3028
from megatron.core.parallel_state import is_pipeline_first_stage, is_pipeline_last_stage
3129
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
3230
from megatron.core.transformer.module import MegatronModule
3331
from megatron.core.transformer.transformer_config import TransformerConfig
3432

3533
from modelopt.torch.export.unified_export_megatron import import_mcore_gpt_from_hf
34+
from modelopt.torch.nas.plugins.megatron import get_te_mamba_stack_spec
3635

3736
try:
3837
from megatron.core.extensions.transformer_engine import TENorm
@@ -44,19 +43,9 @@
4443
HAS_TE = False
4544

4645
try:
47-
from megatron.core.models.mamba.mamba_layer_specs import (
48-
mamba_stack_spec as _te_mamba_stack_spec,
49-
)
5046
from megatron.core.post_training.modelopt.mamba.model_specs import get_mamba_stack_modelopt_spec
5147
from megatron.core.ssm.mamba_layer import MambaLayer # noqa: F401
5248

53-
# The upstream TE mamba stack spec hardcodes TEGroupedMLP for MoE.
54-
# Replace it with SequentialMLP (TE linear layers, no grouped gemm dependency).
55-
te_mamba_stack_spec = copy.deepcopy(_te_mamba_stack_spec)
56-
te_mamba_stack_spec.submodules.moe_layer.submodules.mlp = get_moe_module_spec(
57-
use_te=True, num_experts=8, moe_grouped_gemm=False
58-
)
59-
6049
HAS_MAMBA = True
6150
except ImportError as e:
6251
warn(f"Mamba not installed: {e}")
@@ -152,6 +141,7 @@ def get_mcore_gpt_model(
152141
bf16: bool = True,
153142
use_te: bool = False,
154143
# MoE-specific parameters
144+
moe_grouped_gemm: bool = False,
155145
moe_ffn_hidden_size: int | None = None,
156146
moe_shared_expert_intermediate_size: int | None = None,
157147
num_moe_experts: int | None = None,
@@ -195,6 +185,7 @@ def squared_relu(x):
195185
bf16=bf16,
196186
# MoE-specific parameters
197187
moe_router_dtype=None,
188+
moe_grouped_gemm=moe_grouped_gemm,
198189
moe_ffn_hidden_size=moe_ffn_hidden_size,
199190
moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size,
200191
moe_router_enable_expert_bias=True,
@@ -217,6 +208,7 @@ def squared_relu(x):
217208
assert HAS_APEX, "Apex not installed"
218209
transformer_layer_spec = get_gpt_layer_local_spec(
219210
num_experts=num_moe_experts,
211+
moe_grouped_gemm=moe_grouped_gemm,
220212
normalization=normalization,
221213
)
222214
else:
@@ -320,6 +312,7 @@ def get_mcore_mamba_hybrid_model(
320312
mamba_num_groups: int = 2,
321313
# MoE-specific parameters
322314
skip_moe: bool = False,
315+
moe_grouped_gemm: bool = False,
323316
moe_ffn_hidden_size: int | None = 64,
324317
moe_shared_expert_intermediate_size: int | None = 32,
325318
num_moe_experts: int | None = 8,
@@ -353,6 +346,7 @@ def get_mcore_mamba_hybrid_model(
353346
mamba_head_dim=mamba_head_dim,
354347
mamba_num_groups=mamba_num_groups,
355348
num_moe_experts=num_moe_experts,
349+
moe_grouped_gemm=moe_grouped_gemm,
356350
moe_ffn_hidden_size=moe_ffn_hidden_size,
357351
moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size,
358352
add_bias_linear=False,
@@ -391,7 +385,7 @@ def get_mcore_mamba_hybrid_model(
391385
print(f"Using `{hybrid_override_pattern=}` for building MambaModel")
392386

393387
if transformer_impl == "transformer_engine":
394-
mamba_spec = te_mamba_stack_spec
388+
mamba_spec = get_te_mamba_stack_spec(moe_grouped_gemm=moe_grouped_gemm)
395389
else:
396390
mamba_spec = get_mamba_stack_modelopt_spec(remap_te_layernorm=True)
397391

tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@
3131
_DynamicColumnParallelLinear,
3232
_DynamicEmbedding,
3333
_DynamicExtendedRMSNorm,
34-
_DynamicLayerNorm,
3534
_DynamicMambaLayer,
3635
_DynamicMambaMixer,
3736
_DynamicMCoreLanguageModel,
3837
_DynamicTELayerNormColumnParallelLinear,
38+
_DynamicTENorm,
3939
_DynamicTERowParallelLinear,
4040
)
4141
from modelopt.torch.nas.traced_hp import TracedHp
@@ -104,7 +104,7 @@ def _test_mamba_search_space(rank, size):
104104
if layer.mixer.rmsnorm:
105105
assert isinstance(layer.mixer.norm, _DynamicExtendedRMSNorm)
106106
if is_pipeline_last_stage():
107-
assert isinstance(model.decoder.final_norm, _DynamicLayerNorm)
107+
assert isinstance(model.decoder.final_norm, _DynamicTENorm)
108108
assert isinstance(model.output_layer, _DynamicColumnParallelLinear)
109109

110110
# NOTE: `search_space_size` does not reduce across TP/PP groups

0 commit comments

Comments
 (0)