Conversation
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
… calls Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughParallelism config creation in the speculative decoding example is now conditional. The Transformers plugin adds NVTX profiling decorators, lazily initializes Llama rotary embeddings in Changes
Sequence Diagram(s)mermaid Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Important Pre-merge checks failedPlease resolve all errors before merging. Addressing warnings is optional. ❌ Failed checks (1 error, 1 warning)
✅ Passed checks (2 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Comment |
| training_args.parallelism_config = ParallelismConfig( | ||
| cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size | ||
| ) | ||
| if training_args.cp_size > 1 or training_args.dp_shard_size > 1: |
There was a problem hiding this comment.
Note, this is an unrelated bugfix related to #1045 (does not fully solve the issue, just a single-gpu workaround)
There was a problem hiding this comment.
As discussed in slack, this issue id due to transformers version mismatch. Should be fixed after updating transformers.
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 908-910: The code can raise a NameError when self.eagle_ttt_steps
== 0 because the loop that defines ttt_step never runs; update the logic in
modify() (or the surrounding block) to handle the zero-case explicitly: either
assert self.eagle_ttt_steps >= 1 at the start of modify() to make the invariant
explicit, or initialize ttt_step to a safe default (or skip code that uses
ttt_step) when eagle_ttt_steps == 0 and ensure train_accs =
torch.zeros(num_parallel, num_ttt, device=input_ids.device) is still valid;
reference symbols: self.eagle_ttt_steps, train_accs, ttt_step, and modify() when
applying the fix.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: de2f61b0-6ad0-43ef-b333-c5cd195b6a21
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/transformers.py
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
modelopt/torch/speculative/plugins/transformers.py (1)
909-912:⚠️ Potential issue | 🟠 MajorGuard zero-step TTT to avoid undefined
ttt_step.If
self.eagle_ttt_steps == 0, the loop at Line 931 never runs, and Line 989 referencesttt_stepbefore assignment.🔧 Proposed fix
- train_accs = torch.zeros(num_parallel, num_ttt, device=input_ids.device) + train_accs = torch.zeros(num_parallel, num_ttt, device=input_ids.device) + executed_ttt_steps = 0 @@ - for ttt_step in range(self.eagle_ttt_steps): + for ttt_step in range(self.eagle_ttt_steps): @@ - train_accs[i, ttt_step] = acc + train_accs[i, ttt_step] = acc + executed_ttt_steps = ttt_step + 1 if not self.training: break @@ - train_accs = train_accs[:, : ttt_step + 1].tolist() + train_accs = train_accs[:, :executed_ttt_steps].tolist()Also applies to: 988-990
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/transformers.py` around lines 909 - 912, Guard against zero TTT steps by checking self.eagle_ttt_steps before using ttt_step or running the TTT loop: if self.eagle_ttt_steps == 0 skip the entire TTT block (including the loop that populates train_accs and any later use of ttt_step) or initialize a safe default for ttt_step and related tensors so they are defined when eagle_ttt_steps is 0; update references around train_accs, the loop that iterates over range(self.eagle_ttt_steps) and the later code that uses ttt_step (the code near variables num_ttt, train_accs and where ttt_step is referenced) to either early-return/skip or handle the zero-case explicitly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 928-929: The RoPE initializer is only called in
HFEagleModel.forward, but other entry points like pseudo_speculative_generate()
call _eagle_forward() and can trigger EagleModule.forward before rotary_emb
exists; update the code so every EAGLE entry path invokes the initializer: call
self.eagle_module._maybe_init_rope() at the start of EagleModule.forward and
also ensure _eagle_forward() (and/or pseudo_speculative_generate()) invokes
_maybe_init_rope() before any use of rotary_emb so RoPE is always initialized
regardless of which forward path is exercised.
---
Duplicate comments:
In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 909-912: Guard against zero TTT steps by checking
self.eagle_ttt_steps before using ttt_step or running the TTT loop: if
self.eagle_ttt_steps == 0 skip the entire TTT block (including the loop that
populates train_accs and any later use of ttt_step) or initialize a safe default
for ttt_step and related tensors so they are defined when eagle_ttt_steps is 0;
update references around train_accs, the loop that iterates over
range(self.eagle_ttt_steps) and the later code that uses ttt_step (the code near
variables num_ttt, train_accs and where ttt_step is referenced) to either
early-return/skip or handle the zero-case explicitly.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: d8dcdd21-580f-406c-9a0f-ff3ab7c80865
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/transformers.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1044 +/- ##
==========================================
- Coverage 70.29% 70.26% -0.03%
==========================================
Files 227 227
Lines 25857 25868 +11
==========================================
Hits 18176 18176
- Misses 7681 7692 +11 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
|
Make sense to me. Is there any perf comparison before/after the optimizations? |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/transformers.py (1)
720-725: Consider a size guard before caching full teacher softmax.Caching
base_output_softmax_logitsas a full[B, S, V]tensor can materially increase peak memory (especially large-vocab runs). A guarded fallback to per-slice softmax would keep this optimization safer across wider configs.💡 Example guard pattern
- base_output_predict_tok = base_model_logits.argmax(dim=-1).detach() - base_output_softmax_logits = torch.softmax(base_model_logits, dim=2).detach() + base_output_predict_tok = base_model_logits.argmax(dim=-1).detach() + cache_softmax = base_model_logits.numel() <= getattr( + self.eagle_config, "max_cached_teacher_prob_elems", 0 + ) + base_output_softmax_logits = ( + torch.softmax(base_model_logits, dim=2).detach() if cache_softmax else None + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/transformers.py` around lines 720 - 725, The code currently caches base_output_softmax_logits as a full [B,S,V] tensor which can blow up memory for large vocabularies; add a size guard using eagle_config.draft_vocab_size and eagle_config.vocab_size (or an explicit max_vocab_for_full_softmax threshold) and only compute/cache full softmax when vocab_size*B*S is below the threshold, otherwise avoid storing base_output_softmax_logits and compute softmax per-slice on demand (or keep only argmax via base_output_predict_tok); update the block around base_model_logits, base_output_predict_tok and base_output_softmax_logits to branch on this guard and ensure downstream users handle the per-slice-compute path.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 720-725: The code currently caches base_output_softmax_logits as a
full [B,S,V] tensor which can blow up memory for large vocabularies; add a size
guard using eagle_config.draft_vocab_size and eagle_config.vocab_size (or an
explicit max_vocab_for_full_softmax threshold) and only compute/cache full
softmax when vocab_size*B*S is below the threshold, otherwise avoid storing
base_output_softmax_logits and compute softmax per-slice on demand (or keep only
argmax via base_output_predict_tok); update the block around base_model_logits,
base_output_predict_tok and base_output_softmax_logits to branch on this guard
and ensure downstream users handle the per-slice-compute path.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: f3113c0a-eb31-4a7d-b61d-bae7d19554ae
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/transformers.py
|
@h-guo18 I isolated perf using nsys profile with online training of a config for Llama 3.2 1B with K=3 on ISL 2048. Looking at the EAGLE3 FWD+BWD and excluding the target model forward pass, I get 1.9x speed improvement (roughly 280ms per batch of 16 requests, down from 540ms on main) |
|
@h-guo18 could you advise on this test failure? It seems like the windows build doesn't have NVTX available? I'm not sure how modelopt CI works, what do you suggest to fix the CI? |
Seems like only the test on windows fails (link). Instead of installing it in the testing container, I think it's better to make it an optional dependency for minimal impact. e.g. wrap the decorator with some check: from contextlib import contextmanager
try:
from torch.cuda import nvtx as torch_nvtx
except Exception:
torch_nvtx = None
def _nvtx_available() -> bool:
if torch_nvtx is None:
return False
try:
torch_nvtx.range_push("probe")
torch_nvtx.range_pop()
return True
except Exception:
return False
_NVTX_ENABLED = _nvtx_available()
def nvtx_range(msg: str):
"""Can be used as both decorator and context manager fallback target."""
if _NVTX_ENABLED:
return torch_nvtx.range(msg)
return _null_range(msg)
@contextmanager
def _null_range(msg: str):
yield
#to use it:
@nvtx_range("eagle_loss")
def compute_loss(x):
return x.sum() |
That's huge speedup. Thanks! I don't see any torch.compile in current PR code, so I assume most speedup comes from precomputing base_model_logits related stuff? |
|
Oh oops I forgot to push the torch.compile annotations |
|
A massive chunk comes from the torch.compile |
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/speculative/plugins/transformers.py`:
- Line 41: Make the NVTX import optional: replace the hard import of nvtx with a
guarded import (try/except ImportError) that sets a module-level nvtx = None
when unavailable, and add a helper function/ decorator factory named nvtx_range
that returns a no-op decorator when nvtx is None or returns nvtx.range(...) when
present; then replace all uses of the `@nvtx.range`(...) decorator in this module
(e.g., on any functions decorated with nvtx.range) with `@nvtx_range`(...) so the
code runs on systems without CUDA without changing behavior when nvtx is
available.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 85cd1527-f41a-4ff3-ad63-5908d1cc835b
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/transformers.py
Make sense. Besides,
|
|
LGTM. Please run a full test before merging |
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
…ups-torch-compile Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
|
@h-guo18 torch.compile is now optional, but on by default. I think we can update in software to disable the feature if CP>1 if it becomes a problem. But given the significant difference in performance I think we should try to have it on in as many cases as possible. (This is vLLM's torch.compile philosophy). I ran a 1000-step test on 1xB200 with Batch Size 16 for Llama 3.2-1B. Performance is much better with torch.compile enabled, 1.95 it/s v.s. 1.45 it/s, and this is in online mode where the base model cost is unchanged (base model not compiled). Accuracy is identical for both loss and AR. Config
Launch Command
Results
Torch Compile ON: Torch Compile OFF: |
|
The tests should pass now. |
| self._prepare_eagle_inputs = torch.compile(self._prepare_eagle_inputs, dynamic=False) | ||
| self._eagle_forward = torch.compile( | ||
| self._eagle_forward, dynamic=False, mode="max-autotune" | ||
| ) | ||
| self._eagle_loss = torch.compile(self._eagle_loss, dynamic=False, fullgraph=True) |
There was a problem hiding this comment.
How about we add some try-except to these torch.compile, such that it fallback to eager whenever compile fails?
There was a problem hiding this comment.
I think there's many other (unknown) cases torch compile will fail, other than the windows ci
There was a problem hiding this comment.
Then we can also separate the compilation of these 3 functions. E.g. even if _eagle_forward fails to compile (due to flex attention perhaps), the other 2 function can still be optimized
There was a problem hiding this comment.
I like this idea. However, I still think we would like to have full coverage of torch compile in CI if possible
There was a problem hiding this comment.
I guess the linux test is passing, so for today we are covered. But it would be nice to have some confidence that the torch compile doesn't break (and just get skipped) in the future
There was a problem hiding this comment.
Seems like some error could happen at runtime, even if compile works:
https://github.com/NVIDIA/Model-Optimizer/actions/runs/23258816928/job/67623073586?pr=1044#step:4:1727
We probably need to also set torch._dynamo.config.suppress_errors = True for fallback, in addition to try-except before torch.compile
There was a problem hiding this comment.
Let's see if this one does the trick
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
…VIDIA/Model-Optimizer into bchislett/eagle-speedups-torch-compile
What does this PR do?
Type of change: Optimization
Changes:
Usage
No changes to external interfaces
Testing
Ran training commands for benchmarking. Did not do a full training run, did not validate correctness.
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: N/AAdditional Information
Summary by CodeRabbit
New Features
Improvements