[Feature] Add guided decoding support for speculative decoding#4559
[Feature] Add guided decoding support for speculative decoding#4559windreamer wants to merge 22 commits into
Conversation
There was a problem hiding this comment.
Pull request overview
Adds guided decoding (JSON schema / regex / grammar via xgrammar) support to the PyTorch speculative decoding (MTP) path by propagating GuidedDecodingManager into spec decoding and applying grammar bitmasks during both draft proposal and target verification/rejection sampling.
Changes:
- Propagate
GuidedDecodingManagerintoSpecModelAgentand spec proposers, and apply position-serial grammar masking in spec decode verification. - Add draft-side grammar masking support for proposers that share the target vocab (e.g.,
DeepseekMTP), and asupports_grammar_maskcapability flag. - Add unit/integration/E2E tests and update EN/ZH docs for guided decoding with speculative decoding.
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
lmdeploy/pytorch/engine/model_agent/agent.py |
Propagates guided_decoding_manager into the speculative decoding agent and proposer. |
lmdeploy/pytorch/spec_decode/spec_agent.py |
Implements guided masking in spec decode verification, and expands/slices additional SamplingInputs fields. |
lmdeploy/pytorch/spec_decode/proposers/base.py |
Adds supports_grammar_mask and guided_decoding_manager plumb-through; extends get_outputs signature. |
lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py |
Applies grammar bitmask to draft logits (when provided) and advances forked matchers. |
lmdeploy/pytorch/spec_decode/proposers/eagle3.py |
Disables draft-side grammar masking via supports_grammar_mask = False. |
tests/pytorch/spec_decode/test_guided_spec_decode.py |
Unit tests for expand/slice behavior and guided-spec decode grammar mechanics. |
tests/pytorch/spec_decode/test_guided_spec_integration.py |
Higher-level integration tests for guided masking + rejection sampling state consistency. |
tests/test_lmdeploy/test_mtp_guided_decoding.py |
GPU integration tests for pipeline + MTP + guided decoding (schema/regex/json_object + streaming). |
docs/en/advance/spec_decoding.md |
Documents guided decoding usage with speculative decoding (EN). |
docs/zh_cn/advance/spec_decoding.md |
Documents guided decoding usage with speculative decoding (ZH). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 15 out of 15 changed files in this pull request and generated 5 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 15 out of 15 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 15 out of 15 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
84eac20 to
bb48caf
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 23 out of 23 changed files in this pull request and generated no new comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
- Implement position-serial grammar mask via forked GrammarMatchers - Propagate guided_decoding_manager from ModelAgent to SpecModelAgent - Apply grammar mask in DeepseekMTP proposer before token selection - Advance forked matcher state in Eagle3 proposer (no mask due to vocab mismatch) - Handle grammar state management after rejection sampling - Expand/slice sampling_inputs for non-tensor fields (response_formats, session_ctx, etc.) - Consolidate tests: 7 unit + 2 integration tests, 6 GPU e2e tests - Add MTP + guided decoding usage docs (en/zh_cn)
Eagle3's draft vocabulary differs from the target vocabulary, so a target-vocab grammar mask is inapplicable to draft logits. Add supports_grammar_mask class attribute to BaseSpecProposer (default True); Eagle3 overrides to False. spec_agent now gates the fork on this flag, and Eagle3.get_outputs() no longer accepts or processes guided_processors. Co-authored-by: openhands <openhands@all-hands.dev>
…ation - Eagle3.get_outputs() now applies grammar mask before argmax and accept_token after d2t mapping, matching DeepseekMTP pattern - Add _translate_bitmask() to convert target-vocab bitmask to draft-vocab bitmask via scatter_add_ (vectorized, no loops) - Remove supports_grammar_mask flag; all proposers now support it - Fork guided processors unconditionally in spec_agent._async_model_forward - Move session_to_cleanup handling before get_processors in forward_decode - Bump xgrammar>=0.1.33 (fork() requirement) in all 5 runtime requirements - Add comprehensive tests: bitmask translation, Eagle3 get_outputs, fork independence, multi-step draft loop
… device-to-host sync
In _guided_spec_logits_process, forked matchers were advanced using argmax of the masked target logits. In the non-greedy rejection sampling path, the actually accepted token can differ from argmax, causing subsequent grammar masks (especially the bonus-position mask) to be computed from an incorrect grammar state. Fix: advance forks using the known draft tokens for positions 0..num_spec_tokens-1. Target logits are conditioned on draft tokens, and rejection sampling discards positions after the first rejection, so the draft-token path is the only reachable one. The bonus position needs no advancement — the fork is discarded after the loop.
Extract fill_bitmap/accept_token loops in spec_agent.py into standalone helper functions and wrap them with asyncio.to_thread() to prevent CPU-bound xgrammar operations from blocking the asyncio event loop, which caused streaming token stuttering. - _accept_spec_rejection_tokens: accept tokens on original matchers - _fill_spec_bitmask: fill grammar bitmask for forked matchers - _accept_spec_forked_tokens: advance forked matchers with draft tokens
After moving accept_token out of FusedLogitsProcessor.sampling(), the prefill path in _rejection_sampling() was missing the call to advance the grammar matcher state. This caused guided decoding constraints to be silently ignored after the first prefill step, producing malformed JSON and non-matching regex output.
…date docs - Fix session_ctx incorrectly treated as global in _expand_sampling_inputs and _slice_sampling_inputs. Only session_to_cleanup is global; session_ctx is per-batch and must be expanded/sliced alongside response_formats. - Cache device-specific bitmask translation constants in Eagle3 via _get_bitmask_constants(), eliminating repeated .to(device) calls in _translate_bitmask. Pre-compute _n_draft_words at init time. - Rewrite test_rollback_then_accept_rejection_output as test_fork_strategy_rejection_output: replace rollback+double-accept logic with the production fork strategy (accept exactly the rejection-sampled tokens on the original matcher, no rollback needed). - Add clarifying comments in test_guided_spec_integration.py noting that production code advances forks with draft tokens while the simulation uses argmax as a stand-in. - Update EN/ZH docs with note about vocab translation for Eagle3 (target-vocab bitmask translated to draft-vocab via scatter-add).
…thread - BaseSpecProposer._apply_guided_bitmask and _accept_guided_tokens are now async with asyncio.to_thread wrapping CPU-bound xgrammar ops - DeepseekMTP.get_outputs and Eagle3.get_outputs are now async, calling the shared base methods instead of inline guided decoding - spec_agent.py call sites use await self.proposer.get_outputs(...) - GuidedDecodingManager.processors moved from class-level to instance-level - Move asyncio import to module top-level in base.py
Speculative decoding (ar_spec) requires FlashAttention-3 for the decode kernel. Without FA3, the engine would fail at runtime. Added flash_attn_v3_available() check to the module-level skipif condition alongside the existing GPU check.
Eagle3.get_outputs() is an async method but was called synchronously in TestEagle3GetOutputs, causing TypeError: cannot unpack non-iterable coroutine object. Wrap all 4 call sites with asyncio.run().
- Add guided_decoding_manager=None to SpecModelAgent mock objects - Make _DummyProposer.get_outputs async and accept guided_processors kwarg
…nt, precompute max word idx
… _prepare_guided_bitmask - _apply_guided_bitmask: rename to _prepare_guided_bitmask and update docstring to clarify it only allocates/fills the bitmask; callers (and Eagle3 vocab translation) are responsible for actual application. - _accept_guided_tokens: replace 'original grammar matchers' with 'provided grammar matchers' since they are typically forked copies created by SpecModelAgent.
Extract all guided-decoding logic specific to speculative decoding into a dedicated GuidedSpecHelper class. This replaces the scattered free functions (_accept_spec_rejection_tokens, _fill_spec_bitmask, _accept_spec_forked_tokens) and inline guided logic in spec_agent.py and BaseSpecProposer with a single, well-defined API. Key changes: - New GuidedSpecHelper class (guided_spec_helper.py) encapsulates: - Session lifecycle (cleanup_sessions, get_processors) - Draft-side bitmask (prepare_bitmask, apply_bitmask, accept_draft_tokens) - Target-side serial bitmask (apply_serial_bitmask with forked matchers) - Rejection-sampling-aware token acceptance (accept_rejection_sampled_tokens) - All public methods are null-safe: GuidedSpecHelper(manager=None) is a valid no-op instance, so callers never need to guard with if guided_helper: or if processors:. - Replaced guided_decoding_manager on SpecModelAgent/BaseSpecProposer with guided_helper (a GuidedSpecHelper instance, always set). - Removed _prepare_guided_bitmask and _accept_guided_tokens from BaseSpecProposer (subsumed by helper methods). - Simplified spec_agent.py: removed 3 free functions, removed all if guided_helper: / if guided_processors: guards, delegate to helper.
Avoid CUDA synchronization on the main async event loop by moving .cpu() transfers inside worker closures, as suggested in PR InternLM#4559 review comment r3371701042.
Motivation
Fixes #4551
When speculative decoding and guided decoding (JSON schema / regex / grammar) are both enabled, guided constraints are silently ignored — the
GuidedDecodingManageris never propagated into the speculative decoding path. This is a silent correctness issue: no error, no warning, just unconstrained output.Modification
Core change: propagate & apply grammar mask in spec decode
agent.py— Afterbuild_spec_agent(), propagateGuidedDecodingManagerto bothSpecModelAgentand itsproposer.spec_agent.py— Main integration:_async_model_forward: ForkGrammarMatchers for the draft model from the original guided processors; forked matchers are advanced in-place byget_outputs()at each draft step; originals remain untouched._rejection_sampling:_guided_spec_logits_process()— forked matchers provide per-position bitmasks for allnum_spec_tokens + 1target logits. After rejection sampling, accept the final output tokens on original matchers to advance their state correctly.guided_decoding_managertoFusedLogitsProcessor(standard path already handles it)._guided_spec_logits_process(): New method that (1) runs non-grammar logits processing (temperature, penalties), (2) applies per-position grammar bitmasks using forked matchers, advancing each fork with draft tokens (not argmax — target logits are conditioned on draft tokens, so the grammar state must follow the draft-token path), (3) returns processed logits for rejection sampling.deepseek_mtp.py— Acceptguided_processorsinget_outputs(). Apply grammar bitmask to draft logits beforeargmax, thenaccept_tokenon each forked matcher to advance its state for the next draft position.base.py— Addguided_decoding_managerattribute toBaseSpecProposer(set bySpecModelAgentafter construction). Addguided_processorsparameter toget_outputs()signature.eagle3.py— Support guided decoding via draft-to-target bitmask translation. Since Eagle3 draft vocabulary differs from the target vocabulary, a target-vocab grammar mask cannot be applied directly. Instead,_translate_bitmask()converts the target-vocab bitmask into a draft-vocab bitmask using thedraft_id_to_target_idmapping, then applies it to draft logits. Afterargmax+ token mapping,accept_tokenadvances each forked matcher's state.attention/__init__.py,configurations/utils.py,graph_runner.py,attention/fa3.py— Fix speculative decoding on non-SM90 CUDA GPUs: extend FA3 capability check from== 9(SM90 only) to>= 8(SM80+, Ampere and above) so that speculative decoding can use FA3's multi-token decode path. Add an early check inCUDAGraphRunner.__init__that raises a clearRuntimeErrorwhen speculative decoding is requested but FA3 is unavailable, instead of crashing in the Triton paged attention kernel.Streaming performance fix:
asyncio.to_threadfor CPU-bound xgrammar opsCPU-heavy xgrammar operations (
fill_bitmap,accept_token) were blocking the asyncio event loop during guided decoding, causing tokens to be returned in stuttered batches rather than smoothly streamed. Fix: extract these loops into standalone sync helpers and wrap calls withasyncio.to_thread()so they run off the event loop.logits_process.py— Extract_fill_guided_bitmask()and_accept_guided_tokens()as sync helpers; wrap calls withasyncio.to_thread().agent.py— Wrap_accept_guided_tokenscall withasyncio.to_thread().spec_agent.py— Extract_fill_spec_bitmask(),_accept_spec_forked_tokens(),_accept_spec_rejection_tokens()as sync helpers; wrap calls withasyncio.to_thread().spec_agent.py— Batch GPU→CPU tensor syncs before loops to avoid per-iteration device-to-host synchronization stalls.Helper changes
_expand_sampling_inputs/_slice_sampling_inputs: Handle additionalSamplingInputsfields (response_formats,session_ctx, etc.) so that guided-decoding–related inputs survive the expand/slice round-trip during rejection sampling.Tests
test_guided_spec_decode.py— Unit tests for_expand_sampling_inputs/_slice_sampling_inputswith guided fields,_guided_spec_logits_processbitmask application, andaccept_tokenstate advancement.test_guided_spec_integration.py— Integration tests (require xgrammar + GPU).test_mtp_guided_decoding.py— End-to-end pipeline tests (require xgrammar + GPU).Docs
spec_decoding.md(EN & ZH) with guided decoding usage notes.BC-breaking (Optional)
None. The
guided_processorsparameter inget_outputs()defaults toNone, so existing proposers that don't override it are unaffected.Checklist
_guided_spec_logits_process, expand/slice with guided fields).