Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/guides/_pytorch_quantization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -365,19 +365,19 @@ Here is an example of creating custom calibration mode:
.. code-block:: python

from modelopt.torch.opt.config import ModeloptField
from modelopt.torch.quantization.config import QuantizeAlgorithmConfig
from modelopt.torch.quantization.config import CalibrationConfig
from modelopt.torch.quantization.mode import CalibrateModeRegistry, BaseCalibrateModeDescriptor
# custom configuration comprising of method name and
# any other parameters required by custom calibration function
class CustomConfig(QuantizeAlgorithmConfig):
class CustomConfig(CalibrationConfig):
method: Literal["custom_calib"] = ModeloptField("custom_calib")
...

# custom calibration mode class to register to base calibrator
@CalibrateModeRegistry.register_mode
class CustomCalibrateModeDescriptor(BaseCalibrateModeDescriptor):
@property
def config_class(self) -> QuantizeAlgorithmConfig:
def config_class(self) -> CalibrationConfig:
"""Specifies the config class."""
return CustomConfig

Expand Down
53 changes: 43 additions & 10 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,7 @@ def validate_calibrator(cls, v, info: ValidationInfo):
)


class QuantizeAlgorithmConfig(ModeloptBaseConfig):
class CalibrationConfig(ModeloptBaseConfig):
"""Calibration algorithm config base."""
Comment thread
Fridah-nv marked this conversation as resolved.

method: Literal[None] = ModeloptField(
Expand Down Expand Up @@ -1043,8 +1043,41 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig):
),
)

include_modules: list[str] | None = ModeloptField(
default=None,
title="Patterns of modules to include in calibration.",
description=(
"If provided, only modules whose names match at least one of the fnmatch patterns are "
"calibrated. Modules that do not match any pattern are skipped and retain their "
"pre-existing calibration state. "
"Mutually exclusive with ``exclude_modules``; specifying both raises an error. "
),
)

exclude_modules: list[str] | None = ModeloptField(
default=None,
title="Patterns of modules to exclude from calibration.",
description=(
"If provided, modules whose names match at least one of the fnmatch patterns are "
"skipped during calibration and retain their pre-existing calibration state. "
"Mutually exclusive with ``include_modules``; specifying both raises an error. "
),
)

@model_validator(mode="after")
def _check_include_exclude_mutually_exclusive(self) -> "CalibrationConfig":
if self.include_modules is not None and self.exclude_modules is not None:
raise ValueError(
"include_modules and exclude_modules are mutually exclusive; specify only one."
)
return self


# Backward-compatible alias — deprecated, will be removed in a future release.
QuantizeAlgorithmConfig = CalibrationConfig


class MaxCalibConfig(QuantizeAlgorithmConfig):
class MaxCalibConfig(CalibrationConfig):
"""The config for max calibration algorithm.

Max calibration estimates max values of activations or weights and use this max values
Expand All @@ -1061,7 +1094,7 @@ class MaxCalibConfig(QuantizeAlgorithmConfig):
)


class MseCalibConfig(QuantizeAlgorithmConfig):
class MseCalibConfig(CalibrationConfig):
"""Configuration for per-tensor MSE calibration.

Finds a scale s (via amax a, with s = a / q_max) that minimizes the
Expand Down Expand Up @@ -1111,7 +1144,7 @@ class MseCalibConfig(QuantizeAlgorithmConfig):
)


class LocalHessianCalibConfig(QuantizeAlgorithmConfig):
class LocalHessianCalibConfig(CalibrationConfig):
"""Configuration for local Hessian-weighted MSE calibration.

This algorithm uses activation information to optimize per-block scales for weight
Expand Down Expand Up @@ -1178,7 +1211,7 @@ class LocalHessianCalibConfig(QuantizeAlgorithmConfig):
)


class SmoothQuantCalibConfig(QuantizeAlgorithmConfig):
class SmoothQuantCalibConfig(CalibrationConfig):
"""The config for ``smoothquant`` algorithm (SmoothQuant).

SmoothQuant applies a smoothing factor which balances the scale of outliers in weights and activations.
Expand All @@ -1200,7 +1233,7 @@ class SmoothQuantCalibConfig(QuantizeAlgorithmConfig):
)


class AWQLiteCalibConfig(QuantizeAlgorithmConfig):
class AWQLiteCalibConfig(CalibrationConfig):
"""The config for ``awq_lite`` (AWQ lite) algorithm.

AWQ lite applies a channel-wise scaling factor which minimizes the output difference after quantization.
Expand All @@ -1224,7 +1257,7 @@ class AWQLiteCalibConfig(QuantizeAlgorithmConfig):
)


class AWQClipCalibConfig(QuantizeAlgorithmConfig):
class AWQClipCalibConfig(CalibrationConfig):
"""The config for ``awq_clip`` (AWQ clip) algorithm.

AWQ clip searches clipped amax for per-group quantization, This search requires much more compute
Expand Down Expand Up @@ -1290,7 +1323,7 @@ class AWQFullCalibConfig(AWQLiteCalibConfig, AWQClipCalibConfig):
)


class SVDQuantConfig(QuantizeAlgorithmConfig):
class SVDQuantConfig(CalibrationConfig):
"""The config for SVDQuant.

Refer to the `SVDQuant paper <https://arxiv.org/pdf/2411.05007>`_ for more details.
Expand All @@ -1308,7 +1341,7 @@ class SVDQuantConfig(QuantizeAlgorithmConfig):
)


class GPTQLiteConfig(QuantizeAlgorithmConfig):
class GPTQLiteConfig(CalibrationConfig):
"""The config for GPTQ lite.

GPTQ lite is a variant of GPTQ that does not exactly follow the official GPTQ implementation.
Expand Down Expand Up @@ -1353,7 +1386,7 @@ class GPTQLiteConfig(QuantizeAlgorithmConfig):
| dict[str | Callable, QuantizerAttributeConfig | list[QuantizerAttributeConfig]],
]

_QuantizeAlgoCfgType = str | dict | QuantizeAlgorithmConfig | None
_QuantizeAlgoCfgType = str | dict | CalibrationConfig | None

QuantizeAlgoCfgType = _QuantizeAlgoCfgType | list[_QuantizeAlgoCfgType] | None

Expand Down
70 changes: 37 additions & 33 deletions modelopt/torch/quantization/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@
AWQClipCalibConfig,
AWQFullCalibConfig,
AWQLiteCalibConfig,
CalibrationConfig,
CompressConfig,
GPTQLiteConfig,
LocalHessianCalibConfig,
MaxCalibConfig,
MseCalibConfig,
QuantizeAlgoCfgType,
QuantizeAlgorithmConfig,
QuantizeConfig,
SmoothQuantCalibConfig,
SVDQuantConfig,
Expand All @@ -59,6 +59,7 @@
)
from .model_calib import (
awq,
filter_calib_modules,
gptq_lite,
local_hessian_calibrate,
max_calibrate,
Expand Down Expand Up @@ -210,7 +211,7 @@ def name(self) -> str:

def wrapped_calib_func(
model: ModelLikeModule,
config: QuantizeAlgorithmConfig,
config: CalibrationConfig,
forward_loop: ForwardLoop | None = None,
func: Callable | None = None,
) -> ConvertReturnType:
Expand All @@ -223,6 +224,8 @@ def wrapped_calib_func(
kwargs = config.model_dump()
method = kwargs.pop("method")
sequential = kwargs.pop("use_sequential", False)
include_modules = kwargs.pop("include_modules", None)
exclude_modules = kwargs.pop("exclude_modules", None)
if method is not None and "awq" in method:
# For backward compatibility
kwargs["algorithm"] = method
Expand All @@ -237,22 +240,23 @@ def wrapped_calib_func(
module._moe_calib_experts_ratio = moe_calib_experts_ratio

if func is not None:
if sequential:
if forward_loop is None:
raise ValueError("forward_loop is required for calibration but got None.")
assert method in ["max"], (
f"Sequential calibration currently only supports max calibration, got {method}"
)
# Wrap with sequential processing
sequential_calibrate(
model,
forward_loop=forward_loop,
calib_func=func,
**kwargs,
)
else:
# Direct calibration (existing behavior)
func(model, forward_loop=forward_loop, **kwargs)
with filter_calib_modules(model, include_modules, exclude_modules):
if sequential:
if forward_loop is None:
raise ValueError("forward_loop is required for calibration but got None.")
assert method in ["max"], (
f"Sequential calibration currently only supports max calibration, got {method}"
)
# Wrap with sequential processing
sequential_calibrate(
model,
forward_loop=forward_loop,
calib_func=func,
**kwargs,
)
Comment thread
Fridah-nv marked this conversation as resolved.
else:
# Direct calibration (existing behavior)
func(model, forward_loop=forward_loop, **kwargs)

# Lets get the latest metadata for the quantizer states
metadata = {}
Expand All @@ -264,7 +268,7 @@ class BaseCalibrateModeDescriptor(ModeDescriptor):
"""Base class for quantization calibration algorithm modes.

All calibration algorithm modes must be derived from this base class.
In addition, the `config_class` for the mode must return a subclass of :class:`QuantizeAlgorithmConfig`.
In addition, the `config_class` for the mode must return a subclass of :class:`CalibrationConfig`.

This base class also provides some convenient wrappers/utilities for calibration algorithms to be
translated into ModelOpt mode.
Expand All @@ -283,8 +287,8 @@ class BaseCalibrateModeDescriptor(ModeDescriptor):

def __init__(self, *args, **kwargs):
"""Initialize Base calibrate mode descriptor."""
assert issubclass(self.config_class, QuantizeAlgorithmConfig), (
f"`config_class` of {self.__class__} must be a subclass of `QuantizeAlgorithmConfig`!, "
assert issubclass(self.config_class, CalibrationConfig), (
f"`config_class` of {self.__class__} must be a subclass of `CalibrationConfig`!, "
f"got {self.config_class}!"
)
super().__init__(*args, **kwargs)
Expand All @@ -305,7 +309,7 @@ def name(self) -> str:

@property
@abstractmethod
def config_class(self) -> type[QuantizeAlgorithmConfig]:
def config_class(self) -> type[CalibrationConfig]:
"""Specifies the config class for the mode."""

@property
Expand Down Expand Up @@ -380,9 +384,9 @@ class NoneCalibrateModeDescriptor(BaseCalibrateModeDescriptor):
"""Mode for no calibration algorithm."""

@property
def config_class(self) -> type[QuantizeAlgorithmConfig]:
def config_class(self) -> type[CalibrationConfig]:
"""Specifies the config class for the mode."""
return QuantizeAlgorithmConfig
return CalibrationConfig

_calib_func = None

Expand All @@ -392,7 +396,7 @@ class MaxCalibrateModeDescriptor(BaseCalibrateModeDescriptor):
"""Mode for max calibration algorithm."""

@property
def config_class(self) -> type[QuantizeAlgorithmConfig]:
def config_class(self) -> type[CalibrationConfig]:
"""Specifies the config class for the mode."""
return MaxCalibConfig

Expand All @@ -404,7 +408,7 @@ class MseCalibrateModeDescriptor(BaseCalibrateModeDescriptor):
"""Mode for mse calibration algorithm."""

@property
def config_class(self) -> type[QuantizeAlgorithmConfig]:
def config_class(self) -> type[CalibrationConfig]:
"""Specifies the config class for the mode."""
return MseCalibConfig

Expand All @@ -420,7 +424,7 @@ class LocalHessianModeDescriptor(BaseCalibrateModeDescriptor):
"""

@property
def config_class(self) -> type[QuantizeAlgorithmConfig]:
def config_class(self) -> type[CalibrationConfig]:
"""Specifies the config class for the mode."""
return LocalHessianCalibConfig

Expand All @@ -432,7 +436,7 @@ class SmoothQuantModeDescriptor(BaseCalibrateModeDescriptor):
"""Mode for smoothquant calibration algorithm."""

@property
def config_class(self) -> type[QuantizeAlgorithmConfig]:
def config_class(self) -> type[CalibrationConfig]:
"""Specifies the config class for the mode."""
return SmoothQuantCalibConfig

Expand All @@ -444,7 +448,7 @@ class AWQLiteModeDescriptor(BaseCalibrateModeDescriptor):
"""Mode for AWQ lite calibration algorithm."""

@property
def config_class(self) -> type[QuantizeAlgorithmConfig]:
def config_class(self) -> type[CalibrationConfig]:
"""Specifies the config class for the mode."""
return AWQLiteCalibConfig

Expand All @@ -456,7 +460,7 @@ class AWQClipModeDescriptor(BaseCalibrateModeDescriptor):
"""Mode for AWQ clip calibration algorithm."""

@property
def config_class(self) -> type[QuantizeAlgorithmConfig]:
def config_class(self) -> type[CalibrationConfig]:
"""Specifies the config class for the mode."""
return AWQClipCalibConfig

Expand All @@ -468,7 +472,7 @@ class AWQFullModeDescriptor(BaseCalibrateModeDescriptor):
"""Mode for AWQ full calibration algorithm."""

@property
def config_class(self) -> type[QuantizeAlgorithmConfig]:
def config_class(self) -> type[CalibrationConfig]:
"""Specifies the config class for the mode."""
return AWQFullCalibConfig

Expand All @@ -480,7 +484,7 @@ class SVDQuantModeDescriptor(BaseCalibrateModeDescriptor):
"""Mode for SVDQuant calibration algorithm."""

@property
def config_class(self) -> type[QuantizeAlgorithmConfig]:
def config_class(self) -> type[CalibrationConfig]:
"""Specifies the config class for the mode."""
return SVDQuantConfig

Expand All @@ -497,7 +501,7 @@ class GPTQLiteModeDescriptor(BaseCalibrateModeDescriptor):
"""Mode for GPTQ calibration algorithm."""

@property
def config_class(self) -> type[QuantizeAlgorithmConfig]:
def config_class(self) -> type[CalibrationConfig]:
"""Specifies the config class for the mode."""
return GPTQLiteConfig

Expand Down
Loading
Loading