From 737a65a1ff71f442bec05976f9f887d9339183e9 Mon Sep 17 00:00:00 2001 From: adityasingh2400 Date: Thu, 21 May 2026 04:22:07 -0700 Subject: [PATCH] Point users at AOBaseConfig replacement when passing string quant_type to TorchAoConfig PR #13291 removed the string-based quant_type path from TorchAoConfig in favour of passing an AOBaseConfig subclass instance from torchao.quantization directly. The remaining TypeError ("quant_type must be an AOBaseConfig instance, got str") does not name a replacement, so users running existing code, especially against torchao >= 0.16 where the legacy lowercase factories (float8_weight_only, int8_weight_only, float8_dynamic_activation_float8_weight, ...) were removed upstream, hit the rename without a concrete migration path. Map the common legacy strings to their Config-class replacements and surface that mapping in the TypeError so the error itself tells the user which import to add and how to instantiate it. Strings outside the mapping still raise but point at the torchao quantization docs. The non-string branch is unchanged. Fixes #13286 Fixes #13266 --- .../quantizers/quantization_config.py | 48 +++++++++++++++++++ tests/quantization/torchao/test_torchao.py | 25 ++++++++++ 2 files changed, 73 insertions(+) diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index c3d829fde8cf..1a45a31cd750 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -476,8 +476,56 @@ def post_init(self): from torchao.quantization.quant_api import AOBaseConfig if not isinstance(self.quant_type, AOBaseConfig): + if isinstance(self.quant_type, str): + raise TypeError(self._build_string_quant_type_error(self.quant_type)) raise TypeError(f"quant_type must be an AOBaseConfig instance, got {type(self.quant_type).__name__}") + @staticmethod + def _build_string_quant_type_error(quant_type: str) -> str: + """Build a migration-guidance error for legacy string ``quant_type`` values. + + Older diffusers releases accepted lowercase strings such as ``"int8_weight_only"`` or ``"float8dq_e4m3_row"``. + That path was removed (see PR #13291) in favour of passing an ``AOBaseConfig`` subclass instance from + ``torchao.quantization`` directly. Users on torchao >= 0.16 also hit this because the legacy lowercase + factories were removed upstream. This helper surfaces the rename so users can self-migrate. + """ + # Map common legacy strings to their torchao Config-class replacements. We deliberately + # do not auto-instantiate the Config; the new API exposes options (granularity, dtype, + # version, ...) the legacy strings hard-coded, and silent defaults would be surprising. + legacy_to_config = { + "int4wo": "Int4WeightOnlyConfig", + "int4_weight_only": "Int4WeightOnlyConfig", + "int8wo": "Int8WeightOnlyConfig", + "int8_weight_only": "Int8WeightOnlyConfig", + "int8dq": "Int8DynamicActivationInt8WeightConfig", + "int8_dynamic_activation_int8_weight": "Int8DynamicActivationInt8WeightConfig", + "float8wo": "Float8WeightOnlyConfig", + "float8wo_e4m3": "Float8WeightOnlyConfig", + "float8wo_e5m2": "Float8WeightOnlyConfig", + "float8_weight_only": "Float8WeightOnlyConfig", + "float8dq": "Float8DynamicActivationFloat8WeightConfig", + "float8dq_e4m3": "Float8DynamicActivationFloat8WeightConfig", + "float8dq_e4m3_row": "Float8DynamicActivationFloat8WeightConfig", + "float8dq_e4m3_tensor": "Float8DynamicActivationFloat8WeightConfig", + "float8_dynamic_activation_float8_weight": "Float8DynamicActivationFloat8WeightConfig", + "float8_static_activation_float8_weight": "Float8StaticActivationFloat8WeightConfig", + } + suggestion = legacy_to_config.get(quant_type) + message = ( + f"TorchAoConfig no longer accepts string quant_type values (got {quant_type!r}); " + "pass an AOBaseConfig instance from torchao.quantization instead." + ) + if suggestion is not None: + message += ( + f" For {quant_type!r}, use " + f"`from torchao.quantization import {suggestion}; TorchAoConfig({suggestion}())`." + ) + message += ( + " See https://huggingface.co/docs/diffusers/main/en/quantization/torchao for " + "the full list of supported AOBaseConfig classes." + ) + return message + def to_dict(self): """Convert configuration to a dictionary.""" d = super().to_dict() diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 8a811cfc1c73..842b3c9e10d7 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -112,6 +112,31 @@ def test_post_init_check(self): with self.assertRaises(TypeError): _ = TorchAoConfig(42) + def test_string_quant_type_error_includes_migration_hint(self): + """ + Passing a legacy string quant_type should raise TypeError and the message should name the + replacement AOBaseConfig class so users on torchao >= 0.16 (where the legacy lowercase + factories were removed) can self-migrate. See issues #13286 and #13266. + """ + legacy_to_config = { + "int8_weight_only": "Int8WeightOnlyConfig", + "int8wo": "Int8WeightOnlyConfig", + "float8_weight_only": "Float8WeightOnlyConfig", + "float8dq_e4m3_row": "Float8DynamicActivationFloat8WeightConfig", + "float8_dynamic_activation_float8_weight": "Float8DynamicActivationFloat8WeightConfig", + } + for legacy, config_name in legacy_to_config.items(): + with self.assertRaises(TypeError) as cm: + TorchAoConfig(legacy) + message = str(cm.exception) + self.assertIn(repr(legacy), message) + self.assertIn(config_name, message) + + # Strings without a known mapping should still raise TypeError and point at the docs. + with self.assertRaises(TypeError) as cm: + TorchAoConfig("not_a_real_quant_type") + self.assertIn("AOBaseConfig", str(cm.exception)) + def test_repr(self): """ Check that there is no error in the repr