diff --git a/modelopt/torch/opt/conversion.py b/modelopt/torch/opt/conversion.py index 6ec7a1729..5efaa7a0c 100644 --- a/modelopt/torch/opt/conversion.py +++ b/modelopt/torch/opt/conversion.py @@ -311,6 +311,83 @@ def update_last_state_before_save(self, model: nn.Module) -> None: last_mode.update_for_save(model, last_config, self._last_metadata) self._last_config = last_config + @staticmethod + def validate_modelopt_state(modelopt_state: Any) -> None: + """Validate that the loaded object is a valid modelopt state file. + + Args: + modelopt_state: The loaded object to validate. + + Raises: + TypeError: If the loaded object is not a dictionary or has invalid types for nested fields. + ValueError: If the loaded dictionary doesn't have the expected schema for a modelopt state file. + """ + # Validate that the loaded object is a dictionary + if not isinstance(modelopt_state, dict): + raise TypeError( + f"Expected loaded modelopt state to be a dictionary, " + f"but got {type(modelopt_state).__name__}. " + f"The file may not be a valid modelopt state file." + ) + + # Validate that the dictionary has the expected keys + required_keys = {"modelopt_state_dict", "modelopt_version"} + missing_keys = required_keys - set(modelopt_state.keys()) + if missing_keys: + raise ValueError( + f"The loaded modelopt state is missing required keys: {missing_keys}. " + f"Expected keys: {required_keys}. " + f"The file may not be a valid modelopt state file." + ) + + # Validate that modelopt_version is a string + version = modelopt_state["modelopt_version"] + if not isinstance(version, str): + raise TypeError( + f"Expected 'modelopt_version' to be a string, " + f"but got {type(version).__name__}. " + f"The file may not be a valid modelopt state file." + ) + + # Validate that modelopt_state_dict is a list + state_dict = modelopt_state["modelopt_state_dict"] + if not isinstance(state_dict, list): + raise TypeError( + f"Expected 'modelopt_state_dict' to be a list, " + f"but got {type(state_dict).__name__}. " + f"The file may not be a valid modelopt state file." + ) + + # Validate that each entry in the state_dict is a tuple with 2 elements + for i, entry in enumerate(state_dict): + if not isinstance(entry, tuple) or len(entry) != 2: + entry_type = type(entry).__name__ + entry_len = ( + len(entry) + if isinstance(entry, (tuple, list)) + else "N/A" + ) + msg = ( + f"Expected each entry in 'modelopt_state_dict' to be " + f"a tuple of length 2, but entry {i} is {entry_type} " + f"with length {entry_len}. The file may not be a " + f"valid modelopt state file." + ) + raise ValueError(msg) + mode_name, mode_state = entry + if not isinstance(mode_name, str): + raise TypeError( + f"Expected mode name (first element of tuple) to be a string, " + f"but got {type(mode_name).__name__} at entry {i}. " + f"The file may not be a valid modelopt state file." + ) + if not isinstance(mode_state, dict): + raise TypeError( + f"Expected mode state (second element of tuple) to be a dictionary, " + f"but got {type(mode_state).__name__} at entry {i}. " + f"The file may not be a valid modelopt state file." + ) + class ApplyModeError(RuntimeError): """Error raised when applying a mode to a model fails.""" @@ -522,12 +599,19 @@ def load_modelopt_state(modelopt_state_path: str | os.PathLike, **kwargs) -> dic Returns: A modelopt state dictionary describing the modifications to the model. + + Raises: + TypeError: If the loaded object is not a dictionary. + ValueError: If the loaded dictionary doesn't have the expected schema for a modelopt state file. """ # Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input kwargs.setdefault("weights_only", False) kwargs.setdefault("map_location", "cpu") - # TODO: Add some validation to ensure the file is a valid modelopt state file. modelopt_state = torch.load(modelopt_state_path, **kwargs) + + # Validate the loaded modelopt state + ModeloptStateManager.validate_modelopt_state(modelopt_state) + return modelopt_state diff --git a/tests/unit/torch/opt/test_modelopt_state_validation.py b/tests/unit/torch/opt/test_modelopt_state_validation.py new file mode 100644 index 000000000..af270fa8d --- /dev/null +++ b/tests/unit/torch/opt/test_modelopt_state_validation.py @@ -0,0 +1,150 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile + +import pytest +import torch + +import modelopt.torch.opt as mto + + +class TestModeloptStateValidation: + """Test suite for modelopt state validation.""" + + def test_validate_modelopt_state_valid(self): + """Test validation of a valid modelopt state.""" + valid_state = { + "modelopt_state_dict": [], + "modelopt_version": "0.1.0", + } + # Should not raise any exception + mto.ModeloptStateManager.validate_modelopt_state(valid_state) + + def test_validate_modelopt_state_not_dict(self): + """Test validation fails when state is not a dictionary.""" + with pytest.raises(TypeError) as exc_info: + mto.ModeloptStateManager.validate_modelopt_state([1, 2, 3]) + assert "Expected loaded modelopt state to be a dictionary" in str(exc_info.value) + + def test_validate_modelopt_state_missing_keys(self): + """Test validation fails when required keys are missing.""" + with pytest.raises(ValueError) as exc_info: + mto.ModeloptStateManager.validate_modelopt_state({"modelopt_state_dict": []}) + assert "missing required keys" in str(exc_info.value) + assert "modelopt_version" in str(exc_info.value) + + def test_validate_modelopt_state_invalid_version_type(self): + """Test validation fails when modelopt_version is not a string.""" + with pytest.raises(TypeError) as exc_info: + mto.ModeloptStateManager.validate_modelopt_state({ + "modelopt_state_dict": [], + "modelopt_version": 123, + }) + assert "modelopt_version" in str(exc_info.value) + assert "string" in str(exc_info.value) + + def test_validate_modelopt_state_invalid_state_dict_type(self): + """Test validation fails when modelopt_state_dict is not a list.""" + with pytest.raises(TypeError) as exc_info: + mto.ModeloptStateManager.validate_modelopt_state({ + "modelopt_state_dict": "not a list", + "modelopt_version": "0.1.0", + }) + assert "modelopt_state_dict" in str(exc_info.value) + + def test_validate_modelopt_state_invalid_entry_not_tuple(self): + """Test validation fails when state_dict entry is not a tuple.""" + with pytest.raises(ValueError) as exc_info: + mto.ModeloptStateManager.validate_modelopt_state({ + "modelopt_state_dict": [{"mode": "quantize"}], + "modelopt_version": "0.1.0", + }) + assert "tuple of length 2" in str(exc_info.value) + + def test_validate_modelopt_state_invalid_entry_wrong_length(self): + """Test validation fails when tuple has wrong length.""" + with pytest.raises(ValueError) as exc_info: + mto.ModeloptStateManager.validate_modelopt_state({ + "modelopt_state_dict": [("quantize",)], + "modelopt_version": "0.1.0", + }) + assert "tuple of length 2" in str(exc_info.value) + + def test_validate_modelopt_state_invalid_mode_name_type(self): + """Test validation fails when mode name is not a string.""" + with pytest.raises(TypeError) as exc_info: + mto.ModeloptStateManager.validate_modelopt_state({ + "modelopt_state_dict": [(123, {})], + "modelopt_version": "0.1.0", + }) + assert "mode name" in str(exc_info.value) + assert "string" in str(exc_info.value) + + def test_validate_modelopt_state_invalid_mode_state_type(self): + """Test validation fails when mode state is not a dictionary.""" + with pytest.raises(TypeError) as exc_info: + mto.ModeloptStateManager.validate_modelopt_state({ + "modelopt_state_dict": [("quantize", "not a dict")], + "modelopt_version": "0.1.0", + }) + assert "mode state" in str(exc_info.value) + assert "dictionary" in str(exc_info.value) + + def test_load_modelopt_state_valid_file(self): + """Test loading a valid modelopt state from file.""" + valid_state = { + "modelopt_state_dict": [], + "modelopt_version": "0.1.0", + } + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: + temp_file = f.name + try: + torch.save(valid_state, temp_file) + loaded_state = mto.load_modelopt_state(temp_file) + assert loaded_state == valid_state + finally: + os.remove(temp_file) + + def test_load_modelopt_state_invalid_file(self): + """Test loading an invalid modelopt state from file.""" + invalid_state = [1, 2, 3] + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: + temp_file = f.name + try: + torch.save(invalid_state, temp_file) + with pytest.raises(TypeError) as exc_info: + mto.load_modelopt_state(temp_file) + assert "Expected loaded modelopt state to be a dictionary" in str(exc_info.value) + finally: + os.remove(temp_file) + + def test_load_modelopt_state_with_valid_entries(self): + """Test loading modelopt state with valid mode entries.""" + valid_state = { + "modelopt_state_dict": [ + ("quantize", {"config": {}, "metadata": {}}), + ], + "modelopt_version": "0.1.0", + } + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: + temp_file = f.name + try: + torch.save(valid_state, temp_file) + loaded_state = mto.load_modelopt_state(temp_file) + assert loaded_state == valid_state + finally: + os.remove(temp_file)