Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 85 additions & 1 deletion modelopt/torch/opt/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,83 @@ def update_last_state_before_save(self, model: nn.Module) -> None:
last_mode.update_for_save(model, last_config, self._last_metadata)
self._last_config = last_config

@staticmethod
def validate_modelopt_state(modelopt_state: Any) -> None:
"""Validate that the loaded object is a valid modelopt state file.

Args:
modelopt_state: The loaded object to validate.

Raises:
TypeError: If the loaded object is not a dictionary or has invalid types for nested fields.
ValueError: If the loaded dictionary doesn't have the expected schema for a modelopt state file.
"""
# Validate that the loaded object is a dictionary
if not isinstance(modelopt_state, dict):
raise TypeError(
f"Expected loaded modelopt state to be a dictionary, "
f"but got {type(modelopt_state).__name__}. "
f"The file may not be a valid modelopt state file."
)

# Validate that the dictionary has the expected keys
required_keys = {"modelopt_state_dict", "modelopt_version"}
missing_keys = required_keys - set(modelopt_state.keys())
if missing_keys:
raise ValueError(
f"The loaded modelopt state is missing required keys: {missing_keys}. "
f"Expected keys: {required_keys}. "
f"The file may not be a valid modelopt state file."
)

# Validate that modelopt_version is a string
version = modelopt_state["modelopt_version"]
if not isinstance(version, str):
raise TypeError(
f"Expected 'modelopt_version' to be a string, "
f"but got {type(version).__name__}. "
f"The file may not be a valid modelopt state file."
)

# Validate that modelopt_state_dict is a list
state_dict = modelopt_state["modelopt_state_dict"]
if not isinstance(state_dict, list):
raise TypeError(
f"Expected 'modelopt_state_dict' to be a list, "
f"but got {type(state_dict).__name__}. "
f"The file may not be a valid modelopt state file."
)

# Validate that each entry in the state_dict is a tuple with 2 elements
for i, entry in enumerate(state_dict):
if not isinstance(entry, tuple) or len(entry) != 2:
entry_type = type(entry).__name__
entry_len = (
len(entry)
if isinstance(entry, (tuple, list))
else "N/A"
)
msg = (
f"Expected each entry in 'modelopt_state_dict' to be "
f"a tuple of length 2, but entry {i} is {entry_type} "
f"with length {entry_len}. The file may not be a "
f"valid modelopt state file."
)
raise ValueError(msg)
mode_name, mode_state = entry
if not isinstance(mode_name, str):
raise TypeError(
f"Expected mode name (first element of tuple) to be a string, "
f"but got {type(mode_name).__name__} at entry {i}. "
f"The file may not be a valid modelopt state file."
)
if not isinstance(mode_state, dict):
raise TypeError(
f"Expected mode state (second element of tuple) to be a dictionary, "
f"but got {type(mode_state).__name__} at entry {i}. "
f"The file may not be a valid modelopt state file."
)


class ApplyModeError(RuntimeError):
"""Error raised when applying a mode to a model fails."""
Expand Down Expand Up @@ -522,12 +599,19 @@ def load_modelopt_state(modelopt_state_path: str | os.PathLike, **kwargs) -> dic

Returns:
A modelopt state dictionary describing the modifications to the model.

Raises:
TypeError: If the loaded object is not a dictionary.
ValueError: If the loaded dictionary doesn't have the expected schema for a modelopt state file.
"""
# Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input
kwargs.setdefault("weights_only", False)
kwargs.setdefault("map_location", "cpu")
# TODO: Add some validation to ensure the file is a valid modelopt state file.
modelopt_state = torch.load(modelopt_state_path, **kwargs)

# Validate the loaded modelopt state
ModeloptStateManager.validate_modelopt_state(modelopt_state)

return modelopt_state


Expand Down
150 changes: 150 additions & 0 deletions tests/unit/torch/opt/test_modelopt_state_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile

import pytest
import torch

import modelopt.torch.opt as mto


class TestModeloptStateValidation:
"""Test suite for modelopt state validation."""

def test_validate_modelopt_state_valid(self):
"""Test validation of a valid modelopt state."""
valid_state = {
"modelopt_state_dict": [],
"modelopt_version": "0.1.0",
}
# Should not raise any exception
mto.ModeloptStateManager.validate_modelopt_state(valid_state)

def test_validate_modelopt_state_not_dict(self):
"""Test validation fails when state is not a dictionary."""
with pytest.raises(TypeError) as exc_info:
mto.ModeloptStateManager.validate_modelopt_state([1, 2, 3])
assert "Expected loaded modelopt state to be a dictionary" in str(exc_info.value)

def test_validate_modelopt_state_missing_keys(self):
"""Test validation fails when required keys are missing."""
with pytest.raises(ValueError) as exc_info:
mto.ModeloptStateManager.validate_modelopt_state({"modelopt_state_dict": []})
assert "missing required keys" in str(exc_info.value)
assert "modelopt_version" in str(exc_info.value)

def test_validate_modelopt_state_invalid_version_type(self):
"""Test validation fails when modelopt_version is not a string."""
with pytest.raises(TypeError) as exc_info:
mto.ModeloptStateManager.validate_modelopt_state({
"modelopt_state_dict": [],
"modelopt_version": 123,
})
assert "modelopt_version" in str(exc_info.value)
assert "string" in str(exc_info.value)

def test_validate_modelopt_state_invalid_state_dict_type(self):
"""Test validation fails when modelopt_state_dict is not a list."""
with pytest.raises(TypeError) as exc_info:
mto.ModeloptStateManager.validate_modelopt_state({
"modelopt_state_dict": "not a list",
"modelopt_version": "0.1.0",
})
assert "modelopt_state_dict" in str(exc_info.value)

def test_validate_modelopt_state_invalid_entry_not_tuple(self):
"""Test validation fails when state_dict entry is not a tuple."""
with pytest.raises(ValueError) as exc_info:
mto.ModeloptStateManager.validate_modelopt_state({
"modelopt_state_dict": [{"mode": "quantize"}],
"modelopt_version": "0.1.0",
})
assert "tuple of length 2" in str(exc_info.value)

def test_validate_modelopt_state_invalid_entry_wrong_length(self):
"""Test validation fails when tuple has wrong length."""
with pytest.raises(ValueError) as exc_info:
mto.ModeloptStateManager.validate_modelopt_state({
"modelopt_state_dict": [("quantize",)],
"modelopt_version": "0.1.0",
})
assert "tuple of length 2" in str(exc_info.value)

def test_validate_modelopt_state_invalid_mode_name_type(self):
"""Test validation fails when mode name is not a string."""
with pytest.raises(TypeError) as exc_info:
mto.ModeloptStateManager.validate_modelopt_state({
"modelopt_state_dict": [(123, {})],
"modelopt_version": "0.1.0",
})
assert "mode name" in str(exc_info.value)
assert "string" in str(exc_info.value)

def test_validate_modelopt_state_invalid_mode_state_type(self):
"""Test validation fails when mode state is not a dictionary."""
with pytest.raises(TypeError) as exc_info:
mto.ModeloptStateManager.validate_modelopt_state({
"modelopt_state_dict": [("quantize", "not a dict")],
"modelopt_version": "0.1.0",
})
assert "mode state" in str(exc_info.value)
assert "dictionary" in str(exc_info.value)

def test_load_modelopt_state_valid_file(self):
"""Test loading a valid modelopt state from file."""
valid_state = {
"modelopt_state_dict": [],
"modelopt_version": "0.1.0",
}
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
temp_file = f.name
try:
torch.save(valid_state, temp_file)
loaded_state = mto.load_modelopt_state(temp_file)
assert loaded_state == valid_state
finally:
os.remove(temp_file)

def test_load_modelopt_state_invalid_file(self):
"""Test loading an invalid modelopt state from file."""
invalid_state = [1, 2, 3]
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
temp_file = f.name
try:
torch.save(invalid_state, temp_file)
with pytest.raises(TypeError) as exc_info:
mto.load_modelopt_state(temp_file)
assert "Expected loaded modelopt state to be a dictionary" in str(exc_info.value)
finally:
os.remove(temp_file)

def test_load_modelopt_state_with_valid_entries(self):
"""Test loading modelopt state with valid mode entries."""
valid_state = {
"modelopt_state_dict": [
("quantize", {"config": {}, "metadata": {}}),
],
"modelopt_version": "0.1.0",
}
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
temp_file = f.name
try:
torch.save(valid_state, temp_file)
loaded_state = mto.load_modelopt_state(temp_file)
assert loaded_state == valid_state
finally:
os.remove(temp_file)
Loading