From 551a46c2d7f9358f2d15394e48ebf9cc831609bf Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Tue, 17 Mar 2026 00:32:35 +0000 Subject: [PATCH 1/9] Add FakeBaseModel for offline speculative decoding and Kimi-K2.5 fixes Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/eagle_utils.py | 2 +- examples/speculative_decoding/launch_train.sh | 13 +- .../speculative/plugins/modeling_fakebase.py | 160 ++++++++++++++++++ .../torch/speculative/plugins/transformers.py | 7 +- modelopt/torch/speculative/utils.py | 10 ++ 5 files changed, 186 insertions(+), 6 deletions(-) create mode 100644 modelopt/torch/speculative/plugins/modeling_fakebase.py diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 8c96a19a7..7f28e68f7 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -154,7 +154,7 @@ def make_eagle_supervised_data_module( assert not data_args.vlm_processor, "Offline data is not supported for VLM." offline_data_path = Path(data_args.offline_data_path) - dumped_files = [str(p) for p in offline_data_path.glob("*.pt")] + dumped_files = [str(p) for p in offline_data_path.rglob("*.pt")] if not dumped_files: raise ValueError(f"No .pt files found in {data_args.offline_data_path}") diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index 074151c5a..92a4f614e 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -126,9 +126,16 @@ set -x SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" NUM_NODES=${NUM_NODES:-1} -GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)} -TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE)) -echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)" +if [[ "$NUM_NODES" != 1 ]]; then + #Multi Node Training + GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)} + TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE)) + echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)" +else + #Single Node Training, GPU can be specified by $CUDA_VISIBLE_DEVICES + TOTAL_GPU=$(python -c "import torch; print(torch.cuda.device_count())") + echo "Total GPUs: $TOTAL_GPU (Single Node Training)" +fi # Calculate save_steps DEFAULT_SAVE_STEPS=$((8192 / TOTAL_GPU)) diff --git a/modelopt/torch/speculative/plugins/modeling_fakebase.py b/modelopt/torch/speculative/plugins/modeling_fakebase.py new file mode 100644 index 000000000..45c5e9f18 --- /dev/null +++ b/modelopt/torch/speculative/plugins/modeling_fakebase.py @@ -0,0 +1,160 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Lightweight fake base model for offline speculative decoding training.""" + +import json + +import torch +import torch.nn as nn +from huggingface_hub import hf_hub_download +from safetensors.torch import load_file as safetensors_load_file +from transformers import PretrainedConfig, PreTrainedModel + + +class FakeBaseConfig(PretrainedConfig): + """Minimal config for FakeBaseModel that supports offline speculative decoding training.""" + + model_type = "fake_base_model" + + def __init__( + self, + num_hidden_layers=None, + hidden_size=None, + vocab_size=None, + max_position_embeddings=None, + dtype=torch.bfloat16, + tie_word_embeddings=False, + **kwargs, + ): + """Initialize FakeBaseConfig with minimal model configuration parameters.""" + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + self.num_hidden_layers = num_hidden_layers + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.dtype = dtype + + +class FakeBaseModel(PreTrainedModel): + """Minimal base model for offline speculative decoding. + + Contains only lm_head, embed_tokens, and necessary configs. + + This lightweight class should works ootb for convert, train, save/reload, and + export in offline speculative decoding workflow, while allowing: + 1. Faster initialization and loading by omitting full model layers. + 2. Compatibility with cases where standard HuggingFace loading is incomplete or unsupported. + + Subclasses should override/define the following attributes: + SOURCE_HF_REPO (str): HuggingFace repository ID for weight retrieval. + INDEX_FILENAME (str): Name of the JSON file listing sharded weight files. + LM_HEAD_KEY (str): Key for the language modeling head in the safetensors state dict. + EMBED_TOKENS_KEY (str): Key for the embedding tokens in the safetensors state dict. + """ + + config_class = FakeBaseConfig + + # Default values; subclasses should override as needed. + SOURCE_HF_REPO: str = None + INDEX_FILENAME: str = "model.safetensors.index.json" + LM_HEAD_KEY: str = "lm_head.weight" + EMBED_TOKENS_KEY: str = "model.embed_tokens.weight" + + def __init__(self, config: FakeBaseConfig): + """Initialize FakeBaseModel and download lm_head/embed_tokens weights from HuggingFace. + + Args: + config (FakeBaseConfig): Model configuration. + """ + super().__init__(config) + self.config = config + self.model = nn.Module() + self.model.layers = nn.ModuleList() + self.model.dtype = config.dtype + self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) + self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False) + + try: + lm_head_w, embed_tokens_w = self._download_lm_head_and_embed_tokens() + assert lm_head_w.shape == (self.config.vocab_size, self.config.hidden_size) + assert embed_tokens_w.shape == (self.config.vocab_size, self.config.hidden_size) + self.lm_head.weight.data.copy_(lm_head_w) + self.embed_tokens.weight.data.copy_(embed_tokens_w) + except Exception as e: + raise ValueError(f"Failed to initialize lm_head and embed_tokens: {e}") + + @classmethod + def from_base_config(cls, base_config: PretrainedConfig): + """Create a FakeBaseModel instance using a configuration from a full, real model. + + Args: + base_config (PretrainedConfig): The original model configuration. + + Returns: + FakeBaseModel: A new instance with the minimal configuration. + """ + config_params = { + "num_hidden_layers": getattr(base_config, "num_hidden_layers", None), + "hidden_size": getattr(base_config, "hidden_size", None), + "vocab_size": getattr(base_config, "vocab_size", None), + "max_position_embeddings": getattr(base_config, "max_position_embeddings", None), + "dtype": getattr(base_config, "dtype", torch.bfloat16), + "tie_word_embeddings": getattr(base_config, "tie_word_embeddings", False), + } + return cls(FakeBaseConfig(**config_params)) + + def _download_lm_head_and_embed_tokens(self): + if self.SOURCE_HF_REPO is None: + raise ValueError("Set SOURCE_HF_REPO as a class attribute or in a subclass.") + + index_json_file = hf_hub_download( + repo_id=self.SOURCE_HF_REPO, + filename=self.INDEX_FILENAME, + ) + with open(index_json_file) as f: + index_data = json.load(f) + + weight_map = index_data.get("weight_map", {}) + lm_head_file = weight_map.get(self.LM_HEAD_KEY) + embed_tokens_file = weight_map.get(self.EMBED_TOKENS_KEY) + + if not lm_head_file or not embed_tokens_file: + raise RuntimeError(f"{self.LM_HEAD_KEY} or {self.EMBED_TOKENS_KEY} not found in index!") + + lm_head_shard_file = hf_hub_download(repo_id=self.SOURCE_HF_REPO, filename=lm_head_file) + embed_tokens_shard_file = hf_hub_download( + repo_id=self.SOURCE_HF_REPO, filename=embed_tokens_file + ) + + lm_head_state = safetensors_load_file(lm_head_shard_file, device="cpu") + embed_tokens_state = safetensors_load_file(embed_tokens_shard_file, device="cpu") + + lm_head_weight = lm_head_state[self.LM_HEAD_KEY] + embed_tokens_weight = embed_tokens_state[self.EMBED_TOKENS_KEY] + + return lm_head_weight, embed_tokens_weight + + def forward(self, *args, **kwargs): + """Not implemented: FakeBaseModel omits full model weights and cannot run inference.""" + raise NotImplementedError("FakeBaseModel forward is not implemented.") + + +class KimiK25FakeBaseModel(FakeBaseModel): + """FakeBaseModel subclass tailored for Kimi-K2.5.""" + + SOURCE_HF_REPO = "moonshotai/Kimi-K2.5" + LM_HEAD_KEY = "language_model.lm_head.weight" + EMBED_TOKENS_KEY = "language_model.model.embed_tokens.weight" diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 1b85c342e..442f26277 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -465,12 +465,15 @@ def _find_base_model_parts(self): """Find model parts from different models and set base_{part}_path attributes.""" base_model_parts_mapping = { "base_model_path": [ + "language_model.model", "model.language_model", "model", "backbone", "language_model.backbone", ], "base_model_embeddings_path": [ + "embed_tokens", + "language_model.model.embed_tokens", "model.embed_tokens", "backbone.embeddings", "language_model.backbone.embeddings", @@ -580,7 +583,7 @@ def modify( # Patch for Kimi-K2-Thinking, avoid quantizing drafter quant_config = getattr(self.config, "quantization_config", None) if isinstance(quant_config, CompressedTensorsConfig): - quant_config.ignore.append("re:.*eagle_module.*") + quant_config.quantization_config.ignore.append("re:.*eagle_module.*") # Set default aux_hidden_state layers if ( @@ -869,7 +872,7 @@ def forward( assert "base_model_outputs" in kwargs base_outputs = EagleBaseModelOutput.from_offline_dict(kwargs["base_model_outputs"]) if base_outputs.logits is None: - base_outputs.logits = self.lm_head(base_outputs.out_hiddens) + base_outputs.logits = self._base_model_lm_head(base_outputs.out_hiddens) past_key_values = None else: base_outputs, past_key_values = self._base_model_forward( diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index 72c5b5dbc..eaf953b4e 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -443,6 +443,16 @@ def patched_fwd_with_lazy_rope_init(self, *args, **kwargs): kimi_k2_module.DeepseekV3Attention._init_rope = lambda self: None kimi_k2_module.DeepseekV3Attention.forward = patched_fwd_with_lazy_rope_init + # Kimi implementation is based on older transformers which use "past_key_value" argument + # We patch it to "past_key_values" for compatibility + original_decoder_layer_forward = kimi_k2_module.DeepseekV3DecoderLayer.forward + + def patched_decoder_layer_fwd(self, *args, **kwargs): + kwargs["past_key_value"] = kwargs.get("past_key_values") + return original_decoder_layer_forward(self, *args, **kwargs) + + kimi_k2_module.DeepseekV3DecoderLayer.forward = patched_decoder_layer_fwd + return getattr(kimi_k2_module, "DeepseekV3DecoderLayer") From dcbfcf4137d9f32cae82856a6add2c3ac9e04f37 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Tue, 17 Mar 2026 02:55:10 +0000 Subject: [PATCH 2/9] refactor Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/main.py | 23 +-- .../scripts/ar_validate.py | 4 +- .../scripts/export_hf_checkpoint.py | 4 +- .../speculative/plugins/modeling_fakebase.py | 187 ++++++++++-------- modelopt/torch/speculative/utils.py | 56 +++++- 5 files changed, 162 insertions(+), 112 deletions(-) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 25817ee94..793126178 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -47,10 +47,8 @@ import modelopt.torch.opt as mto import modelopt.torch.speculative as mtsp -from modelopt.torch.speculative.utils import ( - load_vlm_or_llm_with_kwargs, - patch_transformers5_params_loading, -) +from modelopt.torch.speculative.plugins.modeling_fakebase import FakeBaseArguments +from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading from modelopt.torch.utils import print_rank_0 torch.manual_seed(0) @@ -140,9 +138,10 @@ def train(): TrainingArguments, MedusaArguments, EagleArguments, + FakeBaseArguments, ) ) - model_args, data_args, training_args, medusa_args, eagle_args = ( + model_args, data_args, training_args, medusa_args, eagle_args, fake_base_args = ( parser.parse_args_into_dataclasses() ) training_args.parallelism_config = ParallelismConfig( @@ -169,25 +168,19 @@ def train(): if checkpoint: with patch_transformers5_params_loading(): - _, model = load_vlm_or_llm_with_kwargs( - checkpoint, torch_dtype="auto", trust_remote_code=True - ) + model = load_vlm_or_llm(checkpoint, torch_dtype="auto", trust_remote_code=True) tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) else: # To avoid OOM for large models, we load and convert model on CPU first. # Model will be moved to GPU during HF trainer.init(). - offline_kwargs = {"num_hidden_layers": 0} if use_offline_training else {} - model_config, model = load_vlm_or_llm_with_kwargs( + model = load_vlm_or_llm( model_args.model_name_or_path, + use_offline_training=use_offline_training, + fake_base_args=fake_base_args, torch_dtype="auto", device_map="cpu", trust_remote_code=True, - **offline_kwargs, ) - if use_offline_training: - # When doing offline training, we need to set num_hidden_layers - # since we override it when loading the model for space savings - model.config.num_orig_hidden_layers = model_config.num_hidden_layers tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, model_max_length=training_args.training_seq_len, diff --git a/examples/speculative_decoding/scripts/ar_validate.py b/examples/speculative_decoding/scripts/ar_validate.py index d5c37a895..d1bf31a1a 100644 --- a/examples/speculative_decoding/scripts/ar_validate.py +++ b/examples/speculative_decoding/scripts/ar_validate.py @@ -22,7 +22,7 @@ import modelopt.torch.opt as mto from modelopt.torch.speculative.plugins.transformers import HFARValidation -from modelopt.torch.speculative.utils import load_vlm_or_llm_with_kwargs +from modelopt.torch.speculative.utils import load_vlm_or_llm mto.enable_huggingface_checkpointing() @@ -72,7 +72,7 @@ def main(): accelerator = Accelerator() # Load model and tokenizer - _, model = load_vlm_or_llm_with_kwargs(args.model_path, device_map="auto") + model = load_vlm_or_llm(args.model_path, device_map="auto") tokenizer = AutoTokenizer.from_pretrained(args.model_path) model.eval() model = accelerator.prepare(model) diff --git a/examples/speculative_decoding/scripts/export_hf_checkpoint.py b/examples/speculative_decoding/scripts/export_hf_checkpoint.py index 23a7560f7..925f4b73d 100644 --- a/examples/speculative_decoding/scripts/export_hf_checkpoint.py +++ b/examples/speculative_decoding/scripts/export_hf_checkpoint.py @@ -21,7 +21,7 @@ import modelopt.torch.opt as mto from modelopt.torch.export import export_speculative_decoding -from modelopt.torch.speculative.utils import load_vlm_or_llm_with_kwargs +from modelopt.torch.speculative.utils import load_vlm_or_llm def parse_args(): @@ -38,7 +38,7 @@ def parse_args(): mto.enable_huggingface_checkpointing() args = parse_args() -_, model = load_vlm_or_llm_with_kwargs(args.model_path, torch_dtype="auto") +model = load_vlm_or_llm(args.model_path, torch_dtype="auto") model.eval() with torch.inference_mode(): export_speculative_decoding( diff --git a/modelopt/torch/speculative/plugins/modeling_fakebase.py b/modelopt/torch/speculative/plugins/modeling_fakebase.py index 45c5e9f18..383165598 100644 --- a/modelopt/torch/speculative/plugins/modeling_fakebase.py +++ b/modelopt/torch/speculative/plugins/modeling_fakebase.py @@ -16,14 +16,57 @@ """Lightweight fake base model for offline speculative decoding training.""" import json +import os +from dataclasses import dataclass, field import torch import torch.nn as nn -from huggingface_hub import hf_hub_download +import transformers from safetensors.torch import load_file as safetensors_load_file from transformers import PretrainedConfig, PreTrainedModel +@dataclass +class FakeBaseArguments: + """Arguments for FakeBaseModel used during offline speculative decoding training. + + Pass ``--use_fake_base_model`` to enable. Override the default weight key names for models + that use a non-standard layout (e.g. VLMs with a ``language_model`` prefix). + """ + + use_fake_base_model: bool = field( + default=False, + metadata={ + "help": ( + "Use FakeBaseModel for offline training instead of loading full model weights. " + "Only effective when --offline_data_path is set." + ) + }, + ) + lm_head_key: str = field( + default="lm_head.weight", + metadata={"help": "Safetensors key for the lm_head weight in the checkpoint."}, + ) + embed_tokens_key: str = field( + default="model.embed_tokens.weight", + metadata={"help": "Safetensors key for the embed_tokens weight in the checkpoint."}, + ) + index_filename: str = field( + default="model.safetensors.index.json", + metadata={"help": "Name of the sharded safetensors index JSON file."}, + ) + base_config_attr: str | None = field( + default=None, + metadata={ + "help": ( + "Attribute name on model_config to use as the base config " + "(e.g. 'text_config', 'language_config'). " + "If None, model_config itself is used." + ) + }, + ) + + class FakeBaseConfig(PretrainedConfig): """Minimal config for FakeBaseModel that supports offline speculative decoding training.""" @@ -51,110 +94,80 @@ def __init__( class FakeBaseModel(PreTrainedModel): """Minimal base model for offline speculative decoding. - Contains only lm_head, embed_tokens, and necessary configs. - - This lightweight class should works ootb for convert, train, save/reload, and - export in offline speculative decoding workflow, while allowing: - 1. Faster initialization and loading by omitting full model layers. - 2. Compatibility with cases where standard HuggingFace loading is incomplete or unsupported. + Contains only ``lm_head``, ``embed_tokens``, and the minimal config needed by the EAGLE + training loop. The full model weights are never loaded, keeping memory usage low. - Subclasses should override/define the following attributes: - SOURCE_HF_REPO (str): HuggingFace repository ID for weight retrieval. - INDEX_FILENAME (str): Name of the JSON file listing sharded weight files. - LM_HEAD_KEY (str): Key for the language modeling head in the safetensors state dict. - EMBED_TOKENS_KEY (str): Key for the embedding tokens in the safetensors state dict. + Weights are loaded from a local HuggingFace checkpoint directory. The weight key names + default to standard LLaMA-style paths; override ``lm_head_key`` and ``embed_tokens_key`` + for models with a different layout (e.g. VLMs with a ``language_model`` prefix). """ config_class = FakeBaseConfig - # Default values; subclasses should override as needed. - SOURCE_HF_REPO: str = None - INDEX_FILENAME: str = "model.safetensors.index.json" - LM_HEAD_KEY: str = "lm_head.weight" - EMBED_TOKENS_KEY: str = "model.embed_tokens.weight" - - def __init__(self, config: FakeBaseConfig): - """Initialize FakeBaseModel and download lm_head/embed_tokens weights from HuggingFace. + def __init__(self, source: str, args: "FakeBaseArguments"): + """Load lm_head and embed_tokens from a local HuggingFace checkpoint directory. Args: - config (FakeBaseConfig): Model configuration. + source: Path to a local HuggingFace checkpoint directory. + args: :class:`FakeBaseArguments` controlling key names and config lookup. """ - super().__init__(config) - self.config = config + model_config = transformers.AutoConfig.from_pretrained(source) + base_cfg = ( + getattr(model_config, args.base_config_attr) if args.base_config_attr else model_config + ) + hf_config = FakeBaseConfig( + num_hidden_layers=getattr(base_cfg, "num_hidden_layers", None), + hidden_size=getattr(base_cfg, "hidden_size", None), + vocab_size=getattr(base_cfg, "vocab_size", None), + max_position_embeddings=getattr(base_cfg, "max_position_embeddings", None), + dtype=getattr(base_cfg, "dtype", torch.bfloat16), + tie_word_embeddings=getattr(base_cfg, "tie_word_embeddings", False), + ) + super().__init__(hf_config) self.model = nn.Module() self.model.layers = nn.ModuleList() - self.model.dtype = config.dtype - self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) - self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False) + self.model.dtype = hf_config.dtype + self.embed_tokens = nn.Embedding(hf_config.vocab_size, hf_config.hidden_size) + self.lm_head = nn.Linear(hf_config.hidden_size, hf_config.vocab_size, bias=False) try: - lm_head_w, embed_tokens_w = self._download_lm_head_and_embed_tokens() - assert lm_head_w.shape == (self.config.vocab_size, self.config.hidden_size) - assert embed_tokens_w.shape == (self.config.vocab_size, self.config.hidden_size) + lm_head_w, embed_tokens_w = self._load_weights( + source, args.lm_head_key, args.embed_tokens_key, args.index_filename + ) + assert lm_head_w.shape == (hf_config.vocab_size, hf_config.hidden_size) + assert embed_tokens_w.shape == (hf_config.vocab_size, hf_config.hidden_size) self.lm_head.weight.data.copy_(lm_head_w) self.embed_tokens.weight.data.copy_(embed_tokens_w) except Exception as e: raise ValueError(f"Failed to initialize lm_head and embed_tokens: {e}") - @classmethod - def from_base_config(cls, base_config: PretrainedConfig): - """Create a FakeBaseModel instance using a configuration from a full, real model. - - Args: - base_config (PretrainedConfig): The original model configuration. - - Returns: - FakeBaseModel: A new instance with the minimal configuration. - """ - config_params = { - "num_hidden_layers": getattr(base_config, "num_hidden_layers", None), - "hidden_size": getattr(base_config, "hidden_size", None), - "vocab_size": getattr(base_config, "vocab_size", None), - "max_position_embeddings": getattr(base_config, "max_position_embeddings", None), - "dtype": getattr(base_config, "dtype", torch.bfloat16), - "tie_word_embeddings": getattr(base_config, "tie_word_embeddings", False), - } - return cls(FakeBaseConfig(**config_params)) - - def _download_lm_head_and_embed_tokens(self): - if self.SOURCE_HF_REPO is None: - raise ValueError("Set SOURCE_HF_REPO as a class attribute or in a subclass.") - - index_json_file = hf_hub_download( - repo_id=self.SOURCE_HF_REPO, - filename=self.INDEX_FILENAME, - ) - with open(index_json_file) as f: - index_data = json.load(f) - - weight_map = index_data.get("weight_map", {}) - lm_head_file = weight_map.get(self.LM_HEAD_KEY) - embed_tokens_file = weight_map.get(self.EMBED_TOKENS_KEY) - - if not lm_head_file or not embed_tokens_file: - raise RuntimeError(f"{self.LM_HEAD_KEY} or {self.EMBED_TOKENS_KEY} not found in index!") - - lm_head_shard_file = hf_hub_download(repo_id=self.SOURCE_HF_REPO, filename=lm_head_file) - embed_tokens_shard_file = hf_hub_download( - repo_id=self.SOURCE_HF_REPO, filename=embed_tokens_file - ) - - lm_head_state = safetensors_load_file(lm_head_shard_file, device="cpu") - embed_tokens_state = safetensors_load_file(embed_tokens_shard_file, device="cpu") - - lm_head_weight = lm_head_state[self.LM_HEAD_KEY] - embed_tokens_weight = embed_tokens_state[self.EMBED_TOKENS_KEY] - - return lm_head_weight, embed_tokens_weight + def _load_weights( + self, + source: str, + lm_head_key: str, + embed_tokens_key: str, + index_filename: str, + ): + """Load lm_head and embed_tokens weights from a local checkpoint directory.""" + index_path = os.path.join(source, index_filename) + + if os.path.isfile(index_path): + with open(index_path) as f: + index_data = json.load(f) + weight_map = index_data.get("weight_map", {}) + lm_head_file = weight_map.get(lm_head_key) + embed_tokens_file = weight_map.get(embed_tokens_key) + if not lm_head_file or not embed_tokens_file: + raise RuntimeError(f"{lm_head_key} or {embed_tokens_key} not found in index!") + lm_head_state = safetensors_load_file(os.path.join(source, lm_head_file), device="cpu") + embed_tokens_state = safetensors_load_file( + os.path.join(source, embed_tokens_file), device="cpu" + ) + else: + raise FileNotFoundError(f"No {index_filename} found in {source!r}.") + + return lm_head_state[lm_head_key], embed_tokens_state[embed_tokens_key] def forward(self, *args, **kwargs): """Not implemented: FakeBaseModel omits full model weights and cannot run inference.""" raise NotImplementedError("FakeBaseModel forward is not implemented.") - - -class KimiK25FakeBaseModel(FakeBaseModel): - """FakeBaseModel subclass tailored for Kimi-K2.5.""" - - SOURCE_HF_REPO = "moonshotai/Kimi-K2.5" - LM_HEAD_KEY = "language_model.lm_head.weight" - EMBED_TOKENS_KEY = "language_model.model.embed_tokens.weight" diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index eaf953b4e..bb8ad5283 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -484,21 +484,65 @@ def enable_cp_ttt_patch(): modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = False -def load_vlm_or_llm_with_kwargs(model_name_or_path: str, **kwargs): - """Load a VLM or LLM with kwargs. Returns the model and model config.""" +def load_vlm_or_llm( + model_name_or_path: str, + use_offline_training: bool = False, + fake_base_args=None, + torch_dtype: str | torch.dtype | None = None, + device_map: str | None = None, + trust_remote_code: bool = False, +): + """Load a VLM or LLM. Returns the model. + + When ``use_offline_training=True`` and ``fake_base_args.use_fake_base_model=True``, returns a + :class:`~modelopt.torch.speculative.plugins.modeling_fakebase.FakeBaseModel` containing only + ``lm_head`` and ``embed_tokens``. Otherwise, falls back to loading with + ``num_hidden_layers=0`` for memory efficiency. + + Args: + model_name_or_path: Local path or HuggingFace repo ID of the model. + use_offline_training: Whether to load a memory-efficient model for offline training. + fake_base_args: Optional + :class:`~modelopt.torch.speculative.plugins.modeling_fakebase.FakeBaseArguments`. + If provided and ``use_offline_training=True``, a + :class:`~modelopt.torch.speculative.plugins.modeling_fakebase.FakeBaseModel` is + returned instead of the full model. + torch_dtype: dtype to use when loading the model. + device_map: Device map passed to ``from_pretrained``. + trust_remote_code: Whether to trust remote code. + """ + if use_offline_training and fake_base_args is not None and fake_base_args.use_fake_base_model: + from modelopt.torch.speculative.plugins.modeling_fakebase import FakeBaseModel + + return FakeBaseModel(model_name_or_path, fake_base_args) + model_config = transformers.AutoConfig.from_pretrained( - model_name_or_path, trust_remote_code=True + model_name_or_path, trust_remote_code=trust_remote_code ) if "vl" in model_config.model_type.lower(): model_cls = transformers.AutoModelForVision2Seq else: model_cls = transformers.AutoModelForCausalLM - if kwargs.get("num_hidden_layers") == 0: + extra = {} + if use_offline_training: + extra["num_hidden_layers"] = 0 if hasattr(model_config, "layer_types"): - kwargs["layer_types"] = [] + extra["layer_types"] = [] + + model = model_cls.from_pretrained( + model_name_or_path, + trust_remote_code=trust_remote_code, + torch_dtype=torch_dtype, + device_map=device_map, + **extra, + ) + + if use_offline_training: + # Preserve the original layer count since we loaded with num_hidden_layers=0 + model.config.num_orig_hidden_layers = model_config.num_hidden_layers - return model_config, model_cls.from_pretrained(model_name_or_path, **kwargs) + return model @contextlib.contextmanager From d28279f5ccc140140f0b017042747e1ca9ca07cb Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Tue, 17 Mar 2026 08:26:41 +0000 Subject: [PATCH 3/9] refactor Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/main.py | 9 +- .../speculative/plugins/modeling_fakebase.py | 150 ++++++++---------- .../torch/speculative/plugins/transformers.py | 20 +-- modelopt/torch/speculative/utils.py | 17 +- 4 files changed, 83 insertions(+), 113 deletions(-) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 793126178..a4965fbd4 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -47,7 +47,6 @@ import modelopt.torch.opt as mto import modelopt.torch.speculative as mtsp -from modelopt.torch.speculative.plugins.modeling_fakebase import FakeBaseArguments from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading from modelopt.torch.utils import print_rank_0 @@ -58,6 +57,9 @@ @dataclass class ModelArguments: model_name_or_path: str | None = field(default="TinyLlama/TinyLlama-1.1B-Chat-v1.0") + use_fake_base_for_offline: bool = field( + default=True, metadata={"help": "Whether to use fake base for offline training."} + ) @dataclass @@ -138,10 +140,9 @@ def train(): TrainingArguments, MedusaArguments, EagleArguments, - FakeBaseArguments, ) ) - model_args, data_args, training_args, medusa_args, eagle_args, fake_base_args = ( + model_args, data_args, training_args, medusa_args, eagle_args = ( parser.parse_args_into_dataclasses() ) training_args.parallelism_config = ParallelismConfig( @@ -175,8 +176,8 @@ def train(): # Model will be moved to GPU during HF trainer.init(). model = load_vlm_or_llm( model_args.model_name_or_path, + use_fake_base=model_args.use_fake_base_for_offline, use_offline_training=use_offline_training, - fake_base_args=fake_base_args, torch_dtype="auto", device_map="cpu", trust_remote_code=True, diff --git a/modelopt/torch/speculative/plugins/modeling_fakebase.py b/modelopt/torch/speculative/plugins/modeling_fakebase.py index 383165598..cd94c8eaa 100644 --- a/modelopt/torch/speculative/plugins/modeling_fakebase.py +++ b/modelopt/torch/speculative/plugins/modeling_fakebase.py @@ -17,7 +17,6 @@ import json import os -from dataclasses import dataclass, field import torch import torch.nn as nn @@ -25,46 +24,25 @@ from safetensors.torch import load_file as safetensors_load_file from transformers import PretrainedConfig, PreTrainedModel - -@dataclass -class FakeBaseArguments: - """Arguments for FakeBaseModel used during offline speculative decoding training. - - Pass ``--use_fake_base_model`` to enable. Override the default weight key names for models - that use a non-standard layout (e.g. VLMs with a ``language_model`` prefix). - """ - - use_fake_base_model: bool = field( - default=False, - metadata={ - "help": ( - "Use FakeBaseModel for offline training instead of loading full model weights. " - "Only effective when --offline_data_path is set." - ) - }, - ) - lm_head_key: str = field( - default="lm_head.weight", - metadata={"help": "Safetensors key for the lm_head weight in the checkpoint."}, - ) - embed_tokens_key: str = field( - default="model.embed_tokens.weight", - metadata={"help": "Safetensors key for the embed_tokens weight in the checkpoint."}, - ) - index_filename: str = field( - default="model.safetensors.index.json", - metadata={"help": "Name of the sharded safetensors index JSON file."}, - ) - base_config_attr: str | None = field( - default=None, - metadata={ - "help": ( - "Attribute name on model_config to use as the base config " - "(e.g. 'text_config', 'language_config'). " - "If None, model_config itself is used." - ) - }, - ) +# Candidate module paths searched in order — shared with HFEagleModel._find_base_model_parts +_EMBED_TOKENS_PATHS = [ + "embed_tokens", + "language_model.model.embed_tokens", + "model.embed_tokens", + "backbone.embeddings", + "language_model.backbone.embeddings", + "model.language_model.embed_tokens", +] +_LM_HEAD_PATHS = ["lm_head", "language_model.lm_head"] +_BASE_MODEL_PATHS = [ + "language_model.model", + "model.language_model", + "model", + "backbone", + "language_model.backbone", +] +_VLM_CONFIG_ATTRS = ["text_config", "llm_config"] +_SAFETENSORS_INDEX_FILENAME = "model.safetensors.index.json" class FakeBaseConfig(PretrainedConfig): @@ -97,25 +75,30 @@ class FakeBaseModel(PreTrainedModel): Contains only ``lm_head``, ``embed_tokens``, and the minimal config needed by the EAGLE training loop. The full model weights are never loaded, keeping memory usage low. - Weights are loaded from a local HuggingFace checkpoint directory. The weight key names - default to standard LLaMA-style paths; override ``lm_head_key`` and ``embed_tokens_key`` - for models with a different layout (e.g. VLMs with a ``language_model`` prefix). + Weights are loaded from a local HuggingFace checkpoint directory. Weight key names and + VLM config nesting are auto-detected from the shared path constants. """ config_class = FakeBaseConfig - def __init__(self, source: str, args: "FakeBaseArguments"): + def __init__(self, source: str): """Load lm_head and embed_tokens from a local HuggingFace checkpoint directory. Args: source: Path to a local HuggingFace checkpoint directory. - args: :class:`FakeBaseArguments` controlling key names and config lookup. """ - model_config = transformers.AutoConfig.from_pretrained(source) - base_cfg = ( - getattr(model_config, args.base_config_attr) if args.base_config_attr else model_config + orig_config = transformers.AutoConfig.from_pretrained(source) + # For vlms, detect language model config based on _VLM_CONFIG_ATTRS + base_cfg = next( + ( + getattr(orig_config, attr) + for attr in _VLM_CONFIG_ATTRS + if getattr(orig_config, attr, None) is not None + ), + orig_config, ) - hf_config = FakeBaseConfig( + # Extract necessary info for spec training from base config + config = FakeBaseConfig( num_hidden_layers=getattr(base_cfg, "num_hidden_layers", None), hidden_size=getattr(base_cfg, "hidden_size", None), vocab_size=getattr(base_cfg, "vocab_size", None), @@ -123,49 +106,52 @@ def __init__(self, source: str, args: "FakeBaseArguments"): dtype=getattr(base_cfg, "dtype", torch.bfloat16), tie_word_embeddings=getattr(base_cfg, "tie_word_embeddings", False), ) - super().__init__(hf_config) + super().__init__(config) + # Initialize dummy module and attributes for compatibility with HFEagleModel self.model = nn.Module() self.model.layers = nn.ModuleList() - self.model.dtype = hf_config.dtype - self.embed_tokens = nn.Embedding(hf_config.vocab_size, hf_config.hidden_size) - self.lm_head = nn.Linear(hf_config.hidden_size, hf_config.vocab_size, bias=False) + self.model.dtype = config.dtype + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) try: - lm_head_w, embed_tokens_w = self._load_weights( - source, args.lm_head_key, args.embed_tokens_key, args.index_filename - ) - assert lm_head_w.shape == (hf_config.vocab_size, hf_config.hidden_size) - assert embed_tokens_w.shape == (hf_config.vocab_size, hf_config.hidden_size) + lm_head_w, embed_tokens_w = self._load_weights(source) + assert lm_head_w.shape == (config.vocab_size, config.hidden_size) + assert embed_tokens_w.shape == (config.vocab_size, config.hidden_size) self.lm_head.weight.data.copy_(lm_head_w) self.embed_tokens.weight.data.copy_(embed_tokens_w) except Exception as e: raise ValueError(f"Failed to initialize lm_head and embed_tokens: {e}") - def _load_weights( - self, - source: str, - lm_head_key: str, - embed_tokens_key: str, - index_filename: str, - ): + @staticmethod + def _find_weight_key(weight_map: dict, paths: list[str], label: str) -> str: + """Return the first ``path + '.weight'`` found in ``weight_map``.""" + for path in paths: + key = path + ".weight" + if key in weight_map: + return key + tried = [p + ".weight" for p in paths] + raise RuntimeError(f"Cannot find {label} in checkpoint; tried: {tried}") + + def _load_weights(self, source: str): """Load lm_head and embed_tokens weights from a local checkpoint directory.""" - index_path = os.path.join(source, index_filename) - - if os.path.isfile(index_path): - with open(index_path) as f: - index_data = json.load(f) - weight_map = index_data.get("weight_map", {}) - lm_head_file = weight_map.get(lm_head_key) - embed_tokens_file = weight_map.get(embed_tokens_key) - if not lm_head_file or not embed_tokens_file: - raise RuntimeError(f"{lm_head_key} or {embed_tokens_key} not found in index!") - lm_head_state = safetensors_load_file(os.path.join(source, lm_head_file), device="cpu") - embed_tokens_state = safetensors_load_file( - os.path.join(source, embed_tokens_file), device="cpu" - ) - else: - raise FileNotFoundError(f"No {index_filename} found in {source!r}.") + index_path = os.path.join(source, _SAFETENSORS_INDEX_FILENAME) + if not os.path.isfile(index_path): + raise FileNotFoundError(f"No {_SAFETENSORS_INDEX_FILENAME} found in {source!r}.") + + with open(index_path) as f: + weight_map = json.load(f).get("weight_map", {}) + + lm_head_key = self._find_weight_key(weight_map, _LM_HEAD_PATHS, "lm_head") + embed_tokens_key = self._find_weight_key(weight_map, _EMBED_TOKENS_PATHS, "embed_tokens") + + lm_head_state = safetensors_load_file( + os.path.join(source, weight_map[lm_head_key]), device="cpu" + ) + embed_tokens_state = safetensors_load_file( + os.path.join(source, weight_map[embed_tokens_key]), device="cpu" + ) return lm_head_state[lm_head_key], embed_tokens_state[embed_tokens_key] def forward(self, *args, **kwargs): diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 442f26277..060eb89b9 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -68,6 +68,7 @@ get_ttt_msk_func, temporary_set_config_value, ) +from .modeling_fakebase import _BASE_MODEL_PATHS, _EMBED_TOKENS_PATHS, _LM_HEAD_PATHS __all__ = ["HFARValidation", "HFEagleModel", "HFMedusaModel"] @@ -464,22 +465,9 @@ def get_exporter(self) -> SpeculativeDecodingExporter: def _find_base_model_parts(self): """Find model parts from different models and set base_{part}_path attributes.""" base_model_parts_mapping = { - "base_model_path": [ - "language_model.model", - "model.language_model", - "model", - "backbone", - "language_model.backbone", - ], - "base_model_embeddings_path": [ - "embed_tokens", - "language_model.model.embed_tokens", - "model.embed_tokens", - "backbone.embeddings", - "language_model.backbone.embeddings", - "model.language_model.embed_tokens", - ], - "base_model_lm_head_path": ["lm_head", "language_model.lm_head"], + "base_model_path": _BASE_MODEL_PATHS, + "base_model_embeddings_path": _EMBED_TOKENS_PATHS, + "base_model_lm_head_path": _LM_HEAD_PATHS, } for name, paths in base_model_parts_mapping.items(): diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index bb8ad5283..ec250c634 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -486,35 +486,30 @@ def enable_cp_ttt_patch(): def load_vlm_or_llm( model_name_or_path: str, + use_fake_base: bool = True, use_offline_training: bool = False, - fake_base_args=None, torch_dtype: str | torch.dtype | None = None, device_map: str | None = None, trust_remote_code: bool = False, ): """Load a VLM or LLM. Returns the model. - When ``use_offline_training=True`` and ``fake_base_args.use_fake_base_model=True``, returns a + When ``use_offline_training=True``, returns a :class:`~modelopt.torch.speculative.plugins.modeling_fakebase.FakeBaseModel` containing only - ``lm_head`` and ``embed_tokens``. Otherwise, falls back to loading with - ``num_hidden_layers=0`` for memory efficiency. + ``lm_head`` and ``embed_tokens``, auto-detecting weight paths from the checkpoint. + Otherwise, falls back to loading with ``num_hidden_layers=0`` for memory efficiency. Args: model_name_or_path: Local path or HuggingFace repo ID of the model. use_offline_training: Whether to load a memory-efficient model for offline training. - fake_base_args: Optional - :class:`~modelopt.torch.speculative.plugins.modeling_fakebase.FakeBaseArguments`. - If provided and ``use_offline_training=True``, a - :class:`~modelopt.torch.speculative.plugins.modeling_fakebase.FakeBaseModel` is - returned instead of the full model. torch_dtype: dtype to use when loading the model. device_map: Device map passed to ``from_pretrained``. trust_remote_code: Whether to trust remote code. """ - if use_offline_training and fake_base_args is not None and fake_base_args.use_fake_base_model: + if use_offline_training and use_fake_base: from modelopt.torch.speculative.plugins.modeling_fakebase import FakeBaseModel - return FakeBaseModel(model_name_or_path, fake_base_args) + return FakeBaseModel(model_name_or_path) model_config = transformers.AutoConfig.from_pretrained( model_name_or_path, trust_remote_code=trust_remote_code From 45a5aa9a00c03029d78d72dd5621adb4d6d0bc33 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Tue, 17 Mar 2026 08:29:55 +0000 Subject: [PATCH 4/9] polish Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../torch/speculative/plugins/modeling_fakebase.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/modelopt/torch/speculative/plugins/modeling_fakebase.py b/modelopt/torch/speculative/plugins/modeling_fakebase.py index cd94c8eaa..ccefffa7b 100644 --- a/modelopt/torch/speculative/plugins/modeling_fakebase.py +++ b/modelopt/torch/speculative/plugins/modeling_fakebase.py @@ -114,14 +114,12 @@ def __init__(self, source: str): self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - try: - lm_head_w, embed_tokens_w = self._load_weights(source) - assert lm_head_w.shape == (config.vocab_size, config.hidden_size) - assert embed_tokens_w.shape == (config.vocab_size, config.hidden_size) - self.lm_head.weight.data.copy_(lm_head_w) - self.embed_tokens.weight.data.copy_(embed_tokens_w) - except Exception as e: - raise ValueError(f"Failed to initialize lm_head and embed_tokens: {e}") + # Load lm_head and embed_tokens only from checkpoint + lm_head_w, embed_tokens_w = self._load_weights(source) + assert lm_head_w.shape == (config.vocab_size, config.hidden_size) + assert embed_tokens_w.shape == (config.vocab_size, config.hidden_size) + self.lm_head.weight.data.copy_(lm_head_w) + self.embed_tokens.weight.data.copy_(embed_tokens_w) @staticmethod def _find_weight_key(weight_map: dict, paths: list[str], label: str) -> str: From 9221267fe081cb211d4a90e4beb75345cce7b03c Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Tue, 17 Mar 2026 08:32:55 +0000 Subject: [PATCH 5/9] remove kimi patches Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- modelopt/torch/speculative/plugins/transformers.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 060eb89b9..ce2fc28db 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -48,7 +48,6 @@ ) from transformers.trainer_pt_utils import LabelSmoother from transformers.utils import ModelOutput -from transformers.utils.quantization_config import CompressedTensorsConfig from ...export.plugins.hf_spec_export import ( EagleExporter, @@ -568,11 +567,6 @@ def modify( if self.eagle_config._attn_implementation is None: self.eagle_config._attn_implementation = "sdpa" - # Patch for Kimi-K2-Thinking, avoid quantizing drafter - quant_config = getattr(self.config, "quantization_config", None) - if isinstance(quant_config, CompressedTensorsConfig): - quant_config.quantization_config.ignore.append("re:.*eagle_module.*") - # Set default aux_hidden_state layers if ( self.eagle_config.use_aux_hidden_state @@ -582,7 +576,7 @@ def modify( # Freeze all parameters if self.eagle_freeze_base_model: - for name, param in self.named_parameters(): + for _, param in self.named_parameters(): param.requires_grad = False self.eagle_module = EagleModule( @@ -733,8 +727,6 @@ def _compute_ttt_attention_mask( tensor_mask, 0, dtype=self._base_llm_config.dtype, device=self.device ).masked_fill(~tensor_mask, dtypemin) - # Note: (hg) repeat mask for kimi-k2 compatibility - tensor_mask = tensor_mask.repeat(batch_size, 1, 1, 1) return tensor_mask def _base_model_forward( From 0df75e2d395fd67a6fd6a312b4a39215e0b0eff0 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Wed, 18 Mar 2026 06:20:19 +0000 Subject: [PATCH 6/9] add tests Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/launch_train.sh | 13 ++- examples/speculative_decoding/main.py | 17 ++- .../speculative/plugins/modeling_fakebase.py | 70 ++++++++---- modelopt/torch/speculative/utils.py | 2 +- .../speculative_decoding/test_eagle.py | 104 ++++++++++++++++++ 5 files changed, 179 insertions(+), 27 deletions(-) diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index 92a4f614e..6045b40c8 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -114,6 +114,14 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi MIX_HIDDEN_STATES="${1#*=}" ;; + --use_fake_base_for_offline*) + if [[ "$1" != *=* ]]; then shift; fi + USE_FAKE_BASE_FOR_OFFLINE="${1#*=}" + ;; + --trust_remote_code*) + if [[ "$1" != *=* ]]; then shift; fi + TRUST_REMOTE_CODE="${1#*=}" + ;; *) >&2 printf "Error: Invalid argument ${1#*=}\n" exit 1 @@ -161,7 +169,8 @@ DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((TOTAL_GPU/CP_SIZE))} LOG_STEPS=${LOG_STEPS:-100} DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""} MIX_HIDDEN_STATES=${MIX_HIDDEN_STATES:-"False"} - +USE_FAKE_BASE_FOR_OFFLINE=${USE_FAKE_BASE_FOR_OFFLINE:-"False"} +TRUST_REMOTE_CODE=${TRUST_REMOTE_CODE:-"False"} if [[ "$MODE" == "eagle3" ]]; then if [[ -n "$EAGLE_CONFIG" ]]; then @@ -247,6 +256,8 @@ CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/mai --estimate_ar $ESTIMATE_AR \ --ar_validate_steps $AR_VALIDATE_STEPS \ --mix_hidden_states $MIX_HIDDEN_STATES \ + --use_fake_base_for_offline $USE_FAKE_BASE_FOR_OFFLINE \ + --trust_remote_code $TRUST_REMOTE_CODE \ $DRAFT_VOCAB_CACHE_ARGS \ $VLM_ARGS \ $OFFLINE_TRAINING_ARGS \ diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index a4965fbd4..1093e577d 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -58,7 +58,10 @@ class ModelArguments: model_name_or_path: str | None = field(default="TinyLlama/TinyLlama-1.1B-Chat-v1.0") use_fake_base_for_offline: bool = field( - default=True, metadata={"help": "Whether to use fake base for offline training."} + default=False, metadata={"help": "Whether to use fake base for offline training."} + ) + trust_remote_code: bool = field( + default=False, metadata={"help": "Whether to trust remote code."} ) @@ -169,8 +172,12 @@ def train(): if checkpoint: with patch_transformers5_params_loading(): - model = load_vlm_or_llm(checkpoint, torch_dtype="auto", trust_remote_code=True) - tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) + model = load_vlm_or_llm( + checkpoint, torch_dtype="auto", trust_remote_code=model_args.trust_remote_code + ) + tokenizer = transformers.AutoTokenizer.from_pretrained( + checkpoint, trust_remote_code=model_args.trust_remote_code + ) else: # To avoid OOM for large models, we load and convert model on CPU first. # Model will be moved to GPU during HF trainer.init(). @@ -180,12 +187,12 @@ def train(): use_offline_training=use_offline_training, torch_dtype="auto", device_map="cpu", - trust_remote_code=True, + trust_remote_code=model_args.trust_remote_code, ) tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, model_max_length=training_args.training_seq_len, - trust_remote_code=True, + trust_remote_code=model_args.trust_remote_code, ) if training_args.mode == "medusa": config = { diff --git a/modelopt/torch/speculative/plugins/modeling_fakebase.py b/modelopt/torch/speculative/plugins/modeling_fakebase.py index ccefffa7b..5c31c80bf 100644 --- a/modelopt/torch/speculative/plugins/modeling_fakebase.py +++ b/modelopt/torch/speculative/plugins/modeling_fakebase.py @@ -21,6 +21,7 @@ import torch import torch.nn as nn import transformers +from huggingface_hub import hf_hub_download from safetensors.torch import load_file as safetensors_load_file from transformers import PretrainedConfig, PreTrainedModel @@ -81,13 +82,19 @@ class FakeBaseModel(PreTrainedModel): config_class = FakeBaseConfig - def __init__(self, source: str): - """Load lm_head and embed_tokens from a local HuggingFace checkpoint directory. + def __init__(self, source: str, trust_remote_code: bool = False): + """Load lm_head and embed_tokens from a local directory or HuggingFace Hub repo. Args: - source: Path to a local HuggingFace checkpoint directory. + source: Path to a local HuggingFace checkpoint directory, or a HuggingFace Hub + repo ID (e.g. ``"meta-llama/Llama-3.1-8B"``). The source type is detected + automatically: if ``source`` is an existing local directory it is treated as a + local checkpoint; otherwise it is treated as a Hub repo ID and the required + files are downloaded via ``huggingface_hub``. """ - orig_config = transformers.AutoConfig.from_pretrained(source) + orig_config = transformers.AutoConfig.from_pretrained( + source, trust_remote_code=trust_remote_code + ) # For vlms, detect language model config based on _VLM_CONFIG_ATTRS base_cfg = next( ( @@ -132,24 +139,47 @@ def _find_weight_key(weight_map: dict, paths: list[str], label: str) -> str: raise RuntimeError(f"Cannot find {label} in checkpoint; tried: {tried}") def _load_weights(self, source: str): - """Load lm_head and embed_tokens weights from a local checkpoint directory.""" - index_path = os.path.join(source, _SAFETENSORS_INDEX_FILENAME) - - if not os.path.isfile(index_path): - raise FileNotFoundError(f"No {_SAFETENSORS_INDEX_FILENAME} found in {source!r}.") - - with open(index_path) as f: - weight_map = json.load(f).get("weight_map", {}) + """Load lm_head and embed_tokens weights from a local directory or HuggingFace Hub repo. - lm_head_key = self._find_weight_key(weight_map, _LM_HEAD_PATHS, "lm_head") - embed_tokens_key = self._find_weight_key(weight_map, _EMBED_TOKENS_PATHS, "embed_tokens") + For remote repos the index file and the two required weight shards are downloaded via + ``huggingface_hub`` and cached locally; subsequent calls reuse the cache. + """ + if os.path.isdir(source): + index_path = os.path.join(source, _SAFETENSORS_INDEX_FILENAME) + if not os.path.isfile(index_path): + raise FileNotFoundError(f"No {_SAFETENSORS_INDEX_FILENAME} found in {source!r}.") + with open(index_path) as f: + weight_map = json.load(f).get("weight_map", {}) + + lm_head_key = self._find_weight_key(weight_map, _LM_HEAD_PATHS, "lm_head") + embed_tokens_key = self._find_weight_key( + weight_map, _EMBED_TOKENS_PATHS, "embed_tokens" + ) + + lm_head_state = safetensors_load_file( + os.path.join(source, weight_map[lm_head_key]), device="cpu" + ) + embed_tokens_state = safetensors_load_file( + os.path.join(source, weight_map[embed_tokens_key]), device="cpu" + ) + else: + # Treat source as a HuggingFace Hub repo ID + index_path = hf_hub_download(repo_id=source, filename=_SAFETENSORS_INDEX_FILENAME) + with open(index_path) as f: + weight_map = json.load(f).get("weight_map", {}) + + lm_head_key = self._find_weight_key(weight_map, _LM_HEAD_PATHS, "lm_head") + embed_tokens_key = self._find_weight_key( + weight_map, _EMBED_TOKENS_PATHS, "embed_tokens" + ) + + lm_head_shard = hf_hub_download(repo_id=source, filename=weight_map[lm_head_key]) + embed_tokens_shard = hf_hub_download( + repo_id=source, filename=weight_map[embed_tokens_key] + ) + lm_head_state = safetensors_load_file(lm_head_shard, device="cpu") + embed_tokens_state = safetensors_load_file(embed_tokens_shard, device="cpu") - lm_head_state = safetensors_load_file( - os.path.join(source, weight_map[lm_head_key]), device="cpu" - ) - embed_tokens_state = safetensors_load_file( - os.path.join(source, weight_map[embed_tokens_key]), device="cpu" - ) return lm_head_state[lm_head_key], embed_tokens_state[embed_tokens_key] def forward(self, *args, **kwargs): diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index ec250c634..655c25a68 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -509,7 +509,7 @@ def load_vlm_or_llm( if use_offline_training and use_fake_base: from modelopt.torch.speculative.plugins.modeling_fakebase import FakeBaseModel - return FakeBaseModel(model_name_or_path) + return FakeBaseModel(model_name_or_path, trust_remote_code=trust_remote_code) model_config = transformers.AutoConfig.from_pretrained( model_name_or_path, trust_remote_code=trust_remote_code diff --git a/tests/examples/speculative_decoding/test_eagle.py b/tests/examples/speculative_decoding/test_eagle.py index a3542fa25..a79719dc2 100644 --- a/tests/examples/speculative_decoding/test_eagle.py +++ b/tests/examples/speculative_decoding/test_eagle.py @@ -15,6 +15,7 @@ import json import os +from pathlib import Path import pytest import safetensors.torch @@ -25,6 +26,44 @@ from modelopt.torch.export.plugins.hf_spec_export import LLAMA_EAGLE_SINGLE_LAYER +def generate_offline_pt_data( + output_dir, + num_files: int = 8, + seq_len: int = 128, + hidden_size: int = 512, + vocab_size: int = 32000, + num_aux_layers: int = 2, +) -> Path: + """Generate fake offline training .pt files for EAGLE3 offline training tests. + + Each file contains the keys expected by OfflineSupervisedDataset: + - input_ids: LongTensor of shape (seq_len,) + - hidden_states: FloatTensor of shape (seq_len, hidden_size) + - aux_hidden_states: FloatTensor of shape (seq_len, hidden_size*num_aux_layers) + + Args: + output_dir: Directory to write .pt files into. + num_files: Number of .pt files to generate. + seq_len: Sequence length. Defaults to 128. + hidden_size: Hidden size matching the base model. Defaults to 512 (tiny_llama). + vocab_size: Vocabulary size matching the base model. Defaults to 32000 (tiny_llama). + num_aux_layers: Number of auxiliary layers. Defaults to 2. + Returns: + Path to the output directory. + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + torch.manual_seed(42) + for i in range(num_files): + sample = { + "input_ids": torch.randint(0, vocab_size, (seq_len,)), + "hidden_states": torch.randn(seq_len, hidden_size), + "aux_hidden_states": torch.randn(seq_len, hidden_size * num_aux_layers), + } + torch.save(sample, output_dir / f"sample_{i:04d}.pt") + return output_dir + + @pytest.fixture(scope="module") def eagle_output_dir(tmp_path_factory): """Eagle output directory shared in this module.""" @@ -164,3 +203,68 @@ def test_convert_to_vllm_ckpt(tiny_llama_path, eagle_output_dir): ], "speculative_decoding", ) + + +@pytest.mark.parametrize( + ("model_source", "use_fake_base"), + [ + (None, False), # tiny_llama (from fixture), no FakeBase + ("moonshotai/Kimi-K2.5", True), # remote HF repo, FakeBaseModel + ("moonshotai/Kimi-K2-Thinking", True), # remote HF repo, no FakeBaseModel + ("MiniMaxAI/MiniMax-M2.5", True), + ], + ids=["tinyllama", "kimi-k2.5","kimi-k2-thinking","minimax-m2.5"], +) +def test_offline_eagle3_training( + tiny_llama_path, tiny_daring_anteater_path, tmp_path, eagle_output_dir, + model_source, use_fake_base, +): + """Test Eagle3 training with pre-computed hidden states (offline mode / FakeBaseModel).""" + import transformers + + model_path = tiny_llama_path if model_source is None else model_source + model_id = "tinyllama" if model_source is None else "kimi-k2.5" + output_subdir = eagle_output_dir / f"eagle-{model_id}-offline" + + cfg = transformers.AutoConfig.from_pretrained(model_path, trust_remote_code=True) + + if model_source=="moonshotai/Kimi-K2.5": + #vlm, get text config + cfg = cfg.text_config + + offline_data_dir = generate_offline_pt_data( + tmp_path / "offline_data", + hidden_size=cfg.hidden_size, + vocab_size=cfg.vocab_size, + num_aux_layers=min(cfg.num_hidden_layers, 3), + ) + + tiny_eagle_config = { + "max_position_embeddings": 128, + "num_hidden_layers": 1, + "intermediate_size": 64, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "head_dim": 64, + } + config_file = tmp_path / "tiny_eagle_config_offline.json" + with open(config_file, "w") as f: + json.dump(tiny_eagle_config, f) + + cmd = [ + "./launch_train.sh", + "--model", model_path, + "--data", tiny_daring_anteater_path, + "--offline-data", offline_data_dir, + "--num_epochs", "0.1", + "--lr", "1e-5", + "--mode", "eagle3", + "--eagle_config", str(config_file), + "--output_dir", output_subdir, + "--training_seq_len", "64", + "--trust_remote_code", "True", + ] + if use_fake_base: + cmd += ["--use_fake_base_for_offline", "true"] + run_example_command(cmd, "speculative_decoding") + assert os.path.exists(output_subdir / "config.json") From d9b25e199203cd08fe98feafa6b7ebb986b72923 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Thu, 19 Mar 2026 06:38:59 +0000 Subject: [PATCH 7/9] refactor Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../speculative/plugins/modeling_fakebase.py | 81 ++++++++++--------- modelopt/torch/speculative/utils.py | 4 +- .../speculative_decoding/test_eagle.py | 2 +- 3 files changed, 47 insertions(+), 40 deletions(-) diff --git a/modelopt/torch/speculative/plugins/modeling_fakebase.py b/modelopt/torch/speculative/plugins/modeling_fakebase.py index 5c31c80bf..0db11a455 100644 --- a/modelopt/torch/speculative/plugins/modeling_fakebase.py +++ b/modelopt/torch/speculative/plugins/modeling_fakebase.py @@ -22,6 +22,7 @@ import torch.nn as nn import transformers from huggingface_hub import hf_hub_download +from huggingface_hub.errors import EntryNotFoundError from safetensors.torch import load_file as safetensors_load_file from transformers import PretrainedConfig, PreTrainedModel @@ -138,47 +139,53 @@ def _find_weight_key(weight_map: dict, paths: list[str], label: str) -> str: tried = [p + ".weight" for p in paths] raise RuntimeError(f"Cannot find {label} in checkpoint; tried: {tried}") - def _load_weights(self, source: str): - """Load lm_head and embed_tokens weights from a local directory or HuggingFace Hub repo. - - For remote repos the index file and the two required weight shards are downloaded via - ``huggingface_hub`` and cached locally; subsequent calls reuse the cache. - """ + @staticmethod + def _load_index(source: str) -> dict: + """Load weight_map from model.safetensors.index.json (local directory or Hub repo).""" if os.path.isdir(source): index_path = os.path.join(source, _SAFETENSORS_INDEX_FILENAME) if not os.path.isfile(index_path): - raise FileNotFoundError(f"No {_SAFETENSORS_INDEX_FILENAME} found in {source!r}.") - with open(index_path) as f: - weight_map = json.load(f).get("weight_map", {}) - - lm_head_key = self._find_weight_key(weight_map, _LM_HEAD_PATHS, "lm_head") - embed_tokens_key = self._find_weight_key( - weight_map, _EMBED_TOKENS_PATHS, "embed_tokens" - ) - - lm_head_state = safetensors_load_file( - os.path.join(source, weight_map[lm_head_key]), device="cpu" - ) - embed_tokens_state = safetensors_load_file( - os.path.join(source, weight_map[embed_tokens_key]), device="cpu" - ) + raise FileNotFoundError( + f"No {_SAFETENSORS_INDEX_FILENAME} found in {source!r}. " + "FakeBaseModel only supports safetensors checkpoints. " + "Checkpoints using pytorch_model.bin or single-file formats are not supported." + ) else: - # Treat source as a HuggingFace Hub repo ID - index_path = hf_hub_download(repo_id=source, filename=_SAFETENSORS_INDEX_FILENAME) - with open(index_path) as f: - weight_map = json.load(f).get("weight_map", {}) - - lm_head_key = self._find_weight_key(weight_map, _LM_HEAD_PATHS, "lm_head") - embed_tokens_key = self._find_weight_key( - weight_map, _EMBED_TOKENS_PATHS, "embed_tokens" - ) - - lm_head_shard = hf_hub_download(repo_id=source, filename=weight_map[lm_head_key]) - embed_tokens_shard = hf_hub_download( - repo_id=source, filename=weight_map[embed_tokens_key] - ) - lm_head_state = safetensors_load_file(lm_head_shard, device="cpu") - embed_tokens_state = safetensors_load_file(embed_tokens_shard, device="cpu") + try: + index_path = hf_hub_download(repo_id=source, filename=_SAFETENSORS_INDEX_FILENAME) + except EntryNotFoundError: + raise ValueError( + f"Repository {source!r} does not contain {_SAFETENSORS_INDEX_FILENAME}. " + "FakeBaseModel only supports safetensors checkpoints. " + "Checkpoints using pytorch_model.bin or single-file formats are not supported." + ) from None + with open(index_path) as f: + return json.load(f).get("weight_map", {}) + + @staticmethod + def _resolve_shard_paths(source: str, shard_filenames: list[str]) -> list[str]: + """Return local filesystem paths for each shard filename. + + For a local directory the paths are joined directly; for a HuggingFace Hub repo ID the + shards are downloaded via ``hf_hub_download`` (cached on subsequent calls). + """ + if os.path.isdir(source): + return [os.path.join(source, name) for name in shard_filenames] + return [hf_hub_download(repo_id=source, filename=name) for name in shard_filenames] + + def _load_weights(self, source: str): + """Load lm_head and embed_tokens weights from a local directory or HuggingFace Hub repo.""" + weight_map = self._load_index(source) + + lm_head_key = self._find_weight_key(weight_map, _LM_HEAD_PATHS, "lm_head") + embed_tokens_key = self._find_weight_key(weight_map, _EMBED_TOKENS_PATHS, "embed_tokens") + + lm_head_path, embed_tokens_path = self._resolve_shard_paths( + source, [weight_map[lm_head_key], weight_map[embed_tokens_key]] + ) + + lm_head_state = safetensors_load_file(lm_head_path, device="cpu") + embed_tokens_state = safetensors_load_file(embed_tokens_path, device="cpu") return lm_head_state[lm_head_key], embed_tokens_state[embed_tokens_key] diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index 655c25a68..f9ffd2487 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -448,7 +448,7 @@ def patched_fwd_with_lazy_rope_init(self, *args, **kwargs): original_decoder_layer_forward = kimi_k2_module.DeepseekV3DecoderLayer.forward def patched_decoder_layer_fwd(self, *args, **kwargs): - kwargs["past_key_value"] = kwargs.get("past_key_values") + kwargs["past_key_value"] = kwargs.pop("past_key_values", None) return original_decoder_layer_forward(self, *args, **kwargs) kimi_k2_module.DeepseekV3DecoderLayer.forward = patched_decoder_layer_fwd @@ -486,7 +486,7 @@ def enable_cp_ttt_patch(): def load_vlm_or_llm( model_name_or_path: str, - use_fake_base: bool = True, + use_fake_base: bool = False, use_offline_training: bool = False, torch_dtype: str | torch.dtype | None = None, device_map: str | None = None, diff --git a/tests/examples/speculative_decoding/test_eagle.py b/tests/examples/speculative_decoding/test_eagle.py index a79719dc2..02c416fee 100644 --- a/tests/examples/speculative_decoding/test_eagle.py +++ b/tests/examples/speculative_decoding/test_eagle.py @@ -223,7 +223,7 @@ def test_offline_eagle3_training( import transformers model_path = tiny_llama_path if model_source is None else model_source - model_id = "tinyllama" if model_source is None else "kimi-k2.5" + model_id = "tinyllama" if model_source is None else model_source.split("/")[-1] output_subdir = eagle_output_dir / f"eagle-{model_id}-offline" cfg = transformers.AutoConfig.from_pretrained(model_path, trust_remote_code=True) From a023e6e5a5d4bd1aad4e54d5d0fd4372f0f00315 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Thu, 19 Mar 2026 06:58:05 +0000 Subject: [PATCH 8/9] ddp for offline Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/launch_train.sh | 7 ++++++- tests/examples/speculative_decoding/test_eagle.py | 1 + 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index 6045b40c8..94284db29 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -122,6 +122,10 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi TRUST_REMOTE_CODE="${1#*=}" ;; + --fsdp*) + if [[ "$1" != *=* ]]; then shift; fi + FSDP="${1#*=}" + ;; *) >&2 printf "Error: Invalid argument ${1#*=}\n" exit 1 @@ -171,6 +175,7 @@ DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""} MIX_HIDDEN_STATES=${MIX_HIDDEN_STATES:-"False"} USE_FAKE_BASE_FOR_OFFLINE=${USE_FAKE_BASE_FOR_OFFLINE:-"False"} TRUST_REMOTE_CODE=${TRUST_REMOTE_CODE:-"False"} +FSDP=${FSDP:-"False"} if [[ "$MODE" == "eagle3" ]]; then if [[ -n "$EAGLE_CONFIG" ]]; then @@ -201,7 +206,7 @@ else VLM_ARGS="" fi -if [[ "$TOTAL_GPU" -gt 1 ]]; then +if [[ "$TOTAL_GPU" -gt 1 && "$FSDP" == "True" ]]; then #Use FSDP2 when multi GPU available FSDP_ARGS="--fsdp 'full_shard' --fsdp_config ${SCRIPT_DIR}/fsdp_config.json" else diff --git a/tests/examples/speculative_decoding/test_eagle.py b/tests/examples/speculative_decoding/test_eagle.py index 02c416fee..4a7b89457 100644 --- a/tests/examples/speculative_decoding/test_eagle.py +++ b/tests/examples/speculative_decoding/test_eagle.py @@ -263,6 +263,7 @@ def test_offline_eagle3_training( "--output_dir", output_subdir, "--training_seq_len", "64", "--trust_remote_code", "True", + "--fsdp", "False", ] if use_fake_base: cmd += ["--use_fake_base_for_offline", "true"] From 99946d5965304ac3d0fc8f9a727a714b103aecd7 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Thu, 19 Mar 2026 07:14:41 +0000 Subject: [PATCH 9/9] new unit tests Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../speculative/plugins/test_fakebase.py | 95 +++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 tests/unit/torch/speculative/plugins/test_fakebase.py diff --git a/tests/unit/torch/speculative/plugins/test_fakebase.py b/tests/unit/torch/speculative/plugins/test_fakebase.py new file mode 100644 index 000000000..b19ce7b10 --- /dev/null +++ b/tests/unit/torch/speculative/plugins/test_fakebase.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for FakeBaseModel and the fake-base / offline paths in load_vlm_or_llm.""" + +import json + +import pytest +import safetensors.torch +import torch +import transformers + +from modelopt.torch.speculative.plugins.modeling_fakebase import FakeBaseModel +from modelopt.torch.speculative.utils import load_vlm_or_llm + +_HIDDEN_SIZE = 16 +_VOCAB_SIZE = 32 + + +@pytest.fixture +def fake_config(monkeypatch): + """Monkeypatch AutoConfig.from_pretrained to return a minimal fake config.""" + cfg = transformers.PretrainedConfig() + cfg.model_type = "llama" + cfg.hidden_size = _HIDDEN_SIZE + cfg.vocab_size = _VOCAB_SIZE + cfg.num_hidden_layers = 2 + cfg.max_position_embeddings = 128 + cfg.tie_word_embeddings = False + monkeypatch.setattr(transformers.AutoConfig, "from_pretrained", lambda *a, **kw: cfg) + return cfg + + +@pytest.fixture +def fake_checkpoint(tmp_path, fake_config): + """Minimal local safetensors checkpoint loadable by FakeBaseModel.""" + tensors = { + "lm_head.weight": torch.zeros(_VOCAB_SIZE, _HIDDEN_SIZE), + "embed_tokens.weight": torch.zeros(_VOCAB_SIZE, _HIDDEN_SIZE), + } + shard = tmp_path / "model-00001-of-00001.safetensors" + safetensors.torch.save_file(tensors, shard) + index = {"weight_map": dict.fromkeys(tensors, shard.name)} + (tmp_path / "model.safetensors.index.json").write_text(json.dumps(index)) + return tmp_path + + +def test_fakebase_local_happy_path(fake_checkpoint): + model = FakeBaseModel(str(fake_checkpoint)) + assert model.lm_head.weight.shape == torch.Size([_VOCAB_SIZE, _HIDDEN_SIZE]) + assert model.embed_tokens.weight.shape == torch.Size([_VOCAB_SIZE, _HIDDEN_SIZE]) + + +def test_fakebase_missing_index_raises(tmp_path, fake_config): + with pytest.raises(FileNotFoundError, match="safetensors"): + FakeBaseModel(str(tmp_path)) + + +def test_load_vlm_or_llm_returns_fakebase(fake_checkpoint): + model = load_vlm_or_llm(str(fake_checkpoint), use_offline_training=True, use_fake_base=True) + assert isinstance(model, FakeBaseModel) + + +def test_load_vlm_or_llm_offline_zero_layers(monkeypatch): + cfg = transformers.PretrainedConfig() + cfg.model_type = "llama" + cfg.num_hidden_layers = 4 + monkeypatch.setattr(transformers.AutoConfig, "from_pretrained", lambda *a, **kw: cfg) + + captured_kwargs = {} + + class _FakeModel: + config = cfg + + def _fake_from_pretrained(*args, **kwargs): + captured_kwargs.update(kwargs) + return _FakeModel() + + monkeypatch.setattr(transformers.AutoModelForCausalLM, "from_pretrained", _fake_from_pretrained) + + model = load_vlm_or_llm("fake-model", use_offline_training=True, use_fake_base=False) + assert captured_kwargs.get("num_hidden_layers") == 0 + assert model.config.num_orig_hidden_layers == 4