diff --git a/CHANGELOG.md b/CHANGELOG.md index b1c1fd19..35ff89f2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ All notable changes to this project will be documented in this file. - Weight upload support for yolo26-sem semantic segmentation models via `version.deploy()` and `workspace.deploy_model()` +- `ROBOFLOW_DISABLE_CLASS_SORTING` environment variable to preserve custom + class ordering during YOLO model deployment (opt-in, defaults to false) ## 1.3.9 diff --git a/roboflow/config.py b/roboflow/config.py index 0d569e95..e975d38c 100644 --- a/roboflow/config.py +++ b/roboflow/config.py @@ -86,6 +86,7 @@ def get_conditional_configuration_variable(key, default): RF_WORKSPACES = get_conditional_configuration_variable("workspaces", default={}) TQDM_DISABLE = os.getenv("TQDM_DISABLE", None) +DISABLE_CLASS_SORTING = os.getenv("ROBOFLOW_DISABLE_CLASS_SORTING", "false").lower() == "true" def load_roboflow_api_key(workspace_url=None): diff --git a/roboflow/util/model_processor.py b/roboflow/util/model_processor.py index 0022c11a..8adcb5f4 100644 --- a/roboflow/util/model_processor.py +++ b/roboflow/util/model_processor.py @@ -7,6 +7,7 @@ import yaml from roboflow.config import ( + DISABLE_CLASS_SORTING, TASK_CLS, TASK_DET, TASK_OBB, @@ -224,7 +225,11 @@ def _process_yolo(model_type: str, model_path: str, filename: str) -> tuple[str, class_names = [] for i, val in enumerate(model_instance.names): class_names.append((val, model_instance.names[val])) - class_names.sort(key=lambda x: x[0]) + # NOTE: When DISABLE_CLASS_SORTING is enabled, users are responsible for ensuring + # their model's names dict has properly ordered/sequential keys. Non-sequential keys + # may result in incorrect class-to-index mappings. + if not DISABLE_CLASS_SORTING: + class_names.sort(key=lambda x: x[0]) class_names = [x[1] for x in class_names] if ( diff --git a/tests/util/test_class_sorting.py b/tests/util/test_class_sorting.py new file mode 100644 index 00000000..ad4e930f --- /dev/null +++ b/tests/util/test_class_sorting.py @@ -0,0 +1,64 @@ +import importlib.util +import os +import unittest +from unittest.mock import patch + +# Get the path to config.py directly +config_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "roboflow", "config.py") + + +class TestClassSorting(unittest.TestCase): + """Test DISABLE_CLASS_SORTING configuration for class ordering.""" + + def test_disable_class_sorting_default_value(self): + """Test that DISABLE_CLASS_SORTING defaults to False.""" + # Load config module directly without triggering __init__.py + spec = importlib.util.spec_from_file_location("config", config_path) + config = importlib.util.module_from_spec(spec) + spec.loader.exec_module(config) + + self.assertFalse(config.DISABLE_CLASS_SORTING) + + def test_disable_class_sorting_env_var_true(self): + """Test that ROBOFLOW_DISABLE_CLASS_SORTING=true sets config to True.""" + with patch.dict(os.environ, {"ROBOFLOW_DISABLE_CLASS_SORTING": "true"}): + # Reload config to pick up env var + spec = importlib.util.spec_from_file_location("config", config_path) + config = importlib.util.module_from_spec(spec) + spec.loader.exec_module(config) + + # After reload, should be True + self.assertTrue(config.DISABLE_CLASS_SORTING) + + def test_disable_class_sorting_env_var_false(self): + """Test that ROBOFLOW_DISABLE_CLASS_SORTING=false keeps config as False.""" + with patch.dict(os.environ, {"ROBOFLOW_DISABLE_CLASS_SORTING": "false"}): + spec = importlib.util.spec_from_file_location("config", config_path) + config = importlib.util.module_from_spec(spec) + spec.loader.exec_module(config) + + self.assertFalse(config.DISABLE_CLASS_SORTING) + + def test_disable_class_sorting_env_var_case_insensitive(self): + """Test that env var is case-insensitive.""" + for value in ["True", "TRUE", "tRuE"]: + with patch.dict(os.environ, {"ROBOFLOW_DISABLE_CLASS_SORTING": value}): + spec = importlib.util.spec_from_file_location("config", config_path) + config = importlib.util.module_from_spec(spec) + spec.loader.exec_module(config) + + self.assertTrue(config.DISABLE_CLASS_SORTING) + + def test_config_import(self): + """Test that DISABLE_CLASS_SORTING exists and is a boolean.""" + spec = importlib.util.spec_from_file_location("config", config_path) + config = importlib.util.module_from_spec(spec) + spec.loader.exec_module(config) + + config_value = config.DISABLE_CLASS_SORTING + self.assertIsNotNone(config_value) + self.assertIsInstance(config_value, bool) + + +if __name__ == "__main__": + unittest.main()