Refactor local_hessian onto shared MSE flow + fused-MoE expert support#1578
Refactor local_hessian onto shared MSE flow + fused-MoE expert support#1578Fridah-nv wants to merge 4 commits into
Conversation
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughRefactors local-Hessian-weighted MSE calibration: adds activation-capture hooks, threads per-quantizer error functions into MSE calibrators, implements per-block Hessian accumulation during forward capture, and runs shared MSE refinement; includes fused-experts hook wiring and unit tests. ChangesLocal Hessian Calibration Enhancement
Sequence DiagramsequenceDiagram
participant Calibration as local_hessian_calibrate
participant Model as Model
participant Quantizer as WeightQuantizer
participant Accumulator as _LocalHessianAccumulator
participant ForwardLoop as forward_loop
participant MSE as _mse_calibrate_weights
Calibration->>Quantizer: silence enabled weight quantizers
Calibration->>Model: register hooks via register_calibration_input_hooks
Accumulator->>Accumulator: init per-quantizer accumulators
Calibration->>ForwardLoop: execute forward pass
ForwardLoop->>Model: forward(input)
Model->>Accumulator: invoke hook (weight_quantizer, weight, input)
Accumulator->>Accumulator: accumulate XᵀX per block
Calibration->>Quantizer: restore weight quantizers
Calibration->>Accumulator: build error_func from Hessians
Calibration->>MSE: _mse_calibrate_weights(error_func_for)
MSE->>MSE: refine amax using Hessian-weighted loss
🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1578 +/- ##
==========================================
- Coverage 77.33% 76.81% -0.53%
==========================================
Files 478 479 +1
Lines 52421 52754 +333
==========================================
- Hits 40541 40524 -17
- Misses 11880 12230 +350
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
|
/claude review |
cjluo-nv
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Solid refactor that consolidates local_hessian_calibrate onto the shared MSE flow and extends Hessian capture to HF fused-MoE experts via a new register_calibration_input_hooks extension point on QuantModule. Code is well-structured, tests are meaningful (accumulator math, dense end-to-end, conv fallback, block-size mismatch, registered-backend skip, per-expert MoE capture + refinement), and bit-identical preservation on Qwen3-8B is claimed.
Design check passed: the new extension point is a small activation-side counterpart to the existing iter_weights_for_calibration, not a parallel subsystem; this is the right shape rather than introducing a new registry/loader. The PR explicitly documents fallback paths (never-routed experts, non-eager kernels, cin % block_size != 0, registered backends) and makes them warn rather than silently miscalibrate.
Nudging (not approving) for two reasons that warrant a human signoff:
-
Distributed correctness gap is acknowledged but not fixed. The new code emits a
warn_rank_0and a TODO that "the per-block Hessian is not synced across TP/DP ranks ... refined amaxes can diverge under tensor/data parallelism. Treat local_hessian as single-rank for now." The previous (LocalHessianHelper) implementation had the same gap (also TODO'd), so this is no regression — but the change still ships a feature whose multi-rank behavior is known-divergent. A maintainer should confirm this is acceptable for the intended use cases (and whetherlocal_hessian_calibrateshould hard-error rather than warn under TP/DP > 1, asmax_calibratedoes for NVFP4-static + TP). -
Bit-equivalence claim is human-verified outside CI. The PR body says the refactor produces "bit-identical dense weight scales to main (216/216 tensors) on Qwen3-8B with the fp32 accumulation neutralized to isolate the structural change." This is exactly the kind of claim that needs human confirmation since it can't be re-run in CI; please have the original reviewer/runner sign off that the comparison was done with
fp32 accumulation neutralized(i.e. bf16 GEMM as in the old code) and that the deliberate fp32-accumulation switch in the new path is the only intended numerical change.
Minor (non-blocking) observations:
from modelopt.torch.utils.distributed import ...is split across three statements; the two new ones (is_initialized as dist_is_initialized,size as dist_size) could be folded into the existing line._warn_local_hessian_fallbackis only called whencapturesis non-empty, which means acin % block_size != 0warning is suppressed for non-pairable modules (where it wouldn't matter anyway). Consider a brief comment so a future reader doesn't think the warning is missing.SequentialQuantizerweight quantizers are silently skipped (matchesmse_calibrate), but this is invisible to users runninglocal_hessianon a SequentialQuantizer config — they'll see no fallback warning sinceregister_calibration_input_hooksdoes pair them. Worth a one-line comment inlocal_hessian_calibrateor a sibling warning path.
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/quantization/model_calib.py (1)
497-520:⚠️ Potential issue | 🟠 Major | 🏗️ Heavy liftFallback weights are skipped instead of getting plain MSE.
With
fp8_scale_sweep=True, this helper returnsNonefor every unregistered backend and every non-static-NVFP4 quantizer. Inlocal_hessian_calibrate(), that is also the path used for "no Hessian available" fallbacks, so those weights stay at max calibration instead of receiving the documented plain-MSE refinement.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/quantization/model_calib.py` around lines 497 - 520, The fp8_scale_sweep branch currently returns None for unregistered backends and non-static NVFP4 quantizers, causing those weights to skip plain-MSE refinement; change that fallback to return the same plain-MSE calibrator used by local_hessian_calibrate instead of None. Specifically, in the block after _uses_modelopt_fp8_weight_scales(...) where it now returns None, instantiate and return the plain MSE calibrator (the same class/factory local_hessian_calibrate uses for "no Hessian available") with the same parameters you pass to NVFP4MSECalibrator (initial_amax, axis, quant_func, error_func and global_amax from weight_quantizer if needed) so unregistered/non-static backends receive plain-MSE calibration rather than None.
🧹 Nitpick comments (1)
modelopt/torch/quantization/model_calib.py (1)
503-509: ⚡ Quick winUse
warn_rank_0for this backend-skip warning.This can fire once per quantizer on every rank, so
warnings.warnwill get noisy in distributed calibration.Suggested fix
- warnings.warn( + warn_rank_0( f"local_hessian: backend '{backend}' does not support a custom error " "function; skipping Hessian-weighted calibration for this quantizer." )As per coding guidelines, "Develop with distributed processing in mind by using
print_rank_0orwarn_rank_0to avoid noisy logs and guarding shared side effects such as file writes or shared state updates against race conditions between ranks".🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/quantization/model_calib.py` around lines 503 - 509, The warning for skipping Hessian refinement should use the distributed-safe helper warn_rank_0 instead of warnings.warn to avoid noisy warnings on every rank; in the block that checks error_func (the local_hessian/backend skip branch) replace warnings.warn(...) with warn_rank_0(...) (and add an import for warn_rank_0 if missing) so the message about "backend '{backend}' does not support a custom error function; skipping Hessian-weighted calibration for this quantizer." only prints on rank 0.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 657-665: build_error_func currently clears self.hessian_per_block
before model._local_hessian_accumulators can be populated, so debug mode can't
inspect the accumulated Hessian; fix by preserving the raw accumulator until
after any debug assignment: when building hessian = self.hessian_per_block /
self.num_samples, first, if debug inspection is enabled (check
model._local_hessian_accumulators or a debug flag), assign or copy
self.hessian_per_block (or hessian) into model._local_hessian_accumulators, then
only set self.hessian_per_block = None; alternatively, conditionally only clear
self.hessian_per_block when debug is false—apply the same change at the other
clearing site referenced (the similar clear around lines 820-821).
- Around line 763-797: During the local Hessian capture pass we only silence
weight quantizers; also collect and temporarily disable activation/input
quantizers so QuantInputBase.forward() does not quantize activations during
forward_loop. Add a silenced_input_quantizers list (analogous to
silenced_weight_quantizers), and inside the loop over name_to_module when you
register captures check modules that expose an input quantizer (e.g.,
module.input_quantizer or instances of QuantInputBase) and if that quantizer is
a TensorQuantizer (or has is_enabled and _if_quant) append it to
silenced_input_quantizers and disable it; then in the finally block re-enable
each quantizer in silenced_input_quantizers (and still re-enable
silenced_weight_quantizers and remove handles) so the forward pass uses
full-precision activations for ΣXᵀX collection.
In `@modelopt/torch/quantization/nn/modules/quant_module.py`:
- Around line 262-272: The hook is currently registered for any enabled
weight_quantizer (including SequentialQuantizer) which contaminates captures;
change register_calibration_input_hooks to only register when
self.weight_quantizer is an instance of TensorQuantizer (in addition to existing
checks on weight presence/dim and is_enabled). Update the condition in
register_calibration_input_hooks to check isinstance(self.weight_quantizer,
TensorQuantizer) so only TensorQuantizer-backed weights get the forward-pre
hook; keep the _pre_hook and registration logic unchanged.
In `@tests/unit/torch/quantization/plugins/test_fused_experts.py`:
- Around line 662-664: The test currently imports modelopt.torch.quantization as
mtq and local_hessian_calibrate inside the test function; move these imports to
the module top so import errors surface at collection time—add top-level imports
for mtq (modelopt.torch.quantization) and local_hessian_calibrate
(modelopt.torch.quantization.model_calib) near other test imports and remove the
in-function import statements in test_fused_experts.py so the test uses the
module-scope mtq and local_hessian_calibrate symbols.
---
Outside diff comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 497-520: The fp8_scale_sweep branch currently returns None for
unregistered backends and non-static NVFP4 quantizers, causing those weights to
skip plain-MSE refinement; change that fallback to return the same plain-MSE
calibrator used by local_hessian_calibrate instead of None. Specifically, in the
block after _uses_modelopt_fp8_weight_scales(...) where it now returns None,
instantiate and return the plain MSE calibrator (the same class/factory
local_hessian_calibrate uses for "no Hessian available") with the same
parameters you pass to NVFP4MSECalibrator (initial_amax, axis, quant_func,
error_func and global_amax from weight_quantizer if needed) so
unregistered/non-static backends receive plain-MSE calibration rather than None.
---
Nitpick comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 503-509: The warning for skipping Hessian refinement should use
the distributed-safe helper warn_rank_0 instead of warnings.warn to avoid noisy
warnings on every rank; in the block that checks error_func (the
local_hessian/backend skip branch) replace warnings.warn(...) with
warn_rank_0(...) (and add an import for warn_rank_0 if missing) so the message
about "backend '{backend}' does not support a custom error function; skipping
Hessian-weighted calibration for this quantizer." only prints on rank 0.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 967322de-f723-4696-8e51-a3a3a8baf41c
📒 Files selected for processing (5)
modelopt/torch/quantization/model_calib.pymodelopt/torch/quantization/nn/modules/quant_module.pymodelopt/torch/quantization/plugins/huggingface.pytests/unit/torch/quantization/plugins/test_fused_experts.pytests/unit/torch/quantization/test_local_hessian.py
There was a problem hiding this comment.
Claude review summary
Findings: CRITICAL: 0, IMPORTANT: 1, SUGGESTION: 1
The refactor is structurally clean — the new _LocalHessianAccumulator, the shared _mse_calibrate_weights helper, and the register_calibration_input_hooks extension point all match well with the existing MSE flow, and the per-expert capture via _current_expert_idx is a nice fit for the fused-MoE plugin. Tests cover the accumulator math, the dense end-to-end path, the fallback warnings, and the per-expert MoE path.
Most impactful finding
[IMPORTANT Compatibility] During the activation-capture forward, SequentialQuantizer-wrapped weight quantizers are no longer silenced. The isinstance(weight_quantizer, TensorQuantizer) filter on the silencing list excludes them (SequentialQuantizer is an nn.Sequential), whereas the previous implementation disabled them via the delegated SequentialQuantizer.disable()/enable(). For layers that use SequentialQuantizer for weights (e.g. INT4 weights + FP8 scale, or downgraded restore paths), the Hessian-capture forward now propagates quantized weights, subtly corrupting the captured Hessian for downstream TensorQuantizer layers. The PR's bit-identical Qwen3-8B check uses pure TensorQuantizer paths and wouldn't catch this. See inline comment at model_calib.py:776–786 for a suggested fix (iterate SequentialQuantizer leaves when building the silencing list).
Nit
[SUGGESTION] With debug=True, model._local_hessian_accumulators retains accumulators whose hessian_per_block was already freed by build_error_func. The docstring's "retain … for inspection" implies the Hessian tensor is inspectable. Either snapshot before build_error_func, or tighten the docstring.
Risk assessment
Low–medium. The common path (TensorQuantizer-only weight quantizers, e.g. INT8 / FP8 / static NVFP4) is preserved bit-identically per the author's own check. The SequentialQuantizer regression affects only INT4-with-FP8-scale-style configurations and only manifests as a small Hessian-quality drift for downstream TensorQuantizer refinements. Worth fixing before merge but doesn't break the headline use case.
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/unit/torch/quantization/test_local_hessian.py`:
- Line 184: Move the in-test import of SequentialQuantizer to the module top:
remove the inline "from modelopt.torch.quantization.nn import
SequentialQuantizer" inside the test and add it to the file's top-level imports
so import errors surface at collection time; keep the inline import only if
there is a documented circular/optional-dependency reason and add a comment
explaining that. Ensure the test references the top-level SequentialQuantizer
symbol and run tests to confirm no circular import was introduced.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: b1029f88-ec3a-452d-8f44-3f0fd96d0a83
📒 Files selected for processing (4)
modelopt/torch/quantization/model_calib.pymodelopt/torch/quantization/nn/modules/quant_module.pytests/unit/torch/quantization/plugins/test_fused_experts.pytests/unit/torch/quantization/test_local_hessian.py
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/unit/torch/quantization/plugins/test_fused_experts.py
- modelopt/torch/quantization/model_calib.py
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
|
Can we test doing this on a model and runs some simple evals and compare that with the base MSE without hessian? |
|
/claude review |
I plan to do a larger scale experiments after this PR lands, I'd like to merge it first to make sure basic correctness and code quality. |
What does this PR do?
Type of change: Bug fix + new feature (fused-MoE coverage)
Dusted off and refactored
local_hessian_calibrateto align with the new MSE calibration flow, fix latentdrift, extend coverage to fused-MoE experts, and decouple module-specific handling behind a clean extension point.
Core changes (
modelopt/torch/quantization/model_calib.py):_mse_calibrate_weightshelper now used by bothmse_calibrateandlocal_hessian_calibrateLocalHessianHelper+ bespoke per-weight loop with a small_LocalHessianAccumulator(lazy fp32 buffer, freed after building the error func), removing ~200 lines of duplicated scale-search logic and all manualcuda.synchronize/empty_cachebookkeeping. TheXᵀXGEMM accumulates in fp32 to avoid bf16/fp16 precision loss.max_calibrate).id(weight_quantizer)so dense and per-expert paths share one calibration loop. Never-routed experts / non-eager kernels /cinnot divisible byblock_size/ registered backends fall back to plain MSE (with an eager, module-named warning).Decoupling (zero module-type-specific code in
model_calib.py):QuantModule.register_calibration_input_hooks(callback)— the activation-side counterpart toiter_weights_for_calibration. Base default is a no-op;QuantLinearConvBasepairs the weight quantizer with theforward input (linear only), and
_QuantFusedExperts(inplugins/huggingface.py) owns the per-expert pairing via_current_expert_idx. Any future module type gains local-Hessian support by implementing this one method.Usage
# Add a code snippet demonstrating how to use thisTesting
tests/unit/torch/quantization/test_local_hessian.py(accumulator math/shape/dtype, dense end-to-end, backend-skip, block-size guard) and a per-expert MoE test intest_fused_experts.py.main(216/216 tensors) on Qwen3-8B with the fp32 accumulation neutralized to isolate the structural change.Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅Additional Information
Summary by CodeRabbit
New Features
Behavior
Tests