Skip to content

[Feat]FakeBaseModel for offline eagle; Kimi-K2.5 fixes;#1052

Open
h-guo18 wants to merge 9 commits intomainfrom
haoguo/fakebasemodel
Open

[Feat]FakeBaseModel for offline eagle; Kimi-K2.5 fixes;#1052
h-guo18 wants to merge 9 commits intomainfrom
haoguo/fakebasemodel

Conversation

@h-guo18
Copy link
Contributor

@h-guo18 h-guo18 commented Mar 17, 2026

What does this PR do?

Adds FakeBaseModel for offline EAGLE training and several Kimi-K2.5 compatibility fixes.

  • New: FakeBaseModel — lightweight model that loads only lm_head and embed_tokens from a local checkpoint, avoiding full model weight loading during offline training. Configured via FakeBaseArguments and integrated into load_vlm_or_llm.
  • Fix: _find_base_model_parts — support Kimi-K2.5 VLM layout (language_model.model path)
  • Fix: offline mode lm_head access and CompressedTensors ignore path
  • Fix: Kimi-K2.5 decoder past_key_value/past_key_values argument mismatch
  • Fix: rglob for .pt discovery in nested offline data dirs; single-node GPU count respects CUDA_VISIBLE_DEVICES

Type of change: Bug fix, new feature

Testing

Tested offline EAGLE training for Kimi-K2.5 end-to-end.

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: ❌
  • Did you update Changelog?: ❌

Additional Information

Summary by CodeRabbit

  • New Features

    • Added lightweight fake-base model support for offline training
    • Introduced --use_fake_base_for_offline and --trust_remote_code CLI flags
  • Improvements

    • Enhanced GPU detection and single-node logging for training runs
    • Expanded offline data discovery to include nested subdirectories
    • Updated model/tokenizer loading to respect remote-code trust and offline flags; tooling updated to the new loader
  • Bug Fixes

    • Improved compatibility with legacy transformer call signatures
  • Tests

    • Added tests for fake-base loading and offline training workflows

Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 17, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 17, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: fb58196d-bcc1-457e-9d59-b13039a1c6e2

📥 Commits

Reviewing files that changed from the base of the PR and between a023e6e and 99946d5.

📒 Files selected for processing (1)
  • tests/unit/torch/speculative/plugins/test_fakebase.py

📝 Walkthrough

Walkthrough

Adds a lightweight FakeBaseModel and loader API, replaces load_vlm_or_llm_with_kwargs with load_vlm_or_llm (explicit flags), threads use_fake_base_for_offline and trust_remote_code through scripts, updates offline .pt discovery, and adds tests and utilities to support offline/fake-base speculative decoding workflows.

Changes

Cohort / File(s) Summary
API Callsites
examples/speculative_decoding/main.py, examples/speculative_decoding/scripts/ar_validate.py, examples/speculative_decoding/scripts/export_hf_checkpoint.py
Replaced imports/calls of load_vlm_or_llm_with_kwargs with load_vlm_or_llm; adjusted call signatures and removed unpacking of a (config, model) tuple (now returns only model); propagate trust_remote_code and new flags.
Launch & Utilities
examples/speculative_decoding/launch_train.sh, examples/speculative_decoding/eagle_utils.py
Added CLI flags --use_fake_base_for_offline and --trust_remote_code; improved TOTAL_GPU calculation for single vs multi-node; conditionally enable FSDP only when requested; changed .pt discovery to recursive rglob("*.pt").
Model Loading Core
modelopt/torch/speculative/utils.py
Removed load_vlm_or_llm_with_kwargs, added load_vlm_or_llm with explicit params (use_fake_base, use_offline_training, torch_dtype, device_map, trust_remote_code); returns only model; added DeepseekV3DecoderLayer forward compatibility patch; offline-load paths handle zero-layer fake/stripped models and record original layer counts.
Fake Base Plugin
modelopt/torch/speculative/plugins/modeling_fakebase.py
New file: FakeBaseConfig and FakeBaseModel that construct minimal config and load embed_tokens / lm_head weights from local safetensors or HF Hub index/shards; forward is unimplemented.
Transformers Plugin Updates
modelopt/torch/speculative/plugins/transformers.py
Centralized base/embedding/lm-head path constants, removed prior quantization patch, switched offline logits path to _base_model_lm_head, eliminated duplicated attention-mask code; variable renaming only.
Tests / Offline Data
tests/examples/speculative_decoding/test_eagle.py, tests/unit/torch/speculative/plugins/test_fakebase.py
Added generate_offline_pt_data helper and test_offline_eagle3_training that creates .pt samples and runs launch_train.sh with offline data; added unit tests for FakeBaseModel loading, safetensors index handling, and load_vlm_or_llm offline/fake-base behaviors.

Sequence Diagram(s)

sequenceDiagram
  participant Launcher as launch_train.sh
  participant Example as main.py / scripts
  participant Loader as load_vlm_or_llm
  participant FakeBase as FakeBaseModel
  participant HFHub as HuggingFace Hub / Filesystem

  Launcher->>Example: start training (flags: use_fake_base_for_offline, trust_remote_code, offline-data)
  Example->>Loader: load_vlm_or_llm(model_path, use_fake_base=..., use_offline_training=..., trust_remote_code=...)
  alt use_fake_base && use_offline_training
    Loader->>FakeBase: instantiate FakeBaseModel(source, trust_remote_code)
    FakeBase->>HFHub: fetch model.safetensors.index.json (local or hub)
    FakeBase->>HFHub: download shard files
    HFHub-->>FakeBase: return weights
    FakeBase-->>Loader: return FakeBaseModel
  else regular path
    Loader->>HFHub: AutoConfig.from_pretrained(..., trust_remote_code)
    Loader->>HFHub: from_pretrained with modified config (num_hidden_layers=0 for offline)
    HFHub-->>Loader: model artifacts
    Loader-->>Example: return model
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes


Important

Pre-merge checks failed

Please resolve all errors before merging. Addressing warnings is optional.

❌ Failed checks (1 error, 1 warning)

Check name Status Explanation Resolution
Security Anti-Patterns ❌ Error Hardcoded trust_remote_code=True in tests/examples/speculative_decoding/test_eagle.py:229 without justification violates SECURITY.md instruction #3. Make trust_remote_code configurable with False default, or add inline comment explaining why test requires trusting remote code.
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately captures the two main features introduced: FakeBaseModel for offline EAGLE training and Kimi-K2.5 fixes, which are the primary objectives of the PR.

Important

Merge conflicts detected (Beta)

  • Resolve merge conflict in branch haoguo/fakebasemodel
✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch haoguo/fakebasemodel
📝 Coding Plan
  • Generate coding plan for human review comments

Comment @coderabbitai help to get the list of available commands and usage tips.

@h-guo18 h-guo18 changed the title Add FakeBaseModel for offline speculative decoding and Kimi-K2.5 fixes [Feat]FakeBaseModel for offline eagle; Kimi-K2.5 fixes; Mar 17, 2026
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
@codecov
Copy link

codecov bot commented Mar 17, 2026

Codecov Report

❌ Patch coverage is 5.88235% with 16 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.28%. Comparing base (7b34de6) to head (0df75e2).
⚠️ Report is 8 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/speculative/utils.py 5.88% 16 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1052      +/-   ##
==========================================
+ Coverage   70.09%   70.28%   +0.19%     
==========================================
  Files         221      227       +6     
  Lines       25541    25873     +332     
==========================================
+ Hits        17902    18185     +283     
- Misses       7639     7688      +49     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
@h-guo18 h-guo18 self-assigned this Mar 17, 2026
h-guo18 added 3 commits March 17, 2026 08:29
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
@h-guo18 h-guo18 marked this pull request as ready for review March 18, 2026 17:45
@h-guo18 h-guo18 requested a review from a team as a code owner March 18, 2026 17:45
@h-guo18 h-guo18 requested a review from yeyu-nvidia March 18, 2026 17:45
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
modelopt/torch/speculative/utils.py (1)

487-508: LGTM: trust_remote_code properly parameterized in load_vlm_or_llm.

The function correctly exposes trust_remote_code as a caller-configurable parameter defaulting to False, complying with SECURITY.md guidelines.

Minor: The docstring is missing the use_fake_base parameter description.

📝 Proposed docstring fix
     Args:
         model_name_or_path: Local path or HuggingFace repo ID of the model.
+        use_fake_base: Whether to use FakeBaseModel for offline training (default True).
         use_offline_training: Whether to load a memory-efficient model for offline training.
         torch_dtype: dtype to use when loading the model.
         device_map: Device map passed to ``from_pretrained``.
         trust_remote_code: Whether to trust remote code.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/utils.py` around lines 487 - 508, The docstring
for load_vlm_or_llm is missing a description for the use_fake_base parameter;
update the Args section to add a one-line explanation for use_fake_base (what it
toggles and its default behavior), e.g., indicate that use_fake_base controls
whether a FakeBaseModel is used when loading (default True) and how it interacts
with use_offline_training, so readers can understand its purpose alongside
model_name_or_path, use_offline_training, torch_dtype, device_map, and
trust_remote_code.
modelopt/torch/speculative/plugins/modeling_fakebase.py (1)

126-129: Consider using explicit exceptions instead of assert for shape validation.

Assertions can be disabled with -O (optimized mode). For production code that validates external checkpoint data, explicit ValueError or RuntimeError would be more robust.

♻️ Proposed refactor
-        assert lm_head_w.shape == (config.vocab_size, config.hidden_size)
-        assert embed_tokens_w.shape == (config.vocab_size, config.hidden_size)
+        if lm_head_w.shape != (config.vocab_size, config.hidden_size):
+            raise ValueError(
+                f"lm_head shape mismatch: expected {(config.vocab_size, config.hidden_size)}, "
+                f"got {lm_head_w.shape}"
+            )
+        if embed_tokens_w.shape != (config.vocab_size, config.hidden_size):
+            raise ValueError(
+                f"embed_tokens shape mismatch: expected {(config.vocab_size, config.hidden_size)}, "
+                f"got {embed_tokens_w.shape}"
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/modeling_fakebase.py` around lines 126 -
129, Replace the two assert checks with explicit runtime validation that raises
a clear exception (e.g., ValueError) when shapes mismatch: check that
lm_head_w.shape == (config.vocab_size, config.hidden_size) and
embed_tokens_w.shape == (config.vocab_size, config.hidden_size) and if not raise
an error that includes the actual shapes and expected dimensions; keep the
subsequent assignments to self.lm_head.weight.data.copy_(lm_head_w) and
self.embed_tokens.weight.data.copy_(embed_tokens_w) unchanged so that invalid
checkpoint data fails loudly in production rather than being skipped when Python
assertions are disabled.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/examples/speculative_decoding/test_eagle.py`:
- Around line 225-227: The code sets model_id to the hardcoded string
"kimi-k2.5" when model_source is provided, causing remote-model runs to share
the same eagle_output_dir; change the assignment for model_id to derive a unique
identifier from model_source (e.g., use the repository/name suffix:
model_source.split("/")[-1] or a sanitized version of that string) so model_id
is unique per model_source and output_subdir = eagle_output_dir /
f"eagle-{model_id}-offline" writes to distinct directories; update the model_id
assignment near the model_path/model_id definitions used in the test file
(model_path, model_id, eagle_output_dir).

---

Nitpick comments:
In `@modelopt/torch/speculative/plugins/modeling_fakebase.py`:
- Around line 126-129: Replace the two assert checks with explicit runtime
validation that raises a clear exception (e.g., ValueError) when shapes
mismatch: check that lm_head_w.shape == (config.vocab_size, config.hidden_size)
and embed_tokens_w.shape == (config.vocab_size, config.hidden_size) and if not
raise an error that includes the actual shapes and expected dimensions; keep the
subsequent assignments to self.lm_head.weight.data.copy_(lm_head_w) and
self.embed_tokens.weight.data.copy_(embed_tokens_w) unchanged so that invalid
checkpoint data fails loudly in production rather than being skipped when Python
assertions are disabled.

In `@modelopt/torch/speculative/utils.py`:
- Around line 487-508: The docstring for load_vlm_or_llm is missing a
description for the use_fake_base parameter; update the Args section to add a
one-line explanation for use_fake_base (what it toggles and its default
behavior), e.g., indicate that use_fake_base controls whether a FakeBaseModel is
used when loading (default True) and how it interacts with use_offline_training,
so readers can understand its purpose alongside model_name_or_path,
use_offline_training, torch_dtype, device_map, and trust_remote_code.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: d0e27af8-82ec-41a0-a296-d037db3dbd86

📥 Commits

Reviewing files that changed from the base of the PR and between 7b34de6 and 0df75e2.

📒 Files selected for processing (9)
  • examples/speculative_decoding/eagle_utils.py
  • examples/speculative_decoding/launch_train.sh
  • examples/speculative_decoding/main.py
  • examples/speculative_decoding/scripts/ar_validate.py
  • examples/speculative_decoding/scripts/export_hf_checkpoint.py
  • modelopt/torch/speculative/plugins/modeling_fakebase.py
  • modelopt/torch/speculative/plugins/transformers.py
  • modelopt/torch/speculative/utils.py
  • tests/examples/speculative_decoding/test_eagle.py

Copy link
Collaborator

@ChenhanYu ChenhanYu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR introduces FakeBaseModel for memory-efficient offline EAGLE training (loads only lm_head + embed_tokens), refactors load_vlm_or_llm_with_kwargsload_vlm_or_llm with a cleaner API, consolidates model path constants, and adds several Kimi-K2.5 compatibility fixes.

Note: load_vlm_or_llm_with_kwargsload_vlm_or_llm is a breaking API change — any downstream callers passing extra kwargs will break.

Test coverage is insufficient. The PR adds one integration test that exercises the happy path, but:

  • FakeBaseModel has no unit tests (weight loading, path auto-detection, error cases like missing safetensors index or unrecognized weight keys)
  • load_vlm_or_llm has no unit tests for its three distinct code paths (fake base, offline with num_hidden_layers=0, normal load)
  • The removed CompressedTensorsConfig ignore patch and attention mask repeat have no regression tests proving they're safe to remove
  • patched_decoder_layer_fwd (past_key_value/past_key_values fix) is untested
  • 3 of 4 test cases download from remote HF repos — needs @pytest.mark.slow or similar for CI reliability

Please add:

  1. Unit tests for FakeBaseModel (local path happy path, missing index, wrong keys)
  2. Unit tests for load_vlm_or_llm code paths
  3. @pytest.mark.slow or network markers for the remote-model tests

h-guo18 added 2 commits March 19, 2026 06:38
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/speculative/plugins/modeling_fakebase.py`:
- Around line 21-27: The module currently hard-imports optional packages
(transformers, huggingface_hub, safetensors) at top-level; change to guarded
imports following the codebase pattern: wrap imports in try/except and set
boolean flags (e.g., HAS_TRANSFORMERS, HAS_HF_HUB, HAS_SAFETENSORS) or move
those imports into the functions/methods that use them (e.g., where
hf_hub_download, EntryNotFoundError, safetensors_load_file, PretrainedConfig,
PreTrainedModel are referenced or where FakeBaseModel is constructed); ensure
any code that requires these libs checks the flags and raises a clear
ImportError if used without the [hf] extras.
- Line 115: The _load_weights() routine currently requires both 'lm_head.weight'
and 'embed_tokens.weight' even when config.tie_word_embeddings is True; update
_load_weights() to check self.config.tie_word_embeddings and, if True, call
_find_weight_key() once for the existing key (prefer 'embed_tokens.weight' but
accept either) and then assign that same tensor to both lm_head and embed_tokens
rather than requiring both keys; keep the existing behavior (separately finding
both keys) when tie_word_embeddings is False and continue to use
_find_weight_key() for key discovery and error reporting.
- Around line 109-115: The config construction and module creation currently
read getattr(base_cfg, "dtype", ...) and omit passing dtype to modules; change
to read getattr(base_cfg, "torch_dtype", None) (falling back to torch.bfloat16
only if None) when creating FakeBaseConfig and then ensure that created modules
(Embedding, Linear, and any parameter tensors) are instantiated with that dtype
so weights are allocated in the checkpoint dtype; update references around
FakeBaseConfig construction and places that instantiate nn.Embedding / nn.Linear
/ torch.nn.Parameter so they pass dtype=self.model.dtype (or the local
torch_dtype) to preserve fp16/bf16 and keep self.model.dtype accurate.

In `@modelopt/torch/speculative/utils.py`:
- Around line 514-516: The code currently reads attributes like
num_hidden_layers and layer_types directly from model_config (created via
transformers.AutoConfig.from_pretrained), but composite VLM configs nest these
under text_config/llm_config; apply the same unwrapping used in FakeBaseModel by
iterating _VLM_CONFIG_ATTRS (or using the same helper) to drill into the actual
LM config before checking hasattr(model_config, "layer_types") and before
assigning num_orig_hidden_layers, so that you first replace model_config with
the unwrapped sub-config (if present) and then proceed to read num_hidden_layers
and layer_types.
- Around line 450-452: In patched_decoder_layer_fwd, avoid clobbering an
existing legacy kwarg by only mapping the new name when present: if
"past_key_values" exists in kwargs, set kwargs["past_key_value"] =
kwargs.pop("past_key_values"); otherwise leave kwargs["past_key_value"]
untouched (do not set it to None). Update the logic in patched_decoder_layer_fwd
(which calls original_decoder_layer_forward) to perform this conditional
translation so callers that pass the old "past_key_value" continue to work.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: d3fe6aa2-e68e-478f-9a0b-7ecfa5ff4b6c

📥 Commits

Reviewing files that changed from the base of the PR and between 0df75e2 and a023e6e.

📒 Files selected for processing (4)
  • examples/speculative_decoding/launch_train.sh
  • modelopt/torch/speculative/plugins/modeling_fakebase.py
  • modelopt/torch/speculative/utils.py
  • tests/examples/speculative_decoding/test_eagle.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • examples/speculative_decoding/launch_train.sh
  • tests/examples/speculative_decoding/test_eagle.py

Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants