[Feat]FakeBaseModel for offline eagle; Kimi-K2.5 fixes;#1052
[Feat]FakeBaseModel for offline eagle; Kimi-K2.5 fixes;#1052
Conversation
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughAdds a lightweight FakeBaseModel and loader API, replaces load_vlm_or_llm_with_kwargs with load_vlm_or_llm (explicit flags), threads use_fake_base_for_offline and trust_remote_code through scripts, updates offline .pt discovery, and adds tests and utilities to support offline/fake-base speculative decoding workflows. Changes
Sequence Diagram(s)sequenceDiagram
participant Launcher as launch_train.sh
participant Example as main.py / scripts
participant Loader as load_vlm_or_llm
participant FakeBase as FakeBaseModel
participant HFHub as HuggingFace Hub / Filesystem
Launcher->>Example: start training (flags: use_fake_base_for_offline, trust_remote_code, offline-data)
Example->>Loader: load_vlm_or_llm(model_path, use_fake_base=..., use_offline_training=..., trust_remote_code=...)
alt use_fake_base && use_offline_training
Loader->>FakeBase: instantiate FakeBaseModel(source, trust_remote_code)
FakeBase->>HFHub: fetch model.safetensors.index.json (local or hub)
FakeBase->>HFHub: download shard files
HFHub-->>FakeBase: return weights
FakeBase-->>Loader: return FakeBaseModel
else regular path
Loader->>HFHub: AutoConfig.from_pretrained(..., trust_remote_code)
Loader->>HFHub: from_pretrained with modified config (num_hidden_layers=0 for offline)
HFHub-->>Loader: model artifacts
Loader-->>Example: return model
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Important Pre-merge checks failedPlease resolve all errors before merging. Addressing warnings is optional. ❌ Failed checks (1 error, 1 warning)
✅ Passed checks (2 passed)
Important Merge conflicts detected (Beta)
✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
📝 Coding Plan
Comment |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1052 +/- ##
==========================================
+ Coverage 70.09% 70.28% +0.19%
==========================================
Files 221 227 +6
Lines 25541 25873 +332
==========================================
+ Hits 17902 18185 +283
- Misses 7639 7688 +49 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
modelopt/torch/speculative/utils.py (1)
487-508: LGTM:trust_remote_codeproperly parameterized inload_vlm_or_llm.The function correctly exposes
trust_remote_codeas a caller-configurable parameter defaulting toFalse, complying with SECURITY.md guidelines.Minor: The docstring is missing the
use_fake_baseparameter description.📝 Proposed docstring fix
Args: model_name_or_path: Local path or HuggingFace repo ID of the model. + use_fake_base: Whether to use FakeBaseModel for offline training (default True). use_offline_training: Whether to load a memory-efficient model for offline training. torch_dtype: dtype to use when loading the model. device_map: Device map passed to ``from_pretrained``. trust_remote_code: Whether to trust remote code.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/utils.py` around lines 487 - 508, The docstring for load_vlm_or_llm is missing a description for the use_fake_base parameter; update the Args section to add a one-line explanation for use_fake_base (what it toggles and its default behavior), e.g., indicate that use_fake_base controls whether a FakeBaseModel is used when loading (default True) and how it interacts with use_offline_training, so readers can understand its purpose alongside model_name_or_path, use_offline_training, torch_dtype, device_map, and trust_remote_code.modelopt/torch/speculative/plugins/modeling_fakebase.py (1)
126-129: Consider using explicit exceptions instead ofassertfor shape validation.Assertions can be disabled with
-O(optimized mode). For production code that validates external checkpoint data, explicitValueErrororRuntimeErrorwould be more robust.♻️ Proposed refactor
- assert lm_head_w.shape == (config.vocab_size, config.hidden_size) - assert embed_tokens_w.shape == (config.vocab_size, config.hidden_size) + if lm_head_w.shape != (config.vocab_size, config.hidden_size): + raise ValueError( + f"lm_head shape mismatch: expected {(config.vocab_size, config.hidden_size)}, " + f"got {lm_head_w.shape}" + ) + if embed_tokens_w.shape != (config.vocab_size, config.hidden_size): + raise ValueError( + f"embed_tokens shape mismatch: expected {(config.vocab_size, config.hidden_size)}, " + f"got {embed_tokens_w.shape}" + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/modeling_fakebase.py` around lines 126 - 129, Replace the two assert checks with explicit runtime validation that raises a clear exception (e.g., ValueError) when shapes mismatch: check that lm_head_w.shape == (config.vocab_size, config.hidden_size) and embed_tokens_w.shape == (config.vocab_size, config.hidden_size) and if not raise an error that includes the actual shapes and expected dimensions; keep the subsequent assignments to self.lm_head.weight.data.copy_(lm_head_w) and self.embed_tokens.weight.data.copy_(embed_tokens_w) unchanged so that invalid checkpoint data fails loudly in production rather than being skipped when Python assertions are disabled.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/examples/speculative_decoding/test_eagle.py`:
- Around line 225-227: The code sets model_id to the hardcoded string
"kimi-k2.5" when model_source is provided, causing remote-model runs to share
the same eagle_output_dir; change the assignment for model_id to derive a unique
identifier from model_source (e.g., use the repository/name suffix:
model_source.split("/")[-1] or a sanitized version of that string) so model_id
is unique per model_source and output_subdir = eagle_output_dir /
f"eagle-{model_id}-offline" writes to distinct directories; update the model_id
assignment near the model_path/model_id definitions used in the test file
(model_path, model_id, eagle_output_dir).
---
Nitpick comments:
In `@modelopt/torch/speculative/plugins/modeling_fakebase.py`:
- Around line 126-129: Replace the two assert checks with explicit runtime
validation that raises a clear exception (e.g., ValueError) when shapes
mismatch: check that lm_head_w.shape == (config.vocab_size, config.hidden_size)
and embed_tokens_w.shape == (config.vocab_size, config.hidden_size) and if not
raise an error that includes the actual shapes and expected dimensions; keep the
subsequent assignments to self.lm_head.weight.data.copy_(lm_head_w) and
self.embed_tokens.weight.data.copy_(embed_tokens_w) unchanged so that invalid
checkpoint data fails loudly in production rather than being skipped when Python
assertions are disabled.
In `@modelopt/torch/speculative/utils.py`:
- Around line 487-508: The docstring for load_vlm_or_llm is missing a
description for the use_fake_base parameter; update the Args section to add a
one-line explanation for use_fake_base (what it toggles and its default
behavior), e.g., indicate that use_fake_base controls whether a FakeBaseModel is
used when loading (default True) and how it interacts with use_offline_training,
so readers can understand its purpose alongside model_name_or_path,
use_offline_training, torch_dtype, device_map, and trust_remote_code.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: d0e27af8-82ec-41a0-a296-d037db3dbd86
📒 Files selected for processing (9)
examples/speculative_decoding/eagle_utils.pyexamples/speculative_decoding/launch_train.shexamples/speculative_decoding/main.pyexamples/speculative_decoding/scripts/ar_validate.pyexamples/speculative_decoding/scripts/export_hf_checkpoint.pymodelopt/torch/speculative/plugins/modeling_fakebase.pymodelopt/torch/speculative/plugins/transformers.pymodelopt/torch/speculative/utils.pytests/examples/speculative_decoding/test_eagle.py
ChenhanYu
left a comment
There was a problem hiding this comment.
This PR introduces FakeBaseModel for memory-efficient offline EAGLE training (loads only lm_head + embed_tokens), refactors load_vlm_or_llm_with_kwargs → load_vlm_or_llm with a cleaner API, consolidates model path constants, and adds several Kimi-K2.5 compatibility fixes.
Note: load_vlm_or_llm_with_kwargs → load_vlm_or_llm is a breaking API change — any downstream callers passing extra kwargs will break.
Test coverage is insufficient. The PR adds one integration test that exercises the happy path, but:
FakeBaseModelhas no unit tests (weight loading, path auto-detection, error cases like missing safetensors index or unrecognized weight keys)load_vlm_or_llmhas no unit tests for its three distinct code paths (fake base, offline withnum_hidden_layers=0, normal load)- The removed
CompressedTensorsConfigignore patch and attention maskrepeathave no regression tests proving they're safe to remove patched_decoder_layer_fwd(past_key_value/past_key_values fix) is untested- 3 of 4 test cases download from remote HF repos — needs
@pytest.mark.slowor similar for CI reliability
Please add:
- Unit tests for
FakeBaseModel(local path happy path, missing index, wrong keys) - Unit tests for
load_vlm_or_llmcode paths @pytest.mark.slowor network markers for the remote-model tests
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/speculative/plugins/modeling_fakebase.py`:
- Around line 21-27: The module currently hard-imports optional packages
(transformers, huggingface_hub, safetensors) at top-level; change to guarded
imports following the codebase pattern: wrap imports in try/except and set
boolean flags (e.g., HAS_TRANSFORMERS, HAS_HF_HUB, HAS_SAFETENSORS) or move
those imports into the functions/methods that use them (e.g., where
hf_hub_download, EntryNotFoundError, safetensors_load_file, PretrainedConfig,
PreTrainedModel are referenced or where FakeBaseModel is constructed); ensure
any code that requires these libs checks the flags and raises a clear
ImportError if used without the [hf] extras.
- Line 115: The _load_weights() routine currently requires both 'lm_head.weight'
and 'embed_tokens.weight' even when config.tie_word_embeddings is True; update
_load_weights() to check self.config.tie_word_embeddings and, if True, call
_find_weight_key() once for the existing key (prefer 'embed_tokens.weight' but
accept either) and then assign that same tensor to both lm_head and embed_tokens
rather than requiring both keys; keep the existing behavior (separately finding
both keys) when tie_word_embeddings is False and continue to use
_find_weight_key() for key discovery and error reporting.
- Around line 109-115: The config construction and module creation currently
read getattr(base_cfg, "dtype", ...) and omit passing dtype to modules; change
to read getattr(base_cfg, "torch_dtype", None) (falling back to torch.bfloat16
only if None) when creating FakeBaseConfig and then ensure that created modules
(Embedding, Linear, and any parameter tensors) are instantiated with that dtype
so weights are allocated in the checkpoint dtype; update references around
FakeBaseConfig construction and places that instantiate nn.Embedding / nn.Linear
/ torch.nn.Parameter so they pass dtype=self.model.dtype (or the local
torch_dtype) to preserve fp16/bf16 and keep self.model.dtype accurate.
In `@modelopt/torch/speculative/utils.py`:
- Around line 514-516: The code currently reads attributes like
num_hidden_layers and layer_types directly from model_config (created via
transformers.AutoConfig.from_pretrained), but composite VLM configs nest these
under text_config/llm_config; apply the same unwrapping used in FakeBaseModel by
iterating _VLM_CONFIG_ATTRS (or using the same helper) to drill into the actual
LM config before checking hasattr(model_config, "layer_types") and before
assigning num_orig_hidden_layers, so that you first replace model_config with
the unwrapped sub-config (if present) and then proceed to read num_hidden_layers
and layer_types.
- Around line 450-452: In patched_decoder_layer_fwd, avoid clobbering an
existing legacy kwarg by only mapping the new name when present: if
"past_key_values" exists in kwargs, set kwargs["past_key_value"] =
kwargs.pop("past_key_values"); otherwise leave kwargs["past_key_value"]
untouched (do not set it to None). Update the logic in patched_decoder_layer_fwd
(which calls original_decoder_layer_forward) to perform this conditional
translation so callers that pass the old "past_key_value" continue to work.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: d3fe6aa2-e68e-478f-9a0b-7ecfa5ff4b6c
📒 Files selected for processing (4)
examples/speculative_decoding/launch_train.shmodelopt/torch/speculative/plugins/modeling_fakebase.pymodelopt/torch/speculative/utils.pytests/examples/speculative_decoding/test_eagle.py
🚧 Files skipped from review as they are similar to previous changes (2)
- examples/speculative_decoding/launch_train.sh
- tests/examples/speculative_decoding/test_eagle.py
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
What does this PR do?
Adds
FakeBaseModelfor offline EAGLE training and several Kimi-K2.5 compatibility fixes.FakeBaseModel— lightweight model that loads onlylm_headandembed_tokensfrom a local checkpoint, avoiding full model weight loading during offline training. Configured viaFakeBaseArgumentsand integrated intoload_vlm_or_llm._find_base_model_parts— support Kimi-K2.5 VLM layout (language_model.modelpath)past_key_value/past_key_valuesargument mismatchrglobfor.ptdiscovery in nested offline data dirs; single-node GPU count respectsCUDA_VISIBLE_DEVICESType of change: Bug fix, new feature
Testing
Tested offline EAGLE training for Kimi-K2.5 end-to-end.
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: N/AAdditional Information
Summary by CodeRabbit
New Features
Improvements
Bug Fixes
Tests