Skip to content

[OMNIML-3994] Add SharedQuantState#1605

Open
sychen52 wants to merge 5 commits into
NVIDIA:mainfrom
sychen52:shared_quant_state
Open

[OMNIML-3994] Add SharedQuantState#1605
sychen52 wants to merge 5 commits into
NVIDIA:mainfrom
sychen52:shared_quant_state

Conversation

@sychen52
Copy link
Copy Markdown
Contributor

@sychen52 sychen52 commented Jun 2, 2026

What does this PR do?

Type of change: refactor

  • Add SharedQuantState: sibling weight quantizers (q/k/v, gate/up, per-expert w1/w3) calibrate to a single NVFP4 global_amax (the group max).
  • Discover groups via regex patterns (MaxCalibConfig.shared_patterns); max_calibrate attaches the state on the group parent, populates, and promotes.

Usage

same as before

Testing

added unittests.
run MSE before and after this change on Qwen3 and Qwen2 models.

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: / N/A
  • Did you write any new necessary tests?: ✅
  • Did you update Changelog?: N/A
  • Did you get Claude approval on this PR?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • Configurable regex patterns to group weight quantizers and control shared quantization state
    • Calibration API accepts shared_patterns to enable group-based shared-state flows
  • Improvements

    • Centralized shared-state lifecycle for consistent amax aggregation and promotion across siblings and ranks
    • Runtime shared-state back-references are non-serializable (won’t persist in saved models)
  • Tests

    • Added/updated tests for grouping, validation, and end-to-end calibration behavior

  - Add SharedQuantState: sibling weight quantizers (q/k/v, gate/up,
    per-expert w1/w3) calibrate to a single NVFP4 global_amax (the group max).
  - Discover groups via input-sharing hooks or regex patterns
    (MaxCalibConfig.shared_patterns); max_calibrate attaches the state on the
    group parent, populates, and promotes.
  - Extract out the collect_shared_input_modules helper and reuse it for this and export

Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
@sychen52 sychen52 requested a review from a team as a code owner June 2, 2026 16:37
@sychen52 sychen52 requested a review from mxinO June 2, 2026 16:37
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jun 2, 2026

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds regex-driven grouping and a SharedQuantState lifecycle to aggregate and sync weight amax across fusible sibling quantizers; exposes shared_patterns in config and max_calibrate, integrates shared-state attach/populate into NVFP4-static promotion, and adds unit tests for grouping, MoE, and config validation.

Changes

Shared NVFP4 Weight Quantization State

Layer / File(s) Summary
Config and calibration API contract
modelopt/torch/quantization/config.py, modelopt/torch/quantization/model_calib.py
MaxCalibConfig adds `shared_patterns: dict[str, list[str]]
Shared quantization state infrastructure
modelopt/torch/quantization/utils/shared_input.py
Adds DEFAULT_WEIGHT_SHARED_PATTERNS, SharedQuantState(nn.Module), find_shared_input_groups, attach_shared_quant_states, and populate_shared_state to discover fusible weight quantizers by regex, attach shared state on group parents, aggregate _amax, sync weight_global_amax, and write synced global_amax to NVFP4 static quantizers.
Integration into promotion and calibration flows
modelopt/torch/quantization/utils/core_utils.py, modelopt/torch/quantization/model_calib.py, modelopt/torch/quantization/utils/__init__.py, modelopt/torch/quantization/nn/modules/tensor_quantizer.py
promote_nvfp4_static_quantizers now prefers a valid _shared_quant_state_ref.weight_global_amax when present; max_calibrate attaches and populates shared state before/after distributed sync and uses DEFAULT_WEIGHT_SHARED_PATTERNS as fallback; utils __init__ re-exports shared-input utilities; TensorQuantizer marks _shared_quant_state_ref as runtime-only for save/restore.
Test coverage for shared state
tests/unit/torch/quantization/test_shared_input.py, tests/unit/torch/quantization/test_mse_calibrator.py
New tests validate SharedQuantState basics, MoE expert-local grouping, end-to-end max_calibrate behavior including disabling grouping via {"weight": []}, config validation; mse_calibrator tests updated to call attach/populate and use promote_nvfp4_static_quantizers.

Sequence Diagram(s)

sequenceDiagram
  participant Model
  participant Finder as find_shared_input_groups
  participant Attacher as attach_shared_quant_states
  participant Populator as populate_shared_state
  participant Promoter as promote_nvfp4_static_quantizers
  Model->>Finder: scan module FQNs with regex patterns
  Finder->>Attacher: return (parent, members) groups
  Attacher->>Model: attach SharedQuantState on parent, set _shared_quant_state_ref on members
  Model->>Populator: aggregate members' _amax into SharedQuantState.weight_global_amax
  Populator->>Populator: sync weight_global_amax across processes
  Populator->>Promoter: write synced global_amax into NVFP4StaticQuantizers
  Promoter->>Promoter: prefer shared weight_global_amax when valid during promotion
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • NVIDIA/Model-Optimizer#1560: Related changes to max_calibrate ensuring weight quantizers have _amax populated prior to promotion/sync.

Suggested reviewers

  • jenchen13
  • kinjalpatel27
🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 60.87% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[OMNIML-3994] Add SharedQuantState' is clear and directly describes the primary change—introducing a new SharedQuantState feature.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No security anti-patterns found: no unsafe torch.load, numpy.load, hardcoded trust_remote_code, eval/exec, nosec comments, or new dependencies. User regexes validated at config boundary.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 4

🧹 Nitpick comments (1)
modelopt/torch/quantization/utils/__init__.py (1)

21-27: ⚡ Quick win

Use a star re-export for shared_input.

This package already defines the public surface via shared_input.__all__ plus this module’s __all__. Re-exporting with from .shared_input import * keeps it aligned with the repository’s package-export convention and avoids future drift between the submodule and package surface.

As per coding guidelines, "Define the public API with __all__ at the top of each Python module and re-export submodules in __init__.py files using from .module import * to keep the public API explicit and make star-imports safe".

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/quantization/utils/__init__.py` around lines 21 - 27, The
module currently imports specific symbols from shared_input
(DEFAULT_WEIGHT_SHARED_PATTERNS, SharedQuantState, attach_shared_quant_states,
find_shared_input_groups, populate_shared_state) but should re-export the entire
public API defined by shared_input.__all__; replace the explicit imports in this
__init__.py with a star re-export (use from .shared_input import *) and ensure
this package's __all__ still combines or respects shared_input.__all__ so the
package export surface remains consistent with shared_input.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelopt/torch/quantization/config.py`:
- Around line 718-734: The shared_patterns field currently accepts arbitrary
keys/strings; add validation on the Model config that runs when the config is
parsed to (1) restrict keys to the supported set {'weight','input'} and raise a
clear error on unknown keys, and (2) pre-compile every regex string in
shared_patterns (using re.compile) and raise a ValidationError/ValueError for
any invalid pattern so bad configs fail fast; implement this as a pydantic
`@validator` (or root_validator) for shared_patterns in the same class that
replaces the dict[str,list[str]] with dict[str,list[re.Pattern]] (or stores
compiled patterns alongside) and reference the shared_patterns symbol in your
changes.

In `@modelopt/torch/quantization/model_calib.py`:
- Around line 237-238: The current expression for weight_patterns treats an
explicit empty list as falsy and substitutes DEFAULT_WEIGHT_SHARED_PATTERNS,
preventing callers from specifying {"weight": []}; change the logic in the block
that computes weight_patterns so it checks for key-presence instead of
truthiness: if shared_patterns is not None and "weight" in shared_patterns then
use shared_patterns["weight"] (even if empty), otherwise fall back to
DEFAULT_WEIGHT_SHARED_PATTERNS before calling attach_shared_quant_states(model,
patterns=weight_patterns); ensure this preserves behavior when shared_patterns
is None.

In `@modelopt/torch/quantization/utils/core_utils.py`:
- Around line 978-982: The code currently reuses module._shared_quant_state_ref
unconditionally which can be stale when promote_nvfp4_static_quantizers(model)
is run on submodules; update the branch so you only accept
shared.weight_global_amax if the shared state actually belongs to the current
promotion root/subtree: inspect shared for an ownership marker (e.g.
shared._owning_promotion_root or shared._owning_module/_owner) and compare it to
the current promotion root or module before using shared.weight_global_amax,
otherwise fall back to computing global_amax via
reduce_amax(amax.clone().detach(), axis=None); keep references to module,
_shared_quant_state_ref, weight_global_amax, reduce_amax and amax in your check
so the fallback behavior is unchanged when ownership does not match.

In `@modelopt/torch/quantization/utils/shared_input.py`:
- Around line 112-119: The all-reduce failure for syncing shared state should
fail fast instead of warning: in the block that calls dist.all_reduce (the one
operating on self.weight_global_amax with group=group.group), catch the
RuntimeError only to log/annotate if desired and then raise (re-raise the caught
exception or raise a new RuntimeError with context) so the process stops and
inconsistent weight_global_amax cannot be used; update the exception handling in
the method where self.weight_global_amax is synced accordingly.

---

Nitpick comments:
In `@modelopt/torch/quantization/utils/__init__.py`:
- Around line 21-27: The module currently imports specific symbols from
shared_input (DEFAULT_WEIGHT_SHARED_PATTERNS, SharedQuantState,
attach_shared_quant_states, find_shared_input_groups, populate_shared_state) but
should re-export the entire public API defined by shared_input.__all__; replace
the explicit imports in this __init__.py with a star re-export (use from
.shared_input import *) and ensure this package's __all__ still combines or
respects shared_input.__all__ so the package export surface remains consistent
with shared_input.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: bce3b754-50a1-4bf6-afa4-5dc0bd5196d4

📥 Commits

Reviewing files that changed from the base of the PR and between 8f96832 and 5f77e1a.

📒 Files selected for processing (7)
  • modelopt/torch/quantization/config.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/utils/__init__.py
  • modelopt/torch/quantization/utils/core_utils.py
  • modelopt/torch/quantization/utils/shared_input.py
  • tests/unit/torch/quantization/test_mse_calibrator.py
  • tests/unit/torch/quantization/test_shared_input.py

Comment thread modelopt/torch/quantization/config.py
Comment thread modelopt/torch/quantization/model_calib.py Outdated
Comment thread modelopt/torch/quantization/utils/core_utils.py
Comment thread modelopt/torch/quantization/utils/shared_input.py Outdated
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
@sychen52 sychen52 force-pushed the shared_quant_state branch from 5f77e1a to 4b9fca1 Compare June 2, 2026 19:55
@sychen52
Copy link
Copy Markdown
Contributor Author

sychen52 commented Jun 2, 2026

/claude review

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review summary — CRITICAL: 1, IMPORTANT: 0, SUGGESTION: 1

The pattern-based grouping, regex/key validation in MaxCalibConfig, parent-LCA attachment, EP+TP all-reduce sync, and the _shared_quant_state_ref-via-object.__setattr__ indirection (so member quantizers don't re-register the buffer per-child) are nicely thought through. The unit tests exercise singletons, the shared_expert_gate non-fusion case, MoE per-expert grouping, and the empty-patterns escape hatch directly.

Blocking issue:

  • [CRITICAL ModeState] _shared_quant_state.weight_global_amax enters state_dict but the restore path never re-creates the submodule. mto.restore does restore_from_modelopt_state (which only recreates TensorQuantizers) and then model_restored.load_state_dict(...) with default strict=True, so any real model with q/k/v or gate/up siblings calibrated post-PR fails to round-trip through mto.save/mto.restore. Megatron's quant_module_set_extra_state has the same gap. The existing state_dict test only saves/loads on the same already-calibrated model and so does not catch this. See the inline comment on SharedQuantState.__init__ for fix options (the cleanest is persistent=False on the buffer, since each promoted NVFP4StaticQuantizer already serializes its own _global_amax). Please pair the fix with a regression test that uses the existing tf_modelopt_state_and_output_tester harness on a model with q/k/v siblings.

Smaller item:

  • [SUGGESTION] _climb_past_modulelist should also climb past nn.ModuleDict so custom shared_patterns against dict-shaped MoE layouts don't end up registering _shared_quant_state into a ModuleDict's _modules and corrupting its iteration. Default patterns are safe today.

Holding off on approval until the restore path is fixed and covered by a test.

Comment thread modelopt/torch/quantization/utils/shared_input.py
Comment thread modelopt/torch/quantization/utils/shared_input.py
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
@sychen52 sychen52 force-pushed the shared_quant_state branch from 4b9fca1 to 85fe560 Compare June 2, 2026 20:40
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tests/unit/torch/quantization/test_shared_input.py`:
- Line 189: Hoist the in-function imports to module top: move "from
modelopt.torch.quantization.utils import promote_nvfp4_static_quantizers",
"import modelopt.torch.opt as mto", and "from modelopt.torch.quantization.config
import MaxCalibConfig" out of their tests and place them with the other
top-level imports; then remove the corresponding in-test import lines inside
test_promote_ignores_shared_state_outside_root,
test_modelopt_save_restore_with_shared_state, and
test_config_rejects_invalid_shared_patterns so import errors surface at
collection time and no circular/optional-import comments are needed.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 94e0730f-2271-4e49-a7bf-3b8209c961ed

📥 Commits

Reviewing files that changed from the base of the PR and between 4b9fca1 and 85fe560.

📒 Files selected for processing (7)
  • modelopt/torch/quantization/config.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py
  • modelopt/torch/quantization/utils/__init__.py
  • modelopt/torch/quantization/utils/core_utils.py
  • modelopt/torch/quantization/utils/shared_input.py
  • tests/unit/torch/quantization/test_shared_input.py
🚧 Files skipped from review as they are similar to previous changes (4)
  • modelopt/torch/quantization/utils/core_utils.py
  • modelopt/torch/quantization/utils/init.py
  • modelopt/torch/quantization/utils/shared_input.py
  • modelopt/torch/quantization/model_calib.py

Comment thread tests/unit/torch/quantization/test_shared_input.py Outdated
@codecov
Copy link
Copy Markdown

codecov Bot commented Jun 2, 2026

Codecov Report

❌ Patch coverage is 91.22807% with 15 lines in your changes missing coverage. Please review.
✅ Project coverage is 56.00%. Comparing base (8f96832) to head (c149968).
⚠️ Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/quantization/utils/shared_input.py 89.05% 15 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1605       +/-   ##
===========================================
- Coverage   77.41%   56.00%   -21.42%     
===========================================
  Files         480      481        +1     
  Lines       52499    52965      +466     
===========================================
- Hits        40642    29662    -10980     
- Misses      11857    23303    +11446     
Flag Coverage Δ
examples 15.14% <15.78%> (-25.67%) ⬇️
gpu 16.32% <40.35%> (-44.07%) ⬇️
regression 15.23% <15.78%> (+0.10%) ⬆️
unit 53.99% <91.22%> (+0.24%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

New SharedQuantState subsystem (363-line shared_input.py + plumbing in model_calib.py, core_utils.py, tensor_quantizer.py) replaces the earlier _sync_grouped_weight_global_amax (which delegated to preprocess_linear_fusion) with a regex-pattern–based framework for grouping fusible siblings. Tests cover per-group attach, MoE per-expert grouping, the save/restore round-trip (non-persistent buffer + ref filtered from get_modelopt_state), empty/invalid patterns, and stale-ref gating in promote_nvfp4_static_quantizers. Code quality is good and inline comments are unusually thorough.

Asking a human to sign off because:

  • Design alternatives not addressed in PR body. The repo already has preprocess_linear_fusion (modelopt/torch/export/quant_utils.py) doing the same global_amax unification at export time, and the legacy _sync_grouped_weight_global_amax already wired this into max_calibrate. The PR body does not explain why a new ~360-line subsystem is preferred over extending those (e.g. replacing the hardcoded _GATE_UP_PAIRS + Q/K/V tuple with a configurable list, or moving the preprocess_linear_fusion call earlier in max_calibrate). The forward-looking "we'll add input scales / LoRA factors to it later" rationale appears in code comments but isn't justified against the simpler path of extending the existing helper. The "why pattern-based vs input-hook-based" question is addressed (shared_expert_gate counterexample) — but that's the easy half of the design question.
  • Semantic shift in max_calibrate ordering. Previously _promote_nvfp4_static_quantizers_with_global_amax_sync ran before the distributed_sync early return; in the new flow promotion happens after TP/DP/EP _amax sync (via populate_shared_statepromote_nvfp4_static_quantizers). For NVFP4 static-block weights this changes which global_amax value MSE calibration and downstream consumers observe (now the cross-rank-unified group max, vs. the pre-sync per-quantizer value). The PR description doesn't call this out, and there's no GPU/multi-rank test covering the EP/TP path that the new SharedQuantState.sync_weight_global_amax is responsible for.
  • SharedQuantState is attached for every enabled weight quantizer matching the patterns, regardless of format (FP8, INT8, etc.), not just NVFP4-static. The only consumer (promote_nvfp4_static_quantizers) is gated by is_nvfp4_static, so non-NVFP4 runs just carry a dangling state on every Q/K/V parent. Functionally benign today, but worth confirming this is intentional and that the dangling submodule (with a non-persistent buffer) doesn't trip up other paths (e.g. replace_quant_module, FSDP wrapping, sharded_state_dict) — none of the new tests exercise this case.
  • Size / scope. 1021-line PR, single new ~360-line framework file plus a 457-line test file. Borderline against the soft size guideline; could plausibly be split into "introduce SharedQuantState container" and "wire it into max_calibrate / promote", but the changes are cohesive enough that splitting isn't strictly required.

Comment thread tests/unit/torch/quantization/test_shared_input.py Outdated
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
@sychen52 sychen52 force-pushed the shared_quant_state branch from 0cd6157 to c149968 Compare June 2, 2026 22:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants