[OMNIML-3994] Add SharedQuantState#1605
Conversation
- Add SharedQuantState: sibling weight quantizers (q/k/v, gate/up,
per-expert w1/w3) calibrate to a single NVFP4 global_amax (the group max).
- Discover groups via input-sharing hooks or regex patterns
(MaxCalibConfig.shared_patterns); max_calibrate attaches the state on the
group parent, populates, and promotes.
- Extract out the collect_shared_input_modules helper and reuse it for this and export
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds regex-driven grouping and a SharedQuantState lifecycle to aggregate and sync weight amax across fusible sibling quantizers; exposes ChangesShared NVFP4 Weight Quantization State
Sequence Diagram(s)sequenceDiagram
participant Model
participant Finder as find_shared_input_groups
participant Attacher as attach_shared_quant_states
participant Populator as populate_shared_state
participant Promoter as promote_nvfp4_static_quantizers
Model->>Finder: scan module FQNs with regex patterns
Finder->>Attacher: return (parent, members) groups
Attacher->>Model: attach SharedQuantState on parent, set _shared_quant_state_ref on members
Model->>Populator: aggregate members' _amax into SharedQuantState.weight_global_amax
Populator->>Populator: sync weight_global_amax across processes
Populator->>Promoter: write synced global_amax into NVFP4StaticQuantizers
Promoter->>Promoter: prefer shared weight_global_amax when valid during promotion
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 4
🧹 Nitpick comments (1)
modelopt/torch/quantization/utils/__init__.py (1)
21-27: ⚡ Quick winUse a star re-export for
shared_input.This package already defines the public surface via
shared_input.__all__plus this module’s__all__. Re-exporting withfrom .shared_input import *keeps it aligned with the repository’s package-export convention and avoids future drift between the submodule and package surface.As per coding guidelines, "Define the public API with
__all__at the top of each Python module and re-export submodules in__init__.pyfiles usingfrom .module import *to keep the public API explicit and make star-imports safe".🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/quantization/utils/__init__.py` around lines 21 - 27, The module currently imports specific symbols from shared_input (DEFAULT_WEIGHT_SHARED_PATTERNS, SharedQuantState, attach_shared_quant_states, find_shared_input_groups, populate_shared_state) but should re-export the entire public API defined by shared_input.__all__; replace the explicit imports in this __init__.py with a star re-export (use from .shared_input import *) and ensure this package's __all__ still combines or respects shared_input.__all__ so the package export surface remains consistent with shared_input.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@modelopt/torch/quantization/config.py`:
- Around line 718-734: The shared_patterns field currently accepts arbitrary
keys/strings; add validation on the Model config that runs when the config is
parsed to (1) restrict keys to the supported set {'weight','input'} and raise a
clear error on unknown keys, and (2) pre-compile every regex string in
shared_patterns (using re.compile) and raise a ValidationError/ValueError for
any invalid pattern so bad configs fail fast; implement this as a pydantic
`@validator` (or root_validator) for shared_patterns in the same class that
replaces the dict[str,list[str]] with dict[str,list[re.Pattern]] (or stores
compiled patterns alongside) and reference the shared_patterns symbol in your
changes.
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 237-238: The current expression for weight_patterns treats an
explicit empty list as falsy and substitutes DEFAULT_WEIGHT_SHARED_PATTERNS,
preventing callers from specifying {"weight": []}; change the logic in the block
that computes weight_patterns so it checks for key-presence instead of
truthiness: if shared_patterns is not None and "weight" in shared_patterns then
use shared_patterns["weight"] (even if empty), otherwise fall back to
DEFAULT_WEIGHT_SHARED_PATTERNS before calling attach_shared_quant_states(model,
patterns=weight_patterns); ensure this preserves behavior when shared_patterns
is None.
In `@modelopt/torch/quantization/utils/core_utils.py`:
- Around line 978-982: The code currently reuses module._shared_quant_state_ref
unconditionally which can be stale when promote_nvfp4_static_quantizers(model)
is run on submodules; update the branch so you only accept
shared.weight_global_amax if the shared state actually belongs to the current
promotion root/subtree: inspect shared for an ownership marker (e.g.
shared._owning_promotion_root or shared._owning_module/_owner) and compare it to
the current promotion root or module before using shared.weight_global_amax,
otherwise fall back to computing global_amax via
reduce_amax(amax.clone().detach(), axis=None); keep references to module,
_shared_quant_state_ref, weight_global_amax, reduce_amax and amax in your check
so the fallback behavior is unchanged when ownership does not match.
In `@modelopt/torch/quantization/utils/shared_input.py`:
- Around line 112-119: The all-reduce failure for syncing shared state should
fail fast instead of warning: in the block that calls dist.all_reduce (the one
operating on self.weight_global_amax with group=group.group), catch the
RuntimeError only to log/annotate if desired and then raise (re-raise the caught
exception or raise a new RuntimeError with context) so the process stops and
inconsistent weight_global_amax cannot be used; update the exception handling in
the method where self.weight_global_amax is synced accordingly.
---
Nitpick comments:
In `@modelopt/torch/quantization/utils/__init__.py`:
- Around line 21-27: The module currently imports specific symbols from
shared_input (DEFAULT_WEIGHT_SHARED_PATTERNS, SharedQuantState,
attach_shared_quant_states, find_shared_input_groups, populate_shared_state) but
should re-export the entire public API defined by shared_input.__all__; replace
the explicit imports in this __init__.py with a star re-export (use from
.shared_input import *) and ensure this package's __all__ still combines or
respects shared_input.__all__ so the package export surface remains consistent
with shared_input.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: bce3b754-50a1-4bf6-afa4-5dc0bd5196d4
📒 Files selected for processing (7)
modelopt/torch/quantization/config.pymodelopt/torch/quantization/model_calib.pymodelopt/torch/quantization/utils/__init__.pymodelopt/torch/quantization/utils/core_utils.pymodelopt/torch/quantization/utils/shared_input.pytests/unit/torch/quantization/test_mse_calibrator.pytests/unit/torch/quantization/test_shared_input.py
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
5f77e1a to
4b9fca1
Compare
|
/claude review |
There was a problem hiding this comment.
Review summary — CRITICAL: 1, IMPORTANT: 0, SUGGESTION: 1
The pattern-based grouping, regex/key validation in MaxCalibConfig, parent-LCA attachment, EP+TP all-reduce sync, and the _shared_quant_state_ref-via-object.__setattr__ indirection (so member quantizers don't re-register the buffer per-child) are nicely thought through. The unit tests exercise singletons, the shared_expert_gate non-fusion case, MoE per-expert grouping, and the empty-patterns escape hatch directly.
Blocking issue:
- [CRITICAL ModeState]
_shared_quant_state.weight_global_amaxentersstate_dictbut the restore path never re-creates the submodule.mto.restoredoesrestore_from_modelopt_state(which only recreatesTensorQuantizers) and thenmodel_restored.load_state_dict(...)with defaultstrict=True, so any real model with q/k/v or gate/up siblings calibrated post-PR fails to round-trip throughmto.save/mto.restore. Megatron'squant_module_set_extra_statehas the same gap. The existing state_dict test only saves/loads on the same already-calibrated model and so does not catch this. See the inline comment onSharedQuantState.__init__for fix options (the cleanest ispersistent=Falseon the buffer, since each promotedNVFP4StaticQuantizeralready serializes its own_global_amax). Please pair the fix with a regression test that uses the existingtf_modelopt_state_and_output_testerharness on a model with q/k/v siblings.
Smaller item:
- [SUGGESTION]
_climb_past_modulelistshould also climb pastnn.ModuleDictso customshared_patternsagainst dict-shaped MoE layouts don't end up registering_shared_quant_stateinto aModuleDict's_modulesand corrupting its iteration. Default patterns are safe today.
Holding off on approval until the restore path is fixed and covered by a test.
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
4b9fca1 to
85fe560
Compare
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/unit/torch/quantization/test_shared_input.py`:
- Line 189: Hoist the in-function imports to module top: move "from
modelopt.torch.quantization.utils import promote_nvfp4_static_quantizers",
"import modelopt.torch.opt as mto", and "from modelopt.torch.quantization.config
import MaxCalibConfig" out of their tests and place them with the other
top-level imports; then remove the corresponding in-test import lines inside
test_promote_ignores_shared_state_outside_root,
test_modelopt_save_restore_with_shared_state, and
test_config_rejects_invalid_shared_patterns so import errors surface at
collection time and no circular/optional-import comments are needed.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 94e0730f-2271-4e49-a7bf-3b8209c961ed
📒 Files selected for processing (7)
modelopt/torch/quantization/config.pymodelopt/torch/quantization/model_calib.pymodelopt/torch/quantization/nn/modules/tensor_quantizer.pymodelopt/torch/quantization/utils/__init__.pymodelopt/torch/quantization/utils/core_utils.pymodelopt/torch/quantization/utils/shared_input.pytests/unit/torch/quantization/test_shared_input.py
🚧 Files skipped from review as they are similar to previous changes (4)
- modelopt/torch/quantization/utils/core_utils.py
- modelopt/torch/quantization/utils/init.py
- modelopt/torch/quantization/utils/shared_input.py
- modelopt/torch/quantization/model_calib.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1605 +/- ##
===========================================
- Coverage 77.41% 56.00% -21.42%
===========================================
Files 480 481 +1
Lines 52499 52965 +466
===========================================
- Hits 40642 29662 -10980
- Misses 11857 23303 +11446
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
cjluo-nv
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
New SharedQuantState subsystem (363-line shared_input.py + plumbing in model_calib.py, core_utils.py, tensor_quantizer.py) replaces the earlier _sync_grouped_weight_global_amax (which delegated to preprocess_linear_fusion) with a regex-pattern–based framework for grouping fusible siblings. Tests cover per-group attach, MoE per-expert grouping, the save/restore round-trip (non-persistent buffer + ref filtered from get_modelopt_state), empty/invalid patterns, and stale-ref gating in promote_nvfp4_static_quantizers. Code quality is good and inline comments are unusually thorough.
Asking a human to sign off because:
- Design alternatives not addressed in PR body. The repo already has
preprocess_linear_fusion(modelopt/torch/export/quant_utils.py) doing the sameglobal_amaxunification at export time, and the legacy_sync_grouped_weight_global_amaxalready wired this intomax_calibrate. The PR body does not explain why a new ~360-line subsystem is preferred over extending those (e.g. replacing the hardcoded_GATE_UP_PAIRS+ Q/K/V tuple with a configurable list, or moving thepreprocess_linear_fusioncall earlier inmax_calibrate). The forward-looking "we'll add input scales / LoRA factors to it later" rationale appears in code comments but isn't justified against the simpler path of extending the existing helper. The "why pattern-based vs input-hook-based" question is addressed (shared_expert_gate counterexample) — but that's the easy half of the design question. - Semantic shift in
max_calibrateordering. Previously_promote_nvfp4_static_quantizers_with_global_amax_syncran before thedistributed_syncearly return; in the new flow promotion happens after TP/DP/EP_amaxsync (viapopulate_shared_state→promote_nvfp4_static_quantizers). For NVFP4 static-block weights this changes whichglobal_amaxvalue MSE calibration and downstream consumers observe (now the cross-rank-unified group max, vs. the pre-sync per-quantizer value). The PR description doesn't call this out, and there's no GPU/multi-rank test covering the EP/TP path that the newSharedQuantState.sync_weight_global_amaxis responsible for. SharedQuantStateis attached for every enabled weight quantizer matching the patterns, regardless of format (FP8, INT8, etc.), not just NVFP4-static. The only consumer (promote_nvfp4_static_quantizers) is gated byis_nvfp4_static, so non-NVFP4 runs just carry a dangling state on every Q/K/V parent. Functionally benign today, but worth confirming this is intentional and that the dangling submodule (with a non-persistent buffer) doesn't trip up other paths (e.g.replace_quant_module, FSDP wrapping, sharded_state_dict) — none of the new tests exercise this case.- Size / scope. 1021-line PR, single new ~360-line framework file plus a 457-line test file. Borderline against the soft size guideline; could plausibly be split into "introduce SharedQuantState container" and "wire it into max_calibrate / promote", but the changes are cohesive enough that splitting isn't strictly required.
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
0cd6157 to
c149968
Compare
What does this PR do?
Type of change: refactor
Usage
same as before
Testing
added unittests.
run MSE before and after this change on Qwen3 and Qwen2 models.
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: / N/AAdditional Information
Summary by CodeRabbit
New Features
Improvements
Tests