Skip to content

[OMNIML-3994] Make sure all weight quantizers have _amax#1560

Open
sychen52 wants to merge 1 commit into
NVIDIA:mainfrom
sychen52:weight_only_quantize
Open

[OMNIML-3994] Make sure all weight quantizers have _amax#1560
sychen52 wants to merge 1 commit into
NVIDIA:mainfrom
sychen52:weight_only_quantize

Conversation

@sychen52
Copy link
Copy Markdown
Contributor

@sychen52 sychen52 commented May 28, 2026

What does this PR do?

  • max_calibrate now always runs weight_only_quantize before the optional forward_loop, so every weight quantizer has _amax populated regardless of MoE routing.
  • awq_lite search loop explicitly deletes _amax around the alpha sweep to keep using the dynamic-amax path; no restore needed (postprocess overwrites for enabled modules, disabled modules early-return).
  • Removed three band-aids:
    • _bootstrap_uncalibrated_weight_quantizers (mse_calibrate)
    • per-module max_calibrate fallback in awq_lite postprocess
    • _ensure_weight_quantizer_calibrated + helpers (export/quant_utils)
  • Renamed test_bootstrap_populates_dead_expert_quantizers -> test_max_calibrate_populates_dead_expert_quantizers; deleted the GPU lazy-calibration test that no longer has a behavior to exercise.
    Type of change: refactor

Usage

same as before.

Testing

run unittests locally

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: ✅
  • Did you update Changelog?: ❌
  • Did you get Claude approval on this PR?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • Refactor

    • Simplified the weight quantization calibration process by consolidating logic and removing redundant internal functions.
  • Tests

    • Removed a redundant quantization format-specific export test.

Review Change Stack

@sychen52 sychen52 requested review from a team as code owners May 28, 2026 20:10
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 28, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR removes the lazy weight calibration system for NVFP4 quantizers by deleting the bootstrap helper, refactoring max_calibrate to always run weight-only quantization unconditionally, removing lazy-calibration utilities from weight scaling computation, and cleaning up related tests.

Changes

Lazy weight calibration removal

Layer / File(s) Summary
Remove bootstrap helper and refactor calibration pipeline
modelopt/torch/quantization/model_calib.py
_bootstrap_uncalibrated_weight_quantizers helper is deleted; max_calibrate is refactored to explicitly call enable_stats_collection, unconditionally run weight_only_quantize, conditionally run the provided forward_loop, then call finish_stats_collection, replacing the prior _run_and_load_max_stats path; the bootstrap call after NVFP4 static validation is removed.
Remove lazy calibration from weight scaling computation
modelopt/torch/export/quant_utils.py
reduce_block_amax import is removed; internal NVFP4 calibration helpers (_get_nvfp4_block_size, _set_amax_from_tensor, _ensure_weight_quantizer_calibrated) are deleted; lazy-calibration calls in get_weight_scaling_factor and get_weight_scaling_factor_2 NVFP4 branches are removed, eliminating on-demand amax computation during export.
Update tests for lazy-calibration removal
tests/gpu/torch/export/test_export_weight_gpu.py
NVFP4-specific imports (partial_nvfp4_config, NVFP4StaticQuantizer, reduce_block_amax) are replaced with shared weight-export utilities; the NVFP4 dynamic-vs-static export test is deleted along with its CUDA/NVFP4 quantization setup and amax buffer assertions.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

  • NVIDIA/Model-Optimizer#1536: Modifies NVFP4 static weight calibration flow in max_calibrate/MSE calibration for bootstrapping and promoting static NVFP4 quantizers.
  • NVIDIA/Model-Optimizer#1501: Adds Megatron calibration forward_loop helper passed into calibration routines, aligning with max_calibrate's explicit forward_loop invocation.

Suggested reviewers

  • jenchen13
  • realAsma
  • Fridah-nv
🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 60.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly references the main objective of this PR - ensuring all weight quantizers have the _amax field populated. This directly relates to the core changes which refactor the calibration pipeline to run weight_only_quantize unconditionally and remove legacy workarounds.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No security anti-patterns detected. PR removes legacy helpers without introducing unsafe torch.load, numpy.load, hardcoded trust_remote_code, eval/exec, nosec comments, or new unsafe dependencies.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@sychen52
Copy link
Copy Markdown
Contributor Author

/claude review

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
modelopt/torch/quantization/model_calib.py (1)

1328-1333: ⚡ Quick win

Remove the unreachable else path in postprocess.

At Line 1389-Line 1391, postprocess() is only called when module.awq_lite.is_enabled is true, so the else block at Line 1328-Line 1333 is dead code.

♻️ Proposed simplification
-        if module.awq_lite.is_enabled:
-            apply_pre_quant_scale_and_smooth(module, 1.0 / module.awq_lite.best_scale)
-        else:
-            # ``max_calibrate`` already set ``_amax`` for this module via the
-            # always-on ``weight_only_quantize`` step; AWQ is just disabled
-            # because no cache tokens flowed through it. Falling back to
-            # neutral per-tensor max calibration is the right thing — and
-            # already done by ``max_calibrate`` — so just warn.
-            warnings.warn(f"awq_lite: Disabling for {name}, quantizing with max calibration.")
+        apply_pre_quant_scale_and_smooth(module, 1.0 / module.awq_lite.best_scale)

As per coding guidelines "Remove dead code including unused imports, unreachable branches, and obsolete helpers".

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/quantization/model_calib.py` around lines 1328 - 1333, The
else branch in postprocess is unreachable because postprocess() is only invoked
when module.awq_lite.is_enabled is true; remove the entire else block (the
warnings.warn call and its comment) from the postprocess function to eliminate
dead code and keep only the enabled-path logic (references: postprocess,
module.awq_lite.is_enabled, the warnings.warn line that mentions "awq_lite:
Disabling for {name}...").
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 1328-1333: The else branch in postprocess is unreachable because
postprocess() is only invoked when module.awq_lite.is_enabled is true; remove
the entire else block (the warnings.warn call and its comment) from the
postprocess function to eliminate dead code and keep only the enabled-path logic
(references: postprocess, module.awq_lite.is_enabled, the warnings.warn line
that mentions "awq_lite: Disabling for {name}...").

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 70dbcbb8-9cdb-4aec-aa70-cee556e618f6

📥 Commits

Reviewing files that changed from the base of the PR and between a2c496a and 60737c0.

📒 Files selected for processing (4)
  • modelopt/torch/export/quant_utils.py
  • modelopt/torch/quantization/model_calib.py
  • tests/gpu/torch/export/test_export_weight_gpu.py
  • tests/unit/torch/quantization/plugins/test_fused_experts.py
💤 Files with no reviewable changes (1)
  • modelopt/torch/export/quant_utils.py

@sychen52 sychen52 self-assigned this May 28, 2026
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Review Summary

Findings: 0 CRITICAL, 0 IMPORTANT, 1 SUGGESTION

Overall assessment

This is a clean, well-scoped refactor. The core change — making max_calibrate always run weight_only_quantize before the optional forward_loop — establishes a stronger invariant (every weight quantizer has _amax after max_calibrate, regardless of MoE routing) and lets several downstream band-aids be deleted:

  • _bootstrap_uncalibrated_weight_quantizers in mse_calibrate
  • per-module max_calibrate fallback in awq_lite.postprocess's disabled branch
  • _ensure_weight_quantizer_calibrated lazy calibration during HF export

The change holds together end-to-end:

  • For mse_calibrate / local_hessian_calibrate / gptq / svdquant, the pre-existing call to max_calibrate now naturally covers dead experts.
  • The new delattr(self.weight_quantizer, "_amax") inside awq_lite.forward correctly reverts to the dynamic _get_amax path so the per-alpha sweep recomputes max(|weight * pre_quant_scale|). For NVFP4StaticQuantizer, _fake_quantize falls back to super()._fake_quantize when self.amax is None, so dynamic recomputation works there too.
  • apply_pre_quant_scale_and_smooth (called in the enabled-postprocess path) and the disabled-module branch of the outer postprocess loop both re-populate _amax afterward.
  • Export call sites (get_weight_scaling_factor, get_weight_scaling_factor_2) rely on the new invariant; this is safe because all mtq.quantize algorithms route through max_calibrate.

Test changes are consistent with the refactor (renamed dead-expert test + removal of the lazy-calibration test that no longer has a behavior to exercise).

Notable observation

  • One SUGGESTION (inline): the else arm in awq_lite.postprocess is now unreachable — the outer loop only calls postprocess when module.awq_lite.is_enabled is True. The arm could be deleted to remove a misleading comment and the duplicated warning.

Risk

Low. Backward-compatible per the PR description, and the new invariant (_amax always populated by max_calibrate) is strictly stronger than before.

LGTM.

Comment thread modelopt/torch/quantization/model_calib.py
@codecov
Copy link
Copy Markdown

codecov Bot commented May 28, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 77.17%. Comparing base (d7e72f4) to head (311f3a5).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1560      +/-   ##
==========================================
+ Coverage   76.65%   77.17%   +0.51%     
==========================================
  Files         478      478              
  Lines       52408    52388      -20     
==========================================
+ Hits        40172    40428     +256     
+ Misses      12236    11960     -276     
Flag Coverage Δ
examples 41.63% <100.00%> (+8.72%) ⬆️
gpu 59.76% <100.00%> (-0.27%) ⬇️
regression 15.19% <0.00%> (-0.03%) ⬇️
unit 53.62% <100.00%> (+0.02%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@sychen52 sychen52 force-pushed the weight_only_quantize branch from 60737c0 to 238546a Compare May 28, 2026 22:02
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 28, 2026

Actionable comments posted: 0

Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Nice simplification — removing the three band-aid paths in favor of populating _amax once in max_calibrate is the right direction, and the test rename + assertion update on _TinyMoEModel look correct. Flagging for human sign-off on a few points:

  • AWQ lite behavior change is the most fragile piece and lacks a dedicated test. The new delattr(self.weight_quantizer, "_amax") inside awq_lite's forward (model_calib.py L1209-1215) is the load-bearing change that keeps the per-alpha dynamic-amax path working now that weight_only_quantize populates _amax up front. The PR rationale ("postprocess overwrites for enabled modules, disabled modules early-return") is plausible — apply_pre_quant_scale_and_smooth_apply_weight_pre_quant_scale repopulates for enabled modules, and the new assert module.awq_lite.is_enabled plus inline max_calibrate fallback handles disabled ones — but there is no new test exercising AWQ with dead/uncalibrated experts. The existing renamed test only covers plain max_calibrate. Worth either adding an AWQ-MoE-with-dead-experts test or explicitly listing which existing GPU test exercises this combination.

  • Export-time safety net removed without a deprecation path. _ensure_weight_quantizer_calibrated previously warned and re-derived _amax from the weight if export saw an uncalibrated NVFP4 quantizer. With this PR, the same situation will dereference weight_quantizer._amax.float() / 448.0 (quant_utils.py L292, L324) and crash with AttributeError: 'NoneType' object has no attribute 'float'. Anyone loading a partial checkpoint and exporting without re-running mtq.quantize/max_calibrate will hit this. The PR claims "backward compatible" — that's true for the documented happy path, but the failure mode for misuse changes from "warn + recover" to "hard crash". Worth confirming this is intentional and ideally raising a clearer error than AttributeError in get_weight_scaling_factor[_2] when _amax is None.

  • Deleted test_export_nvfp4_static_weight_dynamic_vs_static_match was checking two things, not just lazy calibration: (a) dynamic vs static NVFP4 export produce matching weight/scales, and (b) lazy fill from weights when _amax is cleared. (b) is dead, agreed — but (a) is still a useful invariant. Consider keeping a trimmed version that exercises (a) without the manual delattr setup.

  • weight_only_quantize now runs on every max_calibrate call, including the recursive ones in _apply_weight_pre_quant_scale, smoothquant.postprocess, the awq_lite disabled-fallback, and awq SequentialQuantizer recalibration. For these, the lambda already calibrates from module.weight, so we now do that work twice on the same tensor. Functionally a no-op (idempotent), but the PR doesn't note the duplicated work and large-model calibration time may tick up slightly.

None of these are blockers — the structural cleanup is good and the diff is net -200 lines. Asking a human to confirm the AWQ test coverage and the export-error UX before approval.


Additional comments (outside the PR diff):

  • modelopt/torch/export/quant_utils.py:292 — > Bot comment.

With _ensure_weight_quantizer_calibrated gone, this weight_quantizer._amax.float() / 448.0 (and the analogous one in get_weight_scaling_factor_2) will raise AttributeError: 'NoneType' object has no attribute 'float' instead of warning + recomputing if a user exports a checkpoint that wasn't fully calibrated. Worth at least adding an explicit if weight_quantizer._amax is None: raise RuntimeError("Weight quantizer ... has no _amax; run max_calibrate before export") so the failure points at the cause, not at a generic NoneType deref.

  • tests/gpu/torch/export/test_export_weight_gpu.py:124 — > Bot comment.

The deleted test was checking two things: (a) dynamic vs static NVFP4 export produces matching weight + scales, and (b) lazy fill from _ensure_weight_quantizer_calibrated. (b) is dead, agreed. But (a) — that the static-quantizer export path matches the dynamic one when both are properly calibrated — is still a useful invariant and isn't covered elsewhere as far as I can see. Consider keeping a trimmed version of the test that drops the delattr setup but still asserts the dynamic-vs-static export equivalence.

Comment thread modelopt/torch/quantization/model_calib.py Outdated
Comment thread modelopt/torch/quantization/model_calib.py
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Re-review: the structural cleanup is the right direction and the unreachable else in awq_lite.postprocess plus the docstring nit on max_calibrate look addressed (:speech_balloon: author replied "Done" on the duplicated-calibration note, and the else arm has been replaced with an assert module.awq_lite.is_enabled). However, several substantive points from cjluo-nv's previous human review still appear unresolved — flagging for human sign-off:

  • AWQ dead-expert test still missing. The delattr(self.weight_quantizer, "_amax") inside awq_lite.forward (model_calib.py L1209-1215) is the load-bearing change that keeps the per-alpha dynamic-amax path working now that weight_only_quantize populates _amax up front. The renamed test_max_calibrate_populates_dead_expert_quantizers only exercises plain max_calibrate, not awq_lite with dead/uncalibrated experts. If anyone refactors forward and forgets the delattr, per-alpha amax will be silently wrong with no test to catch it. Author hasn't responded to this point.

  • Export-time AttributeError UX not fixed. With _ensure_weight_quantizer_calibrated gone, weight_quantizer._amax.float() / 448.0 in get_weight_scaling_factor[_2] (quant_utils.py:288, :324 in the post-diff file) will now raise AttributeError: 'NoneType' object has no attribute 'float' instead of the previous warn+recover for a checkpoint exported without re-running max_calibrate. The previous review asked for an explicit if weight_quantizer._amax is None: raise RuntimeError("... run max_calibrate before export") so the failure points at the cause rather than a generic NoneType deref. Not addressed in the current diff.

  • Dropped dynamic-vs-static NVFP4 export equivalence assertion. The deleted test_export_nvfp4_static_weight_dynamic_vs_static_match was checking two things: (a) dynamic vs static NVFP4 export produce matching weight/scales; (b) lazy fill from weights when _amax is cleared. (b) is dead code now, agreed — but (a) is still a useful invariant and isn't covered elsewhere as far as I can see. Previous review suggested keeping a trimmed version that drops the delattr setup but still asserts (a). Not addressed.

Net direction is good (-200 lines, simpler invariant), but please confirm test coverage for AWQ dead experts and decide whether to (1) add an explicit error in get_weight_scaling_factor[_2] and (2) keep a trimmed dynamic-vs-static equivalence test before approving.

@sychen52 sychen52 force-pushed the weight_only_quantize branch from 238546a to e10bfce Compare May 28, 2026 23:54

# For dead experts, bootstrap reads max(|weight|). Sanity-check it matches
# the actual weight tensor's per-row max (axis=0 reduces over hidden_dim).
# For dead experts, ``_amax`` comes purely from ``weight_only_quantize``
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, what is a dead expert?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dead means there is no calibration data pass through it. So its weight_quantizer is also not run, and there is not _amax in the weight_quantizer.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guessed, but this term is not correct. They are not dead, just not getting routed for the calibration tokens. But anyway, it's just a test, not a big deal.

@sychen52 sychen52 changed the title [Quantization] Make sure all weight quantizers have _amax [OMNIML-3994] Make sure all weight quantizers have _amax May 29, 2026
@sychen52 sychen52 force-pushed the weight_only_quantize branch from e10bfce to 3b89895 Compare May 29, 2026 17:41
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/unit/torch/quantization/test_calib.py (1)

419-420: ⚡ Quick win

Avoid .item() in tensor assertions for torch test paths.

Use tensor-native comparison to avoid Python scalar extraction in this path.

Suggested change
-    assert dead_q._amax.abs().max().item() == pytest.approx(
-        original_uncalibrated_weight.abs().max().item(), abs=1e-6
-    )
+    torch.testing.assert_close(
+        dead_q._amax.abs().max(),
+        original_uncalibrated_weight.abs().max(),
+        atol=1e-6,
+        rtol=0.0,
+    )

As per coding guidelines **/torch/**/*.py: "Keep tensor work on the GPU and avoid unnecessary CPU-GPU syncs; avoid Python scalar extraction operators like tensor.item(), float(tensor), or min(tensor)".

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unit/torch/quantization/test_calib.py` around lines 419 - 420, Replace
the scalar .item()-based assertion with a tensor-native comparison to avoid
CPU/GPU sync: compare the two tensors dead_q._amax.abs().max() and
original_uncalibrated_weight.abs().max() directly using a tensor-aware equality
check (e.g. torch.allclose or torch.testing.assert_allclose) with atol=1e-6 so
the test keeps tensors on device and does not extract Python scalars.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@tests/unit/torch/quantization/test_calib.py`:
- Around line 419-420: Replace the scalar .item()-based assertion with a
tensor-native comparison to avoid CPU/GPU sync: compare the two tensors
dead_q._amax.abs().max() and original_uncalibrated_weight.abs().max() directly
using a tensor-aware equality check (e.g. torch.allclose or
torch.testing.assert_allclose) with atol=1e-6 so the test keeps tensors on
device and does not extract Python scalars.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: cbd6496d-92ab-42b9-8bfb-2b8cdc12157c

📥 Commits

Reviewing files that changed from the base of the PR and between e10bfce and 3b89895.

📒 Files selected for processing (4)
  • modelopt/torch/quantization/model_calib.py
  • tests/gpu/torch/export/test_export_weight_gpu.py
  • tests/unit/torch/quantization/plugins/test_fused_experts.py
  • tests/unit/torch/quantization/test_calib.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • modelopt/torch/quantization/model_calib.py
  • tests/unit/torch/quantization/plugins/test_fused_experts.py

max_calibrate now always runs weight_only_quantize before the optional
forward_loop, so every weight quantizer gets _amax regardless of MoE
routing. Weight quantizers disabled by the caller (e.g. awq_lite, which
runs max_calibrate with weight quantizers disabled) are skipped by
weight_only_quantize, so the AWQ dynamic-amax path is unaffected.

With _amax guaranteed after calibration, remove two now-redundant
band-aids:
  - _bootstrap_uncalibrated_weight_quantizers (re-ran weight calibration
    for experts skipped by partial MoE routing); superseded by the
    always-on weight_only_quantize.
  - _ensure_weight_quantizer_calibrated and its helpers in export (lazy
    weight calibration at scale-factor extraction time), plus the GPU
    test that only exercised that lazy path.

Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
@sychen52 sychen52 force-pushed the weight_only_quantize branch from 3b89895 to 311f3a5 Compare May 29, 2026 22:03
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (3)
modelopt/torch/quantization/model_calib.py (3)

109-110: 💤 Low value

Consider inlining this trivial helper.

_collect_weight_stats wraps a single call to quantizer(weight) and is used only once (line 543). Inlining partial(TensorQuantizer.__call__, weight_quantizer, weight) or a lambda directly at the call site would reduce indirection without losing clarity.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/quantization/model_calib.py` around lines 109 - 110, The
helper function _collect_weight_stats is a trivial wrapper that only calls
quantizer(weight); remove the _collect_weight_stats definition and replace its
single use with an inline call (e.g., a lambda or partial of
TensorQuantizer.__call__ bound to weight_quantizer and weight) at the call site
where weight_quantizer and weight are available so there is no indirection;
ensure you update any imports/usages referencing _collect_weight_stats to use
the inline callable (e.g., lambda w: weight_quantizer(w) or
partial(TensorQuantizer.__call__, weight_quantizer, weight)).

183-183: 💤 Low value

Parameter name model is misleading.

This function is invoked at line 542 with a TensorQuantizer, not a full model. Renaming the parameter to module would better reflect that it accepts any nn.Module (including individual quantizers), improving readability.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/quantization/model_calib.py` at line 183, The parameter name
in _run_and_load_max_stats is misleading: rename the parameter from model to
module in the function signature, update its type hint (keep as nn.Module) and
all internal references, and update every call site (including the call that
passes a TensorQuantizer) to use the new name; also update any
docstring/comments referencing the old parameter name to reflect that this
accepts any nn.Module (e.g., individual quantizers).

1263-1270: ⚡ Quick win

Clarify the double finish_stats_collection pattern.

max_calibrate internally calls finish_stats_collection (line 277), and line 1270 calls it again. The intent appears to be processing input quantizers inside the context (while weight quantizers are disabled), then processing weight quantizers outside the context — but this is subtle and undocumented. Consider adding a brief comment explaining why finish_stats_collection is invoked twice and what each call processes.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/quantization/model_calib.py` around lines 1263 - 1270, The
code calls finish_stats_collection twice: once indirectly via max_calibrate
(which performs a finish_stats_collection for the quantizers active inside the
set_quantizer_by_cfg_context where weight quantizers are disabled) and then
again explicitly after exiting the context to finalize the weight quantizers;
add a short clarifying comment above this block (referencing
set_quantizer_by_cfg_context, max_calibrate, and finish_stats_collection) that
states the first finish_stats_collection handles input-quantizer stats and
distributed sync while weight quantizers are disabled, and the second call
finalizes weight-quantizer stats after the context is exited.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 109-110: The helper function _collect_weight_stats is a trivial
wrapper that only calls quantizer(weight); remove the _collect_weight_stats
definition and replace its single use with an inline call (e.g., a lambda or
partial of TensorQuantizer.__call__ bound to weight_quantizer and weight) at the
call site where weight_quantizer and weight are available so there is no
indirection; ensure you update any imports/usages referencing
_collect_weight_stats to use the inline callable (e.g., lambda w:
weight_quantizer(w) or partial(TensorQuantizer.__call__, weight_quantizer,
weight)).
- Line 183: The parameter name in _run_and_load_max_stats is misleading: rename
the parameter from model to module in the function signature, update its type
hint (keep as nn.Module) and all internal references, and update every call site
(including the call that passes a TensorQuantizer) to use the new name; also
update any docstring/comments referencing the old parameter name to reflect that
this accepts any nn.Module (e.g., individual quantizers).
- Around line 1263-1270: The code calls finish_stats_collection twice: once
indirectly via max_calibrate (which performs a finish_stats_collection for the
quantizers active inside the set_quantizer_by_cfg_context where weight
quantizers are disabled) and then again explicitly after exiting the context to
finalize the weight quantizers; add a short clarifying comment above this block
(referencing set_quantizer_by_cfg_context, max_calibrate, and
finish_stats_collection) that states the first finish_stats_collection handles
input-quantizer stats and distributed sync while weight quantizers are disabled,
and the second call finalizes weight-quantizer stats after the context is
exited.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 3f52bb8e-a584-4c45-90be-0c8b12b4daac

📥 Commits

Reviewing files that changed from the base of the PR and between 3b89895 and 311f3a5.

📒 Files selected for processing (3)
  • modelopt/torch/export/quant_utils.py
  • modelopt/torch/quantization/model_calib.py
  • tests/gpu/torch/export/test_export_weight_gpu.py
💤 Files with no reviewable changes (1)
  • modelopt/torch/export/quant_utils.py

Copy link
Copy Markdown
Collaborator

@shengliangxu shengliangxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic makes sense to me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants