diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 8c96a19a7..7f28e68f7 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -154,7 +154,7 @@ def make_eagle_supervised_data_module( assert not data_args.vlm_processor, "Offline data is not supported for VLM." offline_data_path = Path(data_args.offline_data_path) - dumped_files = [str(p) for p in offline_data_path.glob("*.pt")] + dumped_files = [str(p) for p in offline_data_path.rglob("*.pt")] if not dumped_files: raise ValueError(f"No .pt files found in {data_args.offline_data_path}") diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index 074151c5a..94284db29 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -114,6 +114,18 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi MIX_HIDDEN_STATES="${1#*=}" ;; + --use_fake_base_for_offline*) + if [[ "$1" != *=* ]]; then shift; fi + USE_FAKE_BASE_FOR_OFFLINE="${1#*=}" + ;; + --trust_remote_code*) + if [[ "$1" != *=* ]]; then shift; fi + TRUST_REMOTE_CODE="${1#*=}" + ;; + --fsdp*) + if [[ "$1" != *=* ]]; then shift; fi + FSDP="${1#*=}" + ;; *) >&2 printf "Error: Invalid argument ${1#*=}\n" exit 1 @@ -126,9 +138,16 @@ set -x SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" NUM_NODES=${NUM_NODES:-1} -GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)} -TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE)) -echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)" +if [[ "$NUM_NODES" != 1 ]]; then + #Multi Node Training + GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)} + TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE)) + echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)" +else + #Single Node Training, GPU can be specified by $CUDA_VISIBLE_DEVICES + TOTAL_GPU=$(python -c "import torch; print(torch.cuda.device_count())") + echo "Total GPUs: $TOTAL_GPU (Single Node Training)" +fi # Calculate save_steps DEFAULT_SAVE_STEPS=$((8192 / TOTAL_GPU)) @@ -154,7 +173,9 @@ DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((TOTAL_GPU/CP_SIZE))} LOG_STEPS=${LOG_STEPS:-100} DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""} MIX_HIDDEN_STATES=${MIX_HIDDEN_STATES:-"False"} - +USE_FAKE_BASE_FOR_OFFLINE=${USE_FAKE_BASE_FOR_OFFLINE:-"False"} +TRUST_REMOTE_CODE=${TRUST_REMOTE_CODE:-"False"} +FSDP=${FSDP:-"False"} if [[ "$MODE" == "eagle3" ]]; then if [[ -n "$EAGLE_CONFIG" ]]; then @@ -185,7 +206,7 @@ else VLM_ARGS="" fi -if [[ "$TOTAL_GPU" -gt 1 ]]; then +if [[ "$TOTAL_GPU" -gt 1 && "$FSDP" == "True" ]]; then #Use FSDP2 when multi GPU available FSDP_ARGS="--fsdp 'full_shard' --fsdp_config ${SCRIPT_DIR}/fsdp_config.json" else @@ -240,6 +261,8 @@ CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/mai --estimate_ar $ESTIMATE_AR \ --ar_validate_steps $AR_VALIDATE_STEPS \ --mix_hidden_states $MIX_HIDDEN_STATES \ + --use_fake_base_for_offline $USE_FAKE_BASE_FOR_OFFLINE \ + --trust_remote_code $TRUST_REMOTE_CODE \ $DRAFT_VOCAB_CACHE_ARGS \ $VLM_ARGS \ $OFFLINE_TRAINING_ARGS \ diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 25817ee94..1093e577d 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -47,10 +47,7 @@ import modelopt.torch.opt as mto import modelopt.torch.speculative as mtsp -from modelopt.torch.speculative.utils import ( - load_vlm_or_llm_with_kwargs, - patch_transformers5_params_loading, -) +from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading from modelopt.torch.utils import print_rank_0 torch.manual_seed(0) @@ -60,6 +57,12 @@ @dataclass class ModelArguments: model_name_or_path: str | None = field(default="TinyLlama/TinyLlama-1.1B-Chat-v1.0") + use_fake_base_for_offline: bool = field( + default=False, metadata={"help": "Whether to use fake base for offline training."} + ) + trust_remote_code: bool = field( + default=False, metadata={"help": "Whether to trust remote code."} + ) @dataclass @@ -169,29 +172,27 @@ def train(): if checkpoint: with patch_transformers5_params_loading(): - _, model = load_vlm_or_llm_with_kwargs( - checkpoint, torch_dtype="auto", trust_remote_code=True + model = load_vlm_or_llm( + checkpoint, torch_dtype="auto", trust_remote_code=model_args.trust_remote_code ) - tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) + tokenizer = transformers.AutoTokenizer.from_pretrained( + checkpoint, trust_remote_code=model_args.trust_remote_code + ) else: # To avoid OOM for large models, we load and convert model on CPU first. # Model will be moved to GPU during HF trainer.init(). - offline_kwargs = {"num_hidden_layers": 0} if use_offline_training else {} - model_config, model = load_vlm_or_llm_with_kwargs( + model = load_vlm_or_llm( model_args.model_name_or_path, + use_fake_base=model_args.use_fake_base_for_offline, + use_offline_training=use_offline_training, torch_dtype="auto", device_map="cpu", - trust_remote_code=True, - **offline_kwargs, + trust_remote_code=model_args.trust_remote_code, ) - if use_offline_training: - # When doing offline training, we need to set num_hidden_layers - # since we override it when loading the model for space savings - model.config.num_orig_hidden_layers = model_config.num_hidden_layers tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, model_max_length=training_args.training_seq_len, - trust_remote_code=True, + trust_remote_code=model_args.trust_remote_code, ) if training_args.mode == "medusa": config = { diff --git a/examples/speculative_decoding/scripts/ar_validate.py b/examples/speculative_decoding/scripts/ar_validate.py index d5c37a895..d1bf31a1a 100644 --- a/examples/speculative_decoding/scripts/ar_validate.py +++ b/examples/speculative_decoding/scripts/ar_validate.py @@ -22,7 +22,7 @@ import modelopt.torch.opt as mto from modelopt.torch.speculative.plugins.transformers import HFARValidation -from modelopt.torch.speculative.utils import load_vlm_or_llm_with_kwargs +from modelopt.torch.speculative.utils import load_vlm_or_llm mto.enable_huggingface_checkpointing() @@ -72,7 +72,7 @@ def main(): accelerator = Accelerator() # Load model and tokenizer - _, model = load_vlm_or_llm_with_kwargs(args.model_path, device_map="auto") + model = load_vlm_or_llm(args.model_path, device_map="auto") tokenizer = AutoTokenizer.from_pretrained(args.model_path) model.eval() model = accelerator.prepare(model) diff --git a/examples/speculative_decoding/scripts/export_hf_checkpoint.py b/examples/speculative_decoding/scripts/export_hf_checkpoint.py index 23a7560f7..925f4b73d 100644 --- a/examples/speculative_decoding/scripts/export_hf_checkpoint.py +++ b/examples/speculative_decoding/scripts/export_hf_checkpoint.py @@ -21,7 +21,7 @@ import modelopt.torch.opt as mto from modelopt.torch.export import export_speculative_decoding -from modelopt.torch.speculative.utils import load_vlm_or_llm_with_kwargs +from modelopt.torch.speculative.utils import load_vlm_or_llm def parse_args(): @@ -38,7 +38,7 @@ def parse_args(): mto.enable_huggingface_checkpointing() args = parse_args() -_, model = load_vlm_or_llm_with_kwargs(args.model_path, torch_dtype="auto") +model = load_vlm_or_llm(args.model_path, torch_dtype="auto") model.eval() with torch.inference_mode(): export_speculative_decoding( diff --git a/modelopt/torch/speculative/plugins/modeling_fakebase.py b/modelopt/torch/speculative/plugins/modeling_fakebase.py new file mode 100644 index 000000000..0db11a455 --- /dev/null +++ b/modelopt/torch/speculative/plugins/modeling_fakebase.py @@ -0,0 +1,194 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Lightweight fake base model for offline speculative decoding training.""" + +import json +import os + +import torch +import torch.nn as nn +import transformers +from huggingface_hub import hf_hub_download +from huggingface_hub.errors import EntryNotFoundError +from safetensors.torch import load_file as safetensors_load_file +from transformers import PretrainedConfig, PreTrainedModel + +# Candidate module paths searched in order — shared with HFEagleModel._find_base_model_parts +_EMBED_TOKENS_PATHS = [ + "embed_tokens", + "language_model.model.embed_tokens", + "model.embed_tokens", + "backbone.embeddings", + "language_model.backbone.embeddings", + "model.language_model.embed_tokens", +] +_LM_HEAD_PATHS = ["lm_head", "language_model.lm_head"] +_BASE_MODEL_PATHS = [ + "language_model.model", + "model.language_model", + "model", + "backbone", + "language_model.backbone", +] +_VLM_CONFIG_ATTRS = ["text_config", "llm_config"] +_SAFETENSORS_INDEX_FILENAME = "model.safetensors.index.json" + + +class FakeBaseConfig(PretrainedConfig): + """Minimal config for FakeBaseModel that supports offline speculative decoding training.""" + + model_type = "fake_base_model" + + def __init__( + self, + num_hidden_layers=None, + hidden_size=None, + vocab_size=None, + max_position_embeddings=None, + dtype=torch.bfloat16, + tie_word_embeddings=False, + **kwargs, + ): + """Initialize FakeBaseConfig with minimal model configuration parameters.""" + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + self.num_hidden_layers = num_hidden_layers + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.dtype = dtype + + +class FakeBaseModel(PreTrainedModel): + """Minimal base model for offline speculative decoding. + + Contains only ``lm_head``, ``embed_tokens``, and the minimal config needed by the EAGLE + training loop. The full model weights are never loaded, keeping memory usage low. + + Weights are loaded from a local HuggingFace checkpoint directory. Weight key names and + VLM config nesting are auto-detected from the shared path constants. + """ + + config_class = FakeBaseConfig + + def __init__(self, source: str, trust_remote_code: bool = False): + """Load lm_head and embed_tokens from a local directory or HuggingFace Hub repo. + + Args: + source: Path to a local HuggingFace checkpoint directory, or a HuggingFace Hub + repo ID (e.g. ``"meta-llama/Llama-3.1-8B"``). The source type is detected + automatically: if ``source`` is an existing local directory it is treated as a + local checkpoint; otherwise it is treated as a Hub repo ID and the required + files are downloaded via ``huggingface_hub``. + """ + orig_config = transformers.AutoConfig.from_pretrained( + source, trust_remote_code=trust_remote_code + ) + # For vlms, detect language model config based on _VLM_CONFIG_ATTRS + base_cfg = next( + ( + getattr(orig_config, attr) + for attr in _VLM_CONFIG_ATTRS + if getattr(orig_config, attr, None) is not None + ), + orig_config, + ) + # Extract necessary info for spec training from base config + config = FakeBaseConfig( + num_hidden_layers=getattr(base_cfg, "num_hidden_layers", None), + hidden_size=getattr(base_cfg, "hidden_size", None), + vocab_size=getattr(base_cfg, "vocab_size", None), + max_position_embeddings=getattr(base_cfg, "max_position_embeddings", None), + dtype=getattr(base_cfg, "dtype", torch.bfloat16), + tie_word_embeddings=getattr(base_cfg, "tie_word_embeddings", False), + ) + super().__init__(config) + # Initialize dummy module and attributes for compatibility with HFEagleModel + self.model = nn.Module() + self.model.layers = nn.ModuleList() + self.model.dtype = config.dtype + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Load lm_head and embed_tokens only from checkpoint + lm_head_w, embed_tokens_w = self._load_weights(source) + assert lm_head_w.shape == (config.vocab_size, config.hidden_size) + assert embed_tokens_w.shape == (config.vocab_size, config.hidden_size) + self.lm_head.weight.data.copy_(lm_head_w) + self.embed_tokens.weight.data.copy_(embed_tokens_w) + + @staticmethod + def _find_weight_key(weight_map: dict, paths: list[str], label: str) -> str: + """Return the first ``path + '.weight'`` found in ``weight_map``.""" + for path in paths: + key = path + ".weight" + if key in weight_map: + return key + tried = [p + ".weight" for p in paths] + raise RuntimeError(f"Cannot find {label} in checkpoint; tried: {tried}") + + @staticmethod + def _load_index(source: str) -> dict: + """Load weight_map from model.safetensors.index.json (local directory or Hub repo).""" + if os.path.isdir(source): + index_path = os.path.join(source, _SAFETENSORS_INDEX_FILENAME) + if not os.path.isfile(index_path): + raise FileNotFoundError( + f"No {_SAFETENSORS_INDEX_FILENAME} found in {source!r}. " + "FakeBaseModel only supports safetensors checkpoints. " + "Checkpoints using pytorch_model.bin or single-file formats are not supported." + ) + else: + try: + index_path = hf_hub_download(repo_id=source, filename=_SAFETENSORS_INDEX_FILENAME) + except EntryNotFoundError: + raise ValueError( + f"Repository {source!r} does not contain {_SAFETENSORS_INDEX_FILENAME}. " + "FakeBaseModel only supports safetensors checkpoints. " + "Checkpoints using pytorch_model.bin or single-file formats are not supported." + ) from None + with open(index_path) as f: + return json.load(f).get("weight_map", {}) + + @staticmethod + def _resolve_shard_paths(source: str, shard_filenames: list[str]) -> list[str]: + """Return local filesystem paths for each shard filename. + + For a local directory the paths are joined directly; for a HuggingFace Hub repo ID the + shards are downloaded via ``hf_hub_download`` (cached on subsequent calls). + """ + if os.path.isdir(source): + return [os.path.join(source, name) for name in shard_filenames] + return [hf_hub_download(repo_id=source, filename=name) for name in shard_filenames] + + def _load_weights(self, source: str): + """Load lm_head and embed_tokens weights from a local directory or HuggingFace Hub repo.""" + weight_map = self._load_index(source) + + lm_head_key = self._find_weight_key(weight_map, _LM_HEAD_PATHS, "lm_head") + embed_tokens_key = self._find_weight_key(weight_map, _EMBED_TOKENS_PATHS, "embed_tokens") + + lm_head_path, embed_tokens_path = self._resolve_shard_paths( + source, [weight_map[lm_head_key], weight_map[embed_tokens_key]] + ) + + lm_head_state = safetensors_load_file(lm_head_path, device="cpu") + embed_tokens_state = safetensors_load_file(embed_tokens_path, device="cpu") + + return lm_head_state[lm_head_key], embed_tokens_state[embed_tokens_key] + + def forward(self, *args, **kwargs): + """Not implemented: FakeBaseModel omits full model weights and cannot run inference.""" + raise NotImplementedError("FakeBaseModel forward is not implemented.") diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 1b85c342e..ce2fc28db 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -48,7 +48,6 @@ ) from transformers.trainer_pt_utils import LabelSmoother from transformers.utils import ModelOutput -from transformers.utils.quantization_config import CompressedTensorsConfig from ...export.plugins.hf_spec_export import ( EagleExporter, @@ -68,6 +67,7 @@ get_ttt_msk_func, temporary_set_config_value, ) +from .modeling_fakebase import _BASE_MODEL_PATHS, _EMBED_TOKENS_PATHS, _LM_HEAD_PATHS __all__ = ["HFARValidation", "HFEagleModel", "HFMedusaModel"] @@ -464,19 +464,9 @@ def get_exporter(self) -> SpeculativeDecodingExporter: def _find_base_model_parts(self): """Find model parts from different models and set base_{part}_path attributes.""" base_model_parts_mapping = { - "base_model_path": [ - "model.language_model", - "model", - "backbone", - "language_model.backbone", - ], - "base_model_embeddings_path": [ - "model.embed_tokens", - "backbone.embeddings", - "language_model.backbone.embeddings", - "model.language_model.embed_tokens", - ], - "base_model_lm_head_path": ["lm_head", "language_model.lm_head"], + "base_model_path": _BASE_MODEL_PATHS, + "base_model_embeddings_path": _EMBED_TOKENS_PATHS, + "base_model_lm_head_path": _LM_HEAD_PATHS, } for name, paths in base_model_parts_mapping.items(): @@ -577,11 +567,6 @@ def modify( if self.eagle_config._attn_implementation is None: self.eagle_config._attn_implementation = "sdpa" - # Patch for Kimi-K2-Thinking, avoid quantizing drafter - quant_config = getattr(self.config, "quantization_config", None) - if isinstance(quant_config, CompressedTensorsConfig): - quant_config.ignore.append("re:.*eagle_module.*") - # Set default aux_hidden_state layers if ( self.eagle_config.use_aux_hidden_state @@ -591,7 +576,7 @@ def modify( # Freeze all parameters if self.eagle_freeze_base_model: - for name, param in self.named_parameters(): + for _, param in self.named_parameters(): param.requires_grad = False self.eagle_module = EagleModule( @@ -742,8 +727,6 @@ def _compute_ttt_attention_mask( tensor_mask, 0, dtype=self._base_llm_config.dtype, device=self.device ).masked_fill(~tensor_mask, dtypemin) - # Note: (hg) repeat mask for kimi-k2 compatibility - tensor_mask = tensor_mask.repeat(batch_size, 1, 1, 1) return tensor_mask def _base_model_forward( @@ -869,7 +852,7 @@ def forward( assert "base_model_outputs" in kwargs base_outputs = EagleBaseModelOutput.from_offline_dict(kwargs["base_model_outputs"]) if base_outputs.logits is None: - base_outputs.logits = self.lm_head(base_outputs.out_hiddens) + base_outputs.logits = self._base_model_lm_head(base_outputs.out_hiddens) past_key_values = None else: base_outputs, past_key_values = self._base_model_forward( diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index 72c5b5dbc..f9ffd2487 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -443,6 +443,16 @@ def patched_fwd_with_lazy_rope_init(self, *args, **kwargs): kimi_k2_module.DeepseekV3Attention._init_rope = lambda self: None kimi_k2_module.DeepseekV3Attention.forward = patched_fwd_with_lazy_rope_init + # Kimi implementation is based on older transformers which use "past_key_value" argument + # We patch it to "past_key_values" for compatibility + original_decoder_layer_forward = kimi_k2_module.DeepseekV3DecoderLayer.forward + + def patched_decoder_layer_fwd(self, *args, **kwargs): + kwargs["past_key_value"] = kwargs.pop("past_key_values", None) + return original_decoder_layer_forward(self, *args, **kwargs) + + kimi_k2_module.DeepseekV3DecoderLayer.forward = patched_decoder_layer_fwd + return getattr(kimi_k2_module, "DeepseekV3DecoderLayer") @@ -474,21 +484,60 @@ def enable_cp_ttt_patch(): modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = False -def load_vlm_or_llm_with_kwargs(model_name_or_path: str, **kwargs): - """Load a VLM or LLM with kwargs. Returns the model and model config.""" +def load_vlm_or_llm( + model_name_or_path: str, + use_fake_base: bool = False, + use_offline_training: bool = False, + torch_dtype: str | torch.dtype | None = None, + device_map: str | None = None, + trust_remote_code: bool = False, +): + """Load a VLM or LLM. Returns the model. + + When ``use_offline_training=True``, returns a + :class:`~modelopt.torch.speculative.plugins.modeling_fakebase.FakeBaseModel` containing only + ``lm_head`` and ``embed_tokens``, auto-detecting weight paths from the checkpoint. + Otherwise, falls back to loading with ``num_hidden_layers=0`` for memory efficiency. + + Args: + model_name_or_path: Local path or HuggingFace repo ID of the model. + use_offline_training: Whether to load a memory-efficient model for offline training. + torch_dtype: dtype to use when loading the model. + device_map: Device map passed to ``from_pretrained``. + trust_remote_code: Whether to trust remote code. + """ + if use_offline_training and use_fake_base: + from modelopt.torch.speculative.plugins.modeling_fakebase import FakeBaseModel + + return FakeBaseModel(model_name_or_path, trust_remote_code=trust_remote_code) + model_config = transformers.AutoConfig.from_pretrained( - model_name_or_path, trust_remote_code=True + model_name_or_path, trust_remote_code=trust_remote_code ) if "vl" in model_config.model_type.lower(): model_cls = transformers.AutoModelForVision2Seq else: model_cls = transformers.AutoModelForCausalLM - if kwargs.get("num_hidden_layers") == 0: + extra = {} + if use_offline_training: + extra["num_hidden_layers"] = 0 if hasattr(model_config, "layer_types"): - kwargs["layer_types"] = [] + extra["layer_types"] = [] + + model = model_cls.from_pretrained( + model_name_or_path, + trust_remote_code=trust_remote_code, + torch_dtype=torch_dtype, + device_map=device_map, + **extra, + ) + + if use_offline_training: + # Preserve the original layer count since we loaded with num_hidden_layers=0 + model.config.num_orig_hidden_layers = model_config.num_hidden_layers - return model_config, model_cls.from_pretrained(model_name_or_path, **kwargs) + return model @contextlib.contextmanager diff --git a/tests/examples/speculative_decoding/test_eagle.py b/tests/examples/speculative_decoding/test_eagle.py index a3542fa25..4a7b89457 100644 --- a/tests/examples/speculative_decoding/test_eagle.py +++ b/tests/examples/speculative_decoding/test_eagle.py @@ -15,6 +15,7 @@ import json import os +from pathlib import Path import pytest import safetensors.torch @@ -25,6 +26,44 @@ from modelopt.torch.export.plugins.hf_spec_export import LLAMA_EAGLE_SINGLE_LAYER +def generate_offline_pt_data( + output_dir, + num_files: int = 8, + seq_len: int = 128, + hidden_size: int = 512, + vocab_size: int = 32000, + num_aux_layers: int = 2, +) -> Path: + """Generate fake offline training .pt files for EAGLE3 offline training tests. + + Each file contains the keys expected by OfflineSupervisedDataset: + - input_ids: LongTensor of shape (seq_len,) + - hidden_states: FloatTensor of shape (seq_len, hidden_size) + - aux_hidden_states: FloatTensor of shape (seq_len, hidden_size*num_aux_layers) + + Args: + output_dir: Directory to write .pt files into. + num_files: Number of .pt files to generate. + seq_len: Sequence length. Defaults to 128. + hidden_size: Hidden size matching the base model. Defaults to 512 (tiny_llama). + vocab_size: Vocabulary size matching the base model. Defaults to 32000 (tiny_llama). + num_aux_layers: Number of auxiliary layers. Defaults to 2. + Returns: + Path to the output directory. + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + torch.manual_seed(42) + for i in range(num_files): + sample = { + "input_ids": torch.randint(0, vocab_size, (seq_len,)), + "hidden_states": torch.randn(seq_len, hidden_size), + "aux_hidden_states": torch.randn(seq_len, hidden_size * num_aux_layers), + } + torch.save(sample, output_dir / f"sample_{i:04d}.pt") + return output_dir + + @pytest.fixture(scope="module") def eagle_output_dir(tmp_path_factory): """Eagle output directory shared in this module.""" @@ -164,3 +203,69 @@ def test_convert_to_vllm_ckpt(tiny_llama_path, eagle_output_dir): ], "speculative_decoding", ) + + +@pytest.mark.parametrize( + ("model_source", "use_fake_base"), + [ + (None, False), # tiny_llama (from fixture), no FakeBase + ("moonshotai/Kimi-K2.5", True), # remote HF repo, FakeBaseModel + ("moonshotai/Kimi-K2-Thinking", True), # remote HF repo, no FakeBaseModel + ("MiniMaxAI/MiniMax-M2.5", True), + ], + ids=["tinyllama", "kimi-k2.5","kimi-k2-thinking","minimax-m2.5"], +) +def test_offline_eagle3_training( + tiny_llama_path, tiny_daring_anteater_path, tmp_path, eagle_output_dir, + model_source, use_fake_base, +): + """Test Eagle3 training with pre-computed hidden states (offline mode / FakeBaseModel).""" + import transformers + + model_path = tiny_llama_path if model_source is None else model_source + model_id = "tinyllama" if model_source is None else model_source.split("/")[-1] + output_subdir = eagle_output_dir / f"eagle-{model_id}-offline" + + cfg = transformers.AutoConfig.from_pretrained(model_path, trust_remote_code=True) + + if model_source=="moonshotai/Kimi-K2.5": + #vlm, get text config + cfg = cfg.text_config + + offline_data_dir = generate_offline_pt_data( + tmp_path / "offline_data", + hidden_size=cfg.hidden_size, + vocab_size=cfg.vocab_size, + num_aux_layers=min(cfg.num_hidden_layers, 3), + ) + + tiny_eagle_config = { + "max_position_embeddings": 128, + "num_hidden_layers": 1, + "intermediate_size": 64, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "head_dim": 64, + } + config_file = tmp_path / "tiny_eagle_config_offline.json" + with open(config_file, "w") as f: + json.dump(tiny_eagle_config, f) + + cmd = [ + "./launch_train.sh", + "--model", model_path, + "--data", tiny_daring_anteater_path, + "--offline-data", offline_data_dir, + "--num_epochs", "0.1", + "--lr", "1e-5", + "--mode", "eagle3", + "--eagle_config", str(config_file), + "--output_dir", output_subdir, + "--training_seq_len", "64", + "--trust_remote_code", "True", + "--fsdp", "False", + ] + if use_fake_base: + cmd += ["--use_fake_base_for_offline", "true"] + run_example_command(cmd, "speculative_decoding") + assert os.path.exists(output_subdir / "config.json") diff --git a/tests/unit/torch/speculative/plugins/test_fakebase.py b/tests/unit/torch/speculative/plugins/test_fakebase.py new file mode 100644 index 000000000..b19ce7b10 --- /dev/null +++ b/tests/unit/torch/speculative/plugins/test_fakebase.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for FakeBaseModel and the fake-base / offline paths in load_vlm_or_llm.""" + +import json + +import pytest +import safetensors.torch +import torch +import transformers + +from modelopt.torch.speculative.plugins.modeling_fakebase import FakeBaseModel +from modelopt.torch.speculative.utils import load_vlm_or_llm + +_HIDDEN_SIZE = 16 +_VOCAB_SIZE = 32 + + +@pytest.fixture +def fake_config(monkeypatch): + """Monkeypatch AutoConfig.from_pretrained to return a minimal fake config.""" + cfg = transformers.PretrainedConfig() + cfg.model_type = "llama" + cfg.hidden_size = _HIDDEN_SIZE + cfg.vocab_size = _VOCAB_SIZE + cfg.num_hidden_layers = 2 + cfg.max_position_embeddings = 128 + cfg.tie_word_embeddings = False + monkeypatch.setattr(transformers.AutoConfig, "from_pretrained", lambda *a, **kw: cfg) + return cfg + + +@pytest.fixture +def fake_checkpoint(tmp_path, fake_config): + """Minimal local safetensors checkpoint loadable by FakeBaseModel.""" + tensors = { + "lm_head.weight": torch.zeros(_VOCAB_SIZE, _HIDDEN_SIZE), + "embed_tokens.weight": torch.zeros(_VOCAB_SIZE, _HIDDEN_SIZE), + } + shard = tmp_path / "model-00001-of-00001.safetensors" + safetensors.torch.save_file(tensors, shard) + index = {"weight_map": dict.fromkeys(tensors, shard.name)} + (tmp_path / "model.safetensors.index.json").write_text(json.dumps(index)) + return tmp_path + + +def test_fakebase_local_happy_path(fake_checkpoint): + model = FakeBaseModel(str(fake_checkpoint)) + assert model.lm_head.weight.shape == torch.Size([_VOCAB_SIZE, _HIDDEN_SIZE]) + assert model.embed_tokens.weight.shape == torch.Size([_VOCAB_SIZE, _HIDDEN_SIZE]) + + +def test_fakebase_missing_index_raises(tmp_path, fake_config): + with pytest.raises(FileNotFoundError, match="safetensors"): + FakeBaseModel(str(tmp_path)) + + +def test_load_vlm_or_llm_returns_fakebase(fake_checkpoint): + model = load_vlm_or_llm(str(fake_checkpoint), use_offline_training=True, use_fake_base=True) + assert isinstance(model, FakeBaseModel) + + +def test_load_vlm_or_llm_offline_zero_layers(monkeypatch): + cfg = transformers.PretrainedConfig() + cfg.model_type = "llama" + cfg.num_hidden_layers = 4 + monkeypatch.setattr(transformers.AutoConfig, "from_pretrained", lambda *a, **kw: cfg) + + captured_kwargs = {} + + class _FakeModel: + config = cfg + + def _fake_from_pretrained(*args, **kwargs): + captured_kwargs.update(kwargs) + return _FakeModel() + + monkeypatch.setattr(transformers.AutoModelForCausalLM, "from_pretrained", _fake_from_pretrained) + + model = load_vlm_or_llm("fake-model", use_offline_training=True, use_fake_base=False) + assert captured_kwargs.get("num_hidden_layers") == 0 + assert model.config.num_orig_hidden_layers == 4