Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ All notable changes to this project will be documented in this file.

- Weight upload support for yolo26-sem semantic segmentation models via
`version.deploy()` and `workspace.deploy_model()`
- `ROBOFLOW_DISABLE_CLASS_SORTING` environment variable to preserve custom
class ordering during YOLO model deployment (opt-in, defaults to false)

## 1.3.9

Expand Down
1 change: 1 addition & 0 deletions roboflow/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def get_conditional_configuration_variable(key, default):

RF_WORKSPACES = get_conditional_configuration_variable("workspaces", default={})
TQDM_DISABLE = os.getenv("TQDM_DISABLE", None)
DISABLE_CLASS_SORTING = os.getenv("ROBOFLOW_DISABLE_CLASS_SORTING", "false").lower() == "true"


def load_roboflow_api_key(workspace_url=None):
Expand Down
7 changes: 6 additions & 1 deletion roboflow/util/model_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import yaml

from roboflow.config import (
DISABLE_CLASS_SORTING,
TASK_CLS,
TASK_DET,
TASK_OBB,
Expand Down Expand Up @@ -224,7 +225,11 @@ def _process_yolo(model_type: str, model_path: str, filename: str) -> tuple[str,
class_names = []
for i, val in enumerate(model_instance.names):
class_names.append((val, model_instance.names[val]))
class_names.sort(key=lambda x: x[0])
# NOTE: When DISABLE_CLASS_SORTING is enabled, users are responsible for ensuring
# their model's names dict has properly ordered/sequential keys. Non-sequential keys
# may result in incorrect class-to-index mappings.
if not DISABLE_CLASS_SORTING:
class_names.sort(key=lambda x: x[0])
class_names = [x[1] for x in class_names]

if (
Expand Down
64 changes: 64 additions & 0 deletions tests/util/test_class_sorting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import importlib.util
import os
import unittest
from unittest.mock import patch

# Get the path to config.py directly
config_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "roboflow", "config.py")


class TestClassSorting(unittest.TestCase):
"""Test DISABLE_CLASS_SORTING configuration for class ordering."""

def test_disable_class_sorting_default_value(self):
"""Test that DISABLE_CLASS_SORTING defaults to False."""
# Load config module directly without triggering __init__.py
spec = importlib.util.spec_from_file_location("config", config_path)
config = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config)

self.assertFalse(config.DISABLE_CLASS_SORTING)

def test_disable_class_sorting_env_var_true(self):
"""Test that ROBOFLOW_DISABLE_CLASS_SORTING=true sets config to True."""
with patch.dict(os.environ, {"ROBOFLOW_DISABLE_CLASS_SORTING": "true"}):
# Reload config to pick up env var
spec = importlib.util.spec_from_file_location("config", config_path)
config = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config)

# After reload, should be True
self.assertTrue(config.DISABLE_CLASS_SORTING)

def test_disable_class_sorting_env_var_false(self):
"""Test that ROBOFLOW_DISABLE_CLASS_SORTING=false keeps config as False."""
with patch.dict(os.environ, {"ROBOFLOW_DISABLE_CLASS_SORTING": "false"}):
spec = importlib.util.spec_from_file_location("config", config_path)
config = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config)

self.assertFalse(config.DISABLE_CLASS_SORTING)

def test_disable_class_sorting_env_var_case_insensitive(self):
"""Test that env var is case-insensitive."""
for value in ["True", "TRUE", "tRuE"]:
with patch.dict(os.environ, {"ROBOFLOW_DISABLE_CLASS_SORTING": value}):
spec = importlib.util.spec_from_file_location("config", config_path)
config = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config)

self.assertTrue(config.DISABLE_CLASS_SORTING)

def test_config_import(self):
"""Test that DISABLE_CLASS_SORTING exists and is a boolean."""
spec = importlib.util.spec_from_file_location("config", config_path)
config = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config)

config_value = config.DISABLE_CLASS_SORTING
self.assertIsNotNone(config_value)
self.assertIsInstance(config_value, bool)


if __name__ == "__main__":
unittest.main()