From cf73491fb0046d2e8643c7ad07cfb62f722467b1 Mon Sep 17 00:00:00 2001 From: "Chenhan D. Yu" <5185878+ChenhanYu@users.noreply.github.com> Date: Thu, 19 Mar 2026 10:46:58 -0700 Subject: [PATCH 1/5] =?UTF-8?q?fix:=20address=20issue=20#1041=20=E2=80=94?= =?UTF-8?q?=20Feature:=20Add=20validation=20for=20loaded=20modelopt=20stat?= =?UTF-8?q?e?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- modelopt/torch/opt/conversion.py | 55 +++++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/opt/conversion.py b/modelopt/torch/opt/conversion.py index 6ec7a1729..693ab1796 100644 --- a/modelopt/torch/opt/conversion.py +++ b/modelopt/torch/opt/conversion.py @@ -522,12 +522,65 @@ def load_modelopt_state(modelopt_state_path: str | os.PathLike, **kwargs) -> dic Returns: A modelopt state dictionary describing the modifications to the model. + + Raises: + TypeError: If the loaded object is not a dictionary. + ValueError: If the loaded dictionary doesn't have the expected schema for a modelopt state file. """ # Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input kwargs.setdefault("weights_only", False) kwargs.setdefault("map_location", "cpu") - # TODO: Add some validation to ensure the file is a valid modelopt state file. modelopt_state = torch.load(modelopt_state_path, **kwargs) + + # Validate that the loaded object is a dictionary + if not isinstance(modelopt_state, dict): + raise TypeError( + f"Expected loaded modelopt state to be a dictionary, " + f"but got {type(modelopt_state).__name__}. " + f"The file may not be a valid modelopt state file." + ) + + # Validate that the dictionary has the expected keys + required_keys = {"modelopt_state_dict", "modelopt_version"} + missing_keys = required_keys - set(modelopt_state.keys()) + if missing_keys: + raise ValueError( + f"The loaded modelopt state is missing required keys: {missing_keys}. " + f"Expected keys: {required_keys}. " + f"The file may not be a valid modelopt state file." + ) + + # Validate that modelopt_state_dict is a list + state_dict = modelopt_state["modelopt_state_dict"] + if not isinstance(state_dict, list): + raise TypeError( + f"Expected 'modelopt_state_dict' to be a list, " + f"but got {type(state_dict).__name__}. " + f"The file may not be a valid modelopt state file." + ) + + # Validate that each entry in the state_dict is a tuple with 2 elements + for i, entry in enumerate(state_dict): + if not isinstance(entry, tuple) or len(entry) != 2: + raise ValueError( + f"Expected each entry in 'modelopt_state_dict' to be a tuple of length 2, " + f"but entry {i} is {type(entry).__name__} with length {len(entry) if isinstance(entry, (tuple, list)) else 'N/A'}. " + f"The file may not be a valid modelopt state file." + ) + mode_name, mode_state = entry + if not isinstance(mode_name, str): + raise TypeError( + f"Expected mode name (first element of tuple) to be a string, " + f"but got {type(mode_name).__name__} at entry {i}. " + f"The file may not be a valid modelopt state file." + ) + if not isinstance(mode_state, dict): + raise TypeError( + f"Expected mode state (second element of tuple) to be a dictionary, " + f"but got {type(mode_state).__name__} at entry {i}. " + f"The file may not be a valid modelopt state file." + ) + return modelopt_state From 0b0d65f4acc6b0d037aa7f7e21747f47b080cda9 Mon Sep 17 00:00:00 2001 From: "Chenhan D. Yu" <5185878+ChenhanYu@users.noreply.github.com> Date: Thu, 19 Mar 2026 11:56:43 -0700 Subject: [PATCH 2/5] address review feedback on #1074 --- modelopt/torch/opt/conversion.py | 110 +++++++++++++++++-------------- 1 file changed, 62 insertions(+), 48 deletions(-) diff --git a/modelopt/torch/opt/conversion.py b/modelopt/torch/opt/conversion.py index 693ab1796..fdb9ae962 100644 --- a/modelopt/torch/opt/conversion.py +++ b/modelopt/torch/opt/conversion.py @@ -311,6 +311,66 @@ def update_last_state_before_save(self, model: nn.Module) -> None: last_mode.update_for_save(model, last_config, self._last_metadata) self._last_config = last_config + @staticmethod + def validate_modelopt_state(modelopt_state: Any) -> None: + """Validate that the loaded object is a valid modelopt state file. + + Args: + modelopt_state: The loaded object to validate. + + Raises: + TypeError: If the loaded object is not a dictionary or has invalid types for nested fields. + ValueError: If the loaded dictionary doesn't have the expected schema for a modelopt state file. + """ + # Validate that the loaded object is a dictionary + if not isinstance(modelopt_state, dict): + raise TypeError( + f"Expected loaded modelopt state to be a dictionary, " + f"but got {type(modelopt_state).__name__}. " + f"The file may not be a valid modelopt state file." + ) + + # Validate that the dictionary has the expected keys + required_keys = {"modelopt_state_dict", "modelopt_version"} + missing_keys = required_keys - set(modelopt_state.keys()) + if missing_keys: + raise ValueError( + f"The loaded modelopt state is missing required keys: {missing_keys}. " + f"Expected keys: {required_keys}. " + f"The file may not be a valid modelopt state file." + ) + + # Validate that modelopt_state_dict is a list + state_dict = modelopt_state["modelopt_state_dict"] + if not isinstance(state_dict, list): + raise TypeError( + f"Expected 'modelopt_state_dict' to be a list, " + f"but got {type(state_dict).__name__}. " + f"The file may not be a valid modelopt state file." + ) + + # Validate that each entry in the state_dict is a tuple with 2 elements + for i, entry in enumerate(state_dict): + if not isinstance(entry, tuple) or len(entry) != 2: + raise ValueError( + f"Expected each entry in 'modelopt_state_dict' to be a tuple of length 2, " + f"but entry {i} is {type(entry).__name__} with length {len(entry) if isinstance(entry, (tuple, list)) else 'N/A'}. " + f"The file may not be a valid modelopt state file." + ) + mode_name, mode_state = entry + if not isinstance(mode_name, str): + raise TypeError( + f"Expected mode name (first element of tuple) to be a string, " + f"but got {type(mode_name).__name__} at entry {i}. " + f"The file may not be a valid modelopt state file." + ) + if not isinstance(mode_state, dict): + raise TypeError( + f"Expected mode state (second element of tuple) to be a dictionary, " + f"but got {type(mode_state).__name__} at entry {i}. " + f"The file may not be a valid modelopt state file." + ) + class ApplyModeError(RuntimeError): """Error raised when applying a mode to a model fails.""" @@ -532,54 +592,8 @@ def load_modelopt_state(modelopt_state_path: str | os.PathLike, **kwargs) -> dic kwargs.setdefault("map_location", "cpu") modelopt_state = torch.load(modelopt_state_path, **kwargs) - # Validate that the loaded object is a dictionary - if not isinstance(modelopt_state, dict): - raise TypeError( - f"Expected loaded modelopt state to be a dictionary, " - f"but got {type(modelopt_state).__name__}. " - f"The file may not be a valid modelopt state file." - ) - - # Validate that the dictionary has the expected keys - required_keys = {"modelopt_state_dict", "modelopt_version"} - missing_keys = required_keys - set(modelopt_state.keys()) - if missing_keys: - raise ValueError( - f"The loaded modelopt state is missing required keys: {missing_keys}. " - f"Expected keys: {required_keys}. " - f"The file may not be a valid modelopt state file." - ) - - # Validate that modelopt_state_dict is a list - state_dict = modelopt_state["modelopt_state_dict"] - if not isinstance(state_dict, list): - raise TypeError( - f"Expected 'modelopt_state_dict' to be a list, " - f"but got {type(state_dict).__name__}. " - f"The file may not be a valid modelopt state file." - ) - - # Validate that each entry in the state_dict is a tuple with 2 elements - for i, entry in enumerate(state_dict): - if not isinstance(entry, tuple) or len(entry) != 2: - raise ValueError( - f"Expected each entry in 'modelopt_state_dict' to be a tuple of length 2, " - f"but entry {i} is {type(entry).__name__} with length {len(entry) if isinstance(entry, (tuple, list)) else 'N/A'}. " - f"The file may not be a valid modelopt state file." - ) - mode_name, mode_state = entry - if not isinstance(mode_name, str): - raise TypeError( - f"Expected mode name (first element of tuple) to be a string, " - f"but got {type(mode_name).__name__} at entry {i}. " - f"The file may not be a valid modelopt state file." - ) - if not isinstance(mode_state, dict): - raise TypeError( - f"Expected mode state (second element of tuple) to be a dictionary, " - f"but got {type(mode_state).__name__} at entry {i}. " - f"The file may not be a valid modelopt state file." - ) + # Validate the loaded modelopt state + ModeloptStateManager.validate_modelopt_state(modelopt_state) return modelopt_state From 87e6e8f1ec9b2f7917e05548876e076dd6c9b236 Mon Sep 17 00:00:00 2001 From: "Chenhan D. Yu" <5185878+ChenhanYu@users.noreply.github.com> Date: Thu, 19 Mar 2026 16:57:32 -0700 Subject: [PATCH 3/5] address review feedback on #1074 Signed-off-by: Pensieve Bot --- .../opt/test_modelopt_state_validation.py | 140 ++++++++++++++++++ 1 file changed, 140 insertions(+) create mode 100644 tests/unit/torch/opt/test_modelopt_state_validation.py diff --git a/tests/unit/torch/opt/test_modelopt_state_validation.py b/tests/unit/torch/opt/test_modelopt_state_validation.py new file mode 100644 index 000000000..1ecf65afb --- /dev/null +++ b/tests/unit/torch/opt/test_modelopt_state_validation.py @@ -0,0 +1,140 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile + +import pytest +import torch + +import modelopt.torch.opt as mto + + +class TestModeloptStateValidation: + """Test suite for modelopt state validation.""" + + def test_validate_modelopt_state_valid(self): + """Test validation of a valid modelopt state.""" + valid_state = { + "modelopt_state_dict": [], + "modelopt_version": "0.1.0", + } + # Should not raise any exception + mto.ModeloptStateManager.validate_modelopt_state(valid_state) + + def test_validate_modelopt_state_not_dict(self): + """Test validation fails when state is not a dictionary.""" + with pytest.raises(TypeError) as exc_info: + mto.ModeloptStateManager.validate_modelopt_state([1, 2, 3]) + assert "Expected loaded modelopt state to be a dictionary" in str(exc_info.value) + + def test_validate_modelopt_state_missing_keys(self): + """Test validation fails when required keys are missing.""" + with pytest.raises(ValueError) as exc_info: + mto.ModeloptStateManager.validate_modelopt_state({"modelopt_state_dict": []}) + assert "missing required keys" in str(exc_info.value) + assert "modelopt_version" in str(exc_info.value) + + def test_validate_modelopt_state_invalid_state_dict_type(self): + """Test validation fails when modelopt_state_dict is not a list.""" + with pytest.raises(TypeError) as exc_info: + mto.ModeloptStateManager.validate_modelopt_state({ + "modelopt_state_dict": "not a list", + "modelopt_version": "0.1.0", + }) + assert "modelopt_state_dict" in str(exc_info.value) + + def test_validate_modelopt_state_invalid_entry_not_tuple(self): + """Test validation fails when state_dict entry is not a tuple.""" + with pytest.raises(ValueError) as exc_info: + mto.ModeloptStateManager.validate_modelopt_state({ + "modelopt_state_dict": [{"mode": "quantize"}], + "modelopt_version": "0.1.0", + }) + assert "tuple of length 2" in str(exc_info.value) + + def test_validate_modelopt_state_invalid_entry_wrong_length(self): + """Test validation fails when tuple has wrong length.""" + with pytest.raises(ValueError) as exc_info: + mto.ModeloptStateManager.validate_modelopt_state({ + "modelopt_state_dict": [("quantize",)], + "modelopt_version": "0.1.0", + }) + assert "tuple of length 2" in str(exc_info.value) + + def test_validate_modelopt_state_invalid_mode_name_type(self): + """Test validation fails when mode name is not a string.""" + with pytest.raises(TypeError) as exc_info: + mto.ModeloptStateManager.validate_modelopt_state({ + "modelopt_state_dict": [(123, {})], + "modelopt_version": "0.1.0", + }) + assert "mode name" in str(exc_info.value) + assert "string" in str(exc_info.value) + + def test_validate_modelopt_state_invalid_mode_state_type(self): + """Test validation fails when mode state is not a dictionary.""" + with pytest.raises(TypeError) as exc_info: + mto.ModeloptStateManager.validate_modelopt_state({ + "modelopt_state_dict": [("quantize", "not a dict")], + "modelopt_version": "0.1.0", + }) + assert "mode state" in str(exc_info.value) + assert "dictionary" in str(exc_info.value) + + def test_load_modelopt_state_valid_file(self): + """Test loading a valid modelopt state from file.""" + valid_state = { + "modelopt_state_dict": [], + "modelopt_version": "0.1.0", + } + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: + temp_file = f.name + try: + torch.save(valid_state, temp_file) + loaded_state = mto.load_modelopt_state(temp_file) + assert loaded_state == valid_state + finally: + os.remove(temp_file) + + def test_load_modelopt_state_invalid_file(self): + """Test loading an invalid modelopt state from file.""" + invalid_state = [1, 2, 3] + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: + temp_file = f.name + try: + torch.save(invalid_state, temp_file) + with pytest.raises(TypeError) as exc_info: + mto.load_modelopt_state(temp_file) + assert "Expected loaded modelopt state to be a dictionary" in str(exc_info.value) + finally: + os.remove(temp_file) + + def test_load_modelopt_state_with_valid_entries(self): + """Test loading modelopt state with valid mode entries.""" + valid_state = { + "modelopt_state_dict": [ + ("quantize", {"config": {}, "metadata": {}}), + ], + "modelopt_version": "0.1.0", + } + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: + temp_file = f.name + try: + torch.save(valid_state, temp_file) + loaded_state = mto.load_modelopt_state(temp_file) + assert loaded_state == valid_state + finally: + os.remove(temp_file) From e13187442afe07b7cf68d05623758832cb2422f7 Mon Sep 17 00:00:00 2001 From: "Chenhan D. Yu" <5185878+ChenhanYu@users.noreply.github.com> Date: Thu, 19 Mar 2026 18:04:56 -0700 Subject: [PATCH 4/5] address review feedback on #1074 Signed-off-by: Pensieve Bot --- modelopt/torch/opt/conversion.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/opt/conversion.py b/modelopt/torch/opt/conversion.py index fdb9ae962..5efaa7a0c 100644 --- a/modelopt/torch/opt/conversion.py +++ b/modelopt/torch/opt/conversion.py @@ -340,6 +340,15 @@ def validate_modelopt_state(modelopt_state: Any) -> None: f"The file may not be a valid modelopt state file." ) + # Validate that modelopt_version is a string + version = modelopt_state["modelopt_version"] + if not isinstance(version, str): + raise TypeError( + f"Expected 'modelopt_version' to be a string, " + f"but got {type(version).__name__}. " + f"The file may not be a valid modelopt state file." + ) + # Validate that modelopt_state_dict is a list state_dict = modelopt_state["modelopt_state_dict"] if not isinstance(state_dict, list): @@ -352,11 +361,19 @@ def validate_modelopt_state(modelopt_state: Any) -> None: # Validate that each entry in the state_dict is a tuple with 2 elements for i, entry in enumerate(state_dict): if not isinstance(entry, tuple) or len(entry) != 2: - raise ValueError( - f"Expected each entry in 'modelopt_state_dict' to be a tuple of length 2, " - f"but entry {i} is {type(entry).__name__} with length {len(entry) if isinstance(entry, (tuple, list)) else 'N/A'}. " - f"The file may not be a valid modelopt state file." + entry_type = type(entry).__name__ + entry_len = ( + len(entry) + if isinstance(entry, (tuple, list)) + else "N/A" + ) + msg = ( + f"Expected each entry in 'modelopt_state_dict' to be " + f"a tuple of length 2, but entry {i} is {entry_type} " + f"with length {entry_len}. The file may not be a " + f"valid modelopt state file." ) + raise ValueError(msg) mode_name, mode_state = entry if not isinstance(mode_name, str): raise TypeError( From 7289799c4f9313546be18671af1acb99bcc8ee3f Mon Sep 17 00:00:00 2001 From: "Chenhan D. Yu" <5185878+ChenhanYu@users.noreply.github.com> Date: Thu, 19 Mar 2026 18:04:57 -0700 Subject: [PATCH 5/5] address review feedback on #1074 Signed-off-by: Pensieve Bot --- tests/unit/torch/opt/test_modelopt_state_validation.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/unit/torch/opt/test_modelopt_state_validation.py b/tests/unit/torch/opt/test_modelopt_state_validation.py index 1ecf65afb..af270fa8d 100644 --- a/tests/unit/torch/opt/test_modelopt_state_validation.py +++ b/tests/unit/torch/opt/test_modelopt_state_validation.py @@ -47,6 +47,16 @@ def test_validate_modelopt_state_missing_keys(self): assert "missing required keys" in str(exc_info.value) assert "modelopt_version" in str(exc_info.value) + def test_validate_modelopt_state_invalid_version_type(self): + """Test validation fails when modelopt_version is not a string.""" + with pytest.raises(TypeError) as exc_info: + mto.ModeloptStateManager.validate_modelopt_state({ + "modelopt_state_dict": [], + "modelopt_version": 123, + }) + assert "modelopt_version" in str(exc_info.value) + assert "string" in str(exc_info.value) + def test_validate_modelopt_state_invalid_state_dict_type(self): """Test validation fails when modelopt_state_dict is not a list.""" with pytest.raises(TypeError) as exc_info: