-
Notifications
You must be signed in to change notification settings - Fork 423
Refine layerwise non-mutating calibration #1592
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: fridah/layerwise-config
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,7 +34,7 @@ | |
| _CheckpointState, | ||
| ) | ||
| from modelopt.torch.utils import print_rank_0 | ||
| from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState | ||
| from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState, is_master | ||
| from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method | ||
|
|
||
| from .calib import MseCalibrator, NVFP4MSECalibrator, _Calibrator | ||
|
|
@@ -1761,7 +1761,7 @@ def layerwise_calibrate( | |
| checkpoint_dir = calib_kwargs.pop("checkpoint_dir", None) | ||
| qdq_from_prev = calib_kwargs.pop("get_qdq_activations_from_prev_layer", False) | ||
| save_every = calib_kwargs.pop("save_every", 1) | ||
| save_quantizers_only = calib_kwargs.pop("save_quantizers_only", False) | ||
| calib_mutates_weights = calib_kwargs.pop("calib_mutates_weights", True) | ||
|
|
||
| if forward_loop is None: | ||
| raise ValueError( | ||
|
|
@@ -1783,16 +1783,27 @@ def layerwise_calibrate( | |
| checkpoint_dir, | ||
| num_layers, | ||
| save_every=save_every, | ||
| save_quantizers_only=save_quantizers_only, | ||
| calib_mutates_weights=calib_mutates_weights, | ||
| ) | ||
| start_layer = ckpt.start_layer if ckpt else 0 | ||
|
|
||
| input_getter = LayerActivationCollector(model) | ||
| input_getter._patch_all_layers(decoder_layers=transformer_layers) | ||
| layer_pbar = tqdm( | ||
| total=num_layers, | ||
| initial=start_layer, | ||
| desc="Layerwise calibration", | ||
| disable=not is_master(), | ||
| dynamic_ncols=True, | ||
| ) | ||
|
|
||
| def _set_layer_status(status: str): | ||
| layer_pbar.set_postfix_str(status, refresh=True) | ||
|
|
||
| resumed_inputs = ckpt.setup_resume(transformer_layers) if ckpt and start_layer > 0 else None | ||
| input_getter = LayerActivationCollector(model, status_callback=_set_layer_status) | ||
|
|
||
| try: | ||
| input_getter._patch_all_layers(decoder_layers=transformer_layers) | ||
| resumed_inputs = ckpt.setup_resume(transformer_layers) if ckpt and start_layer > 0 else None | ||
|
|
||
| # Bootstrap: get first layer's inputs (or use resumed inputs). | ||
| layer_inputs = input_getter.get_first_layer_inputs( | ||
| start_layer, resumed_inputs, forward_loop | ||
|
|
@@ -1822,39 +1833,43 @@ def _layer_forward_loop(m, _inputs=layer_inputs): | |
|
|
||
| is_last = layer_idx + 1 >= num_layers | ||
|
|
||
| # qdq_from_prev=False: capture before calib_func so the forward | ||
| # replay uses the original FP weights. Disable quantizers too in | ||
| # case any pre-calibration observer behavior would perturb the | ||
| # captured activations. | ||
| if not is_last and not qdq_from_prev: | ||
| with set_quantizer_by_cfg_context( | ||
| layer, [{"quantizer_name": "*", "enable": False}] | ||
| ): | ||
| next_inputs = input_getter.cache_outputs_for_next_layer_calib( | ||
| layer, forward_loop | ||
| ) | ||
| # cache_outputs left this layer in "run" mode with an empty | ||
| # deque; reset so calib_func's replay hits the real forward. | ||
| layer._layerwise_calib.mode = "original" | ||
| with persistent_materialization(layer, writeback=calib_mutates_weights): | ||
| # qdq_from_prev=False: capture before calib_func so the forward | ||
| # replay uses the original FP weights. Disable quantizers too in | ||
| # case any pre-calibration observer behavior would perturb the | ||
| # captured activations. | ||
| if not is_last and not qdq_from_prev: | ||
| with set_quantizer_by_cfg_context( | ||
| layer, [{"quantizer_name": "*", "enable": False}] | ||
| ): | ||
| next_inputs = input_getter.cache_outputs_for_next_layer_calib( | ||
| layer, forward_loop | ||
| ) | ||
| # cache_outputs left this layer in "run" mode with an empty | ||
| # deque; reset so calib_func's replay hits the real forward. | ||
| layer._layerwise_calib.mode = "original" | ||
|
|
||
| with persistent_materialization(layer): | ||
| calib_func(layer, _layer_forward_loop, **calib_kwargs) | ||
|
|
||
| # qdq_from_prev=True: capture after calib_func so the next layer | ||
| # sees QDQ error and any in-place weight updates from this layer. | ||
| if not is_last and qdq_from_prev: | ||
| next_inputs = input_getter.cache_outputs_for_next_layer_calib(layer, forward_loop) | ||
| elif is_last: | ||
| next_inputs = None | ||
| # qdq_from_prev=True: capture after calib_func so the next layer | ||
| # sees QDQ error and any in-place weight updates from this layer. | ||
| if not is_last and qdq_from_prev: | ||
| next_inputs = input_getter.cache_outputs_for_next_layer_calib( | ||
| layer, forward_loop | ||
| ) | ||
| elif is_last: | ||
| next_inputs = None | ||
|
|
||
| if ckpt: | ||
| ckpt.save(layer_idx, model, transformer_layers, next_inputs) | ||
| if ckpt: | ||
| ckpt.save(layer_idx, model, transformer_layers, next_inputs) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you explain why we move |
||
|
|
||
| layer_pbar.update(1) | ||
| del layer_inputs | ||
| torch.cuda.empty_cache() | ||
| layer_inputs = next_inputs # noqa: F841 (used in next iteration's closure) | ||
| finally: | ||
| input_getter._unpatch_all_layers() | ||
| layer_pbar.close() | ||
|
|
||
| if ckpt: | ||
| ckpt.full_restore(transformer_layers, model) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,6 +29,7 @@ | |
|
|
||
| from modelopt.torch.quantization.config import QuantizerCfgEntry | ||
| from modelopt.torch.utils import get_unwrapped_name, print_rank_0 | ||
| from modelopt.torch.utils.network import temporarily_remove_accelerate_hook | ||
|
|
||
| if TYPE_CHECKING: | ||
| from collections.abc import Generator | ||
|
|
@@ -471,7 +472,24 @@ def _set_parameter(module: nn.Module, name: str, value: nn.Parameter): | |
|
|
||
|
|
||
| @contextmanager | ||
| def fsdp2_weight_access_and_writeback_context(module: nn.Module, root_model: nn.Module): | ||
| def _fsdp2_unshard_context(fsdp_module: FSDPModule): | ||
| """Unshard an FSDP2 module without replacing individual DTensor parameters.""" | ||
| fsdp_param_group = fully_shard.state(fsdp_module)._fsdp_param_group | ||
| was_sharded = fsdp_param_group.is_sharded | ||
| if was_sharded: | ||
| fsdp_module.unshard() | ||
| try: | ||
| with _disable_fsdp_unshard_reshard(fsdp_module): | ||
| yield | ||
| finally: | ||
| if was_sharded: | ||
| fsdp_module.reshard() | ||
|
|
||
|
|
||
| @contextmanager | ||
| def fsdp2_weight_access_and_writeback_context( | ||
| module: nn.Module, root_model: nn.Module, writeback: bool = True | ||
| ): | ||
| """Context manager for FSDP2 weight access and writeback. | ||
|
|
||
| Gathers sharded DTensor parameters across FSDP/HSDP shards so they can be | ||
|
|
@@ -486,6 +504,11 @@ def fsdp2_weight_access_and_writeback_context(module: nn.Module, root_model: nn. | |
| assert not hasattr(module, "_hf_hook"), "We dont support FSDP2 with HF accelerate hooks" | ||
| fsdp_module = _get_enclosing_fsdp_module(module, root_model) | ||
| assert fsdp_module is not None, "Module is not wrapped by FSDP" | ||
| if not writeback: | ||
| with _fsdp2_unshard_context(fsdp_module): | ||
| yield | ||
| return | ||
|
|
||
|
Comment on lines
+507
to
+511
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sugunav14 here is an easy perf improvement for layerwise FSDP2
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @realAsma Claude claims a correctness issue on this branch, it makes sense to me, please check from your side: Bug: FSDP2
|
||
| fsdp_device_mesh = _get_fsdp2_mesh(fsdp_module) | ||
| fsdp_dim = fsdp_device_mesh.ndim | ||
|
|
||
|
|
@@ -525,7 +548,9 @@ def fsdp2_weight_access_and_writeback_context(module: nn.Module, root_model: nn. | |
|
|
||
|
|
||
| @contextmanager | ||
| def enable_weight_access_and_writeback(module, root_model, name_to_module: dict | None = None): | ||
| def enable_weight_access_and_writeback( | ||
| module, root_model, name_to_module: dict | None = None, writeback: bool = True | ||
| ): | ||
| """Enable weight access and writeback for a module. | ||
|
|
||
| Useful for modules with weight not intact such as Linear layer in FSDP wrapped model or | ||
|
|
@@ -539,16 +564,18 @@ def enable_weight_access_and_writeback(module, root_model, name_to_module: dict | |
| total cost when called in a loop. This causes significant CPU overhead on large | ||
| models, particularly Sparse MoE architectures where each expert is typically | ||
| implemented as its own module. | ||
| writeback: Whether modified weights must be written back to the owning sharded/offload | ||
| representation when exiting the context. | ||
| """ | ||
| if _get_enclosing_fsdp_module(module, root_model, name_to_module) is not None: | ||
| context = fsdp2_weight_access_and_writeback_context(module, root_model) | ||
| context = fsdp2_weight_access_and_writeback_context(module, root_model, writeback) | ||
| elif is_quantized_parallel_linear(module) and hasattr(module, "_hf_tp_plan"): | ||
| # HF transformers TP sharded linear layer | ||
| context = module.enable_weight_access_and_writeback() | ||
| elif hasattr(module, "_hf_hook"): | ||
| from ..plugins.accelerate import weight_access_and_writeback_context | ||
|
|
||
| context = weight_access_and_writeback_context(module) | ||
| context = weight_access_and_writeback_context(module, writeback) | ||
| else: | ||
| context = nullcontext() | ||
|
|
||
|
|
@@ -557,18 +584,23 @@ def enable_weight_access_and_writeback(module, root_model, name_to_module: dict | |
|
|
||
|
|
||
| @contextmanager | ||
| def persistent_materialization(layer): | ||
| def persistent_materialization(layer, writeback: bool = True): | ||
| """Keep all layer weights materialized on GPU for the duration. | ||
|
|
||
| Suppresses per-forward weight transfers so that N calibration batches | ||
| pay the cost of one load/unload instead of N. | ||
|
|
||
| - **FSDP2**: patches ``FSDPParamGroup.unshard/reshard`` to no-ops, then | ||
| gathers weights once via ``enable_weight_access_and_writeback``. | ||
| - **Accelerate**: materializes weights and sets ``hook.offload = False`` | ||
| so per-forward hooks skip materialization/offloading. | ||
| - **Accelerate**: materializes weights, sets ``hook.offload = False``, | ||
| and bypasses the layer's top-level accelerate hook while the weights are | ||
| materialized. | ||
| """ | ||
| with _disable_fsdp_unshard_reshard(layer), enable_weight_access_and_writeback(layer, layer): | ||
| with ( | ||
| _disable_fsdp_unshard_reshard(layer), | ||
| enable_weight_access_and_writeback(layer, layer, writeback=writeback), | ||
| temporarily_remove_accelerate_hook(layer), | ||
| ): | ||
| yield | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we can rename
_supports_save_quantizers_onlyto names like_calib_is_amax_onlyto align with new vocabulary, otherwise a reader has to learn that_supports_save_quantizers_only=Truemeans "amax-only algorithm, socalib_mutates_weights=Falseis allowed"