From d2bb3326f8f19da70acf9d38a4c47b7be21f693f Mon Sep 17 00:00:00 2001 From: Lukas Ewecker Date: Tue, 10 Mar 2026 11:31:14 +0000 Subject: [PATCH 1/8] make docker registry independent of deployed environment --- .../model_engine_server/api/dependencies.py | 28 +++- .../model_engine_server/core/config.py | 1 + .../use_cases/llm_model_endpoint_use_cases.py | 4 - .../entrypoints/k8s_cache.py | 17 ++- .../infra/repositories/__init__.py | 2 + .../repositories/generic_docker_repository.py | 99 ++++++++++++ .../test_generic_docker_repository.py | 141 ++++++++++++++++++ 7 files changed, 277 insertions(+), 15 deletions(-) create mode 100644 model-engine/model_engine_server/infra/repositories/generic_docker_repository.py create mode 100644 model-engine/tests/unit/infra/repositories/test_generic_docker_repository.py diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index a2865e78a..19b5c1ad0 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -122,6 +122,7 @@ ECRDockerRepository, FakeDockerRepository, GARDockerRepository, + GenericDockerRepository, GCSFileLLMFineTuneEventsRepository, GCSFileLLMFineTuneRepository, LiveTokenizerRepository, @@ -146,6 +147,18 @@ logger = make_logger(logger_name()) + +def _infer_registry_type(prefix: str) -> str: + """Infer docker registry type from docker_repo_prefix.""" + if ".dkr.ecr." in prefix and ".amazonaws.com" in prefix: + return "ecr" + if prefix.endswith(".azurecr.io"): + return "acr" + if "-docker.pkg.dev" in prefix: + return "gar" + return "generic" + + basic_auth = HTTPBasic(auto_error=False) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False) @@ -383,16 +396,21 @@ def _get_external_interfaces( file_storage_gateway = S3FileStorageGateway() docker_repository: DockerRepository + registry_type = infra_config().docker_registry_type or _infer_registry_type( + infra_config().docker_repo_prefix + ) if CIRCLECI: docker_repository = FakeDockerRepository() - elif infra_config().cloud_provider == "onprem": - docker_repository = OnPremDockerRepository() - elif infra_config().cloud_provider == "azure": + elif registry_type == "ecr": + docker_repository = ECRDockerRepository() + elif registry_type == "acr": docker_repository = ACRDockerRepository() - elif infra_config().cloud_provider == "gcp": + elif registry_type == "gar": docker_repository = GARDockerRepository() + elif registry_type == "onprem": + docker_repository = OnPremDockerRepository() else: - docker_repository = ECRDockerRepository() + docker_repository = GenericDockerRepository() tokenizer_repository = LiveTokenizerRepository(llm_artifact_gateway=llm_artifact_gateway) diff --git a/model-engine/model_engine_server/core/config.py b/model-engine/model_engine_server/core/config.py index e721d3c5a..c689d8890 100644 --- a/model-engine/model_engine_server/core/config.py +++ b/model-engine/model_engine_server/core/config.py @@ -51,6 +51,7 @@ class _InfraConfig: prometheus_server_address: Optional[str] = None celery_broker_type_redis: Optional[bool] = None celery_enable_sha256: Optional[bool] = None + docker_registry_type: Optional[str] = None debug_mode: Optional[bool] = None diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 849d93566..fe0631d7f 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -387,10 +387,6 @@ def __init__( def check_docker_image_exists_for_image_tag( self, framework_image_tag: str, repository_name: str ): - # Skip ECR validation for on-prem deployments - images are in local registry - if infra_config().cloud_provider == "onprem": - return - if not self.docker_repository.image_exists( image_tag=framework_image_tag, repository_name=repository_name, diff --git a/model-engine/model_engine_server/entrypoints/k8s_cache.py b/model-engine/model_engine_server/entrypoints/k8s_cache.py index f39d304b8..f83dc81c8 100644 --- a/model-engine/model_engine_server/entrypoints/k8s_cache.py +++ b/model-engine/model_engine_server/entrypoints/k8s_cache.py @@ -40,10 +40,12 @@ from model_engine_server.infra.gateways.resources.sqs_queue_endpoint_resource_delegate import ( SQSQueueEndpointResourceDelegate, ) +from model_engine_server.api.dependencies import _infer_registry_type from model_engine_server.infra.repositories import ( ACRDockerRepository, ECRDockerRepository, FakeDockerRepository, + GenericDockerRepository, ) from model_engine_server.infra.repositories.db_model_endpoint_record_repository import ( DbModelEndpointRecordRepository, @@ -128,16 +130,19 @@ async def main(args: Any): ) image_cache_gateway = ImageCacheGateway() docker_repo: DockerRepository + registry_type = infra_config().docker_registry_type or _infer_registry_type( + infra_config().docker_repo_prefix + ) if CIRCLECI: docker_repo = FakeDockerRepository() - elif infra_config().cloud_provider == "onprem": - docker_repo = OnPremDockerRepository() - elif infra_config().cloud_provider == "azure" or infra_config().docker_repo_prefix.endswith( - "azurecr.io" - ): + elif registry_type == "ecr": + docker_repo = ECRDockerRepository() + elif registry_type == "acr": docker_repo = ACRDockerRepository() + elif registry_type == "onprem": + docker_repo = OnPremDockerRepository() else: - docker_repo = ECRDockerRepository() + docker_repo = GenericDockerRepository() while True: loop_start = time.time() await loop_iteration( diff --git a/model-engine/model_engine_server/infra/repositories/__init__.py b/model-engine/model_engine_server/infra/repositories/__init__.py index 5a34f7ec5..9682146cb 100644 --- a/model-engine/model_engine_server/infra/repositories/__init__.py +++ b/model-engine/model_engine_server/infra/repositories/__init__.py @@ -13,6 +13,7 @@ from .fake_docker_repository import FakeDockerRepository from .feature_flag_repository import FeatureFlagRepository from .gar_docker_repository import GARDockerRepository +from .generic_docker_repository import GenericDockerRepository from .gcs_file_llm_fine_tune_events_repository import GCSFileLLMFineTuneEventsRepository from .gcs_file_llm_fine_tune_repository import GCSFileLLMFineTuneRepository from .live_tokenizer_repository import LiveTokenizerRepository @@ -38,6 +39,7 @@ "ECRDockerRepository", "FakeDockerRepository", "GARDockerRepository", + "GenericDockerRepository", "FeatureFlagRepository", "GCSFileLLMFineTuneEventsRepository", "GCSFileLLMFineTuneRepository", diff --git a/model-engine/model_engine_server/infra/repositories/generic_docker_repository.py b/model-engine/model_engine_server/infra/repositories/generic_docker_repository.py new file mode 100644 index 000000000..97c758b91 --- /dev/null +++ b/model-engine/model_engine_server/infra/repositories/generic_docker_repository.py @@ -0,0 +1,99 @@ +import re +from typing import Optional +from urllib.parse import urlencode + +import requests +from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.repositories import DockerRepository + +logger = make_logger(logger_name()) + +_REQUEST_TIMEOUT = 10 + + +def _parse_www_authenticate(header: str) -> Optional[dict]: + """Parse a Www-Authenticate Bearer header into realm, service, and scope.""" + match = re.match(r'Bearer\s+(.*)', header, re.IGNORECASE) + if not match: + return None + params = {} + for m in re.finditer(r'(\w+)="([^"]*)"', match.group(1)): + params[m.group(1)] = m.group(2) + return params if "realm" in params else None + + +def _get_token(realm: str, service: Optional[str], scope: Optional[str]) -> Optional[str]: + """Fetch a bearer token from the registry's token endpoint.""" + query = {} + if service: + query["service"] = service + if scope: + query["scope"] = scope + url = f"{realm}?{urlencode(query)}" if query else realm + try: + resp = requests.get(url, timeout=_REQUEST_TIMEOUT) + if resp.status_code == 200: + data = resp.json() + return data.get("token") or data.get("access_token") + except (requests.RequestException, ValueError): + pass + return None + + +class GenericDockerRepository(DockerRepository): + """Registry-agnostic Docker repository using the OCI Distribution / Docker Registry V2 HTTP API.""" + + def image_exists( + self, image_tag: str, repository_name: str, aws_profile: Optional[str] = None + ) -> bool: + prefix = infra_config().docker_repo_prefix + registry = prefix.rstrip("/") + manifest_url = f"https://{registry}/v2/{repository_name}/manifests/{image_tag}" + headers = { + "Accept": ", ".join([ + "application/vnd.docker.distribution.manifest.v2+json", + "application/vnd.oci.image.manifest.v1+json", + "application/vnd.docker.distribution.manifest.list.v2+json", + "application/vnd.oci.image.index.v1+json", + ]) + } + + try: + resp = requests.head(manifest_url, headers=headers, timeout=_REQUEST_TIMEOUT) + + if resp.status_code == 200: + return True + + if resp.status_code == 401: + www_auth = resp.headers.get("Www-Authenticate", "") + auth_params = _parse_www_authenticate(www_auth) + if auth_params: + token = _get_token( + realm=auth_params["realm"], + service=auth_params.get("service"), + scope=auth_params.get("scope"), + ) + if token: + headers["Authorization"] = f"Bearer {token}" + resp = requests.head( + manifest_url, headers=headers, timeout=_REQUEST_TIMEOUT + ) + return resp.status_code == 200 + + return False + except requests.RequestException as e: + logger.warning(f"Failed to check image existence at {manifest_url}: {e}") + return False + + def get_image_url(self, image_tag: str, repository_name: str) -> str: + if self.is_repo_name(repository_name): + return f"{infra_config().docker_repo_prefix}/{repository_name}:{image_tag}" + return f"{repository_name}:{image_tag}" + + def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse: + raise NotImplementedError("GenericDockerRepository does not support building images") + + def get_latest_image_tag(self, repository_name: str) -> str: + raise NotImplementedError("GenericDockerRepository does not support querying latest tags") diff --git a/model-engine/tests/unit/infra/repositories/test_generic_docker_repository.py b/model-engine/tests/unit/infra/repositories/test_generic_docker_repository.py new file mode 100644 index 000000000..b6a51d0b6 --- /dev/null +++ b/model-engine/tests/unit/infra/repositories/test_generic_docker_repository.py @@ -0,0 +1,141 @@ +from unittest import mock + +import pytest +import requests +from model_engine_server.infra.repositories.generic_docker_repository import ( + GenericDockerRepository, + _parse_www_authenticate, +) + + +@pytest.fixture +def generic_docker_repo(): + return GenericDockerRepository() + + +@pytest.fixture +def mock_infra_config(): + with mock.patch( + "model_engine_server.infra.repositories.generic_docker_repository.infra_config" + ) as mock_config: + mock_config.return_value.docker_repo_prefix = "public.ecr.aws/b2z8n5q1" + yield mock_config + + +class TestParseWwwAuthenticate: + def test_parses_bearer_header(self): + header = 'Bearer realm="https://auth.example.com/token",service="registry.example.com",scope="repository:myrepo:pull"' + result = _parse_www_authenticate(header) + assert result == { + "realm": "https://auth.example.com/token", + "service": "registry.example.com", + "scope": "repository:myrepo:pull", + } + + def test_returns_none_for_basic_auth(self): + assert _parse_www_authenticate("Basic realm=\"registry\"") is None + + def test_returns_none_for_missing_realm(self): + assert _parse_www_authenticate('Bearer service="foo"') is None + + def test_returns_none_for_empty_string(self): + assert _parse_www_authenticate("") is None + + +class TestImageExists: + def test_returns_true_on_200(self, generic_docker_repo, mock_infra_config): + with mock.patch("model_engine_server.infra.repositories.generic_docker_repository.requests") as mock_requests: + mock_resp = mock.Mock() + mock_resp.status_code = 200 + mock_requests.head.return_value = mock_resp + + result = generic_docker_repo.image_exists("v0.4.0", "model-engine/vllm") + + assert result is True + mock_requests.head.assert_called_once() + call_url = mock_requests.head.call_args[0][0] + assert call_url == "https://public.ecr.aws/b2z8n5q1/v2/model-engine/vllm/manifests/v0.4.0" + + def test_returns_false_on_404(self, generic_docker_repo, mock_infra_config): + with mock.patch("model_engine_server.infra.repositories.generic_docker_repository.requests") as mock_requests: + mock_resp = mock.Mock() + mock_resp.status_code = 404 + mock_requests.head.return_value = mock_resp + + result = generic_docker_repo.image_exists("nonexistent", "vllm") + + assert result is False + + def test_returns_false_on_connection_error(self, generic_docker_repo, mock_infra_config): + with mock.patch("model_engine_server.infra.repositories.generic_docker_repository.requests") as mock_requests: + mock_requests.head.side_effect = requests.ConnectionError("unreachable") + mock_requests.ConnectionError = requests.ConnectionError + mock_requests.RequestException = requests.RequestException + + result = generic_docker_repo.image_exists("v1.0", "vllm") + + assert result is False + + def test_token_auth_on_401(self, generic_docker_repo, mock_infra_config): + with mock.patch("model_engine_server.infra.repositories.generic_docker_repository.requests") as mock_requests: + mock_requests.RequestException = requests.RequestException + + # First HEAD returns 401 with Www-Authenticate + unauthed_resp = mock.Mock() + unauthed_resp.status_code = 401 + unauthed_resp.headers = { + "Www-Authenticate": 'Bearer realm="https://public.ecr.aws/token",service="public.ecr.aws",scope="repository:b2z8n5q1/vllm:pull"' + } + + # Second HEAD (with token) returns 200 + authed_resp = mock.Mock() + authed_resp.status_code = 200 + + mock_requests.head.side_effect = [unauthed_resp, authed_resp] + + # Token endpoint returns a token + token_resp = mock.Mock() + token_resp.status_code = 200 + token_resp.json.return_value = {"token": "test-token-123"} + mock_requests.get.return_value = token_resp + + result = generic_docker_repo.image_exists("v0.4.0", "vllm") + + assert result is True + assert mock_requests.head.call_count == 2 + # Verify the second HEAD had the Authorization header + second_call_headers = mock_requests.head.call_args_list[1][1]["headers"] + assert second_call_headers["Authorization"] == "Bearer test-token-123" + + def test_returns_false_on_401_without_www_authenticate(self, generic_docker_repo, mock_infra_config): + with mock.patch("model_engine_server.infra.repositories.generic_docker_repository.requests") as mock_requests: + mock_requests.RequestException = requests.RequestException + + mock_resp = mock.Mock() + mock_resp.status_code = 401 + mock_resp.headers = {} + mock_requests.head.return_value = mock_resp + + result = generic_docker_repo.image_exists("v1.0", "vllm") + + assert result is False + + +class TestGetImageUrl: + def test_prepends_prefix_for_simple_repo_name(self, generic_docker_repo, mock_infra_config): + result = generic_docker_repo.get_image_url("v1.0", "vllm") + assert result == "public.ecr.aws/b2z8n5q1/vllm:v1.0" + + def test_no_prefix_for_full_url(self, generic_docker_repo, mock_infra_config): + result = generic_docker_repo.get_image_url("v1.0", "docker.io/library/nginx") + assert result == "docker.io/library/nginx:v1.0" + + +class TestNotImplemented: + def test_build_image_raises(self, generic_docker_repo): + with pytest.raises(NotImplementedError): + generic_docker_repo.build_image(None) + + def test_get_latest_image_tag_raises(self, generic_docker_repo): + with pytest.raises(NotImplementedError): + generic_docker_repo.get_latest_image_tag("vllm") From 3f384bf363e6e4dd86fcfbdf084ec7ee5ba85e40 Mon Sep 17 00:00:00 2001 From: Lukas Ewecker Date: Tue, 10 Mar 2026 11:57:55 +0000 Subject: [PATCH 2/8] address greptile comments --- .../model_engine_server/api/dependencies.py | 15 ++------------- model-engine/model_engine_server/core/config.py | 11 +++++++++++ .../model_engine_server/entrypoints/k8s_cache.py | 7 +++++-- .../repositories/generic_docker_repository.py | 9 ++++++--- .../test_generic_docker_repository.py | 2 +- 5 files changed, 25 insertions(+), 19 deletions(-) diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index 19b5c1ad0..181bb4567 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -14,7 +14,7 @@ from model_engine_server.core.auth.fake_authentication_repository import ( FakeAuthenticationRepository, ) -from model_engine_server.core.config import infra_config +from model_engine_server.core.config import infer_registry_type, infra_config from model_engine_server.core.loggers import ( LoggerTagKey, LoggerTagManager, @@ -148,17 +148,6 @@ logger = make_logger(logger_name()) -def _infer_registry_type(prefix: str) -> str: - """Infer docker registry type from docker_repo_prefix.""" - if ".dkr.ecr." in prefix and ".amazonaws.com" in prefix: - return "ecr" - if prefix.endswith(".azurecr.io"): - return "acr" - if "-docker.pkg.dev" in prefix: - return "gar" - return "generic" - - basic_auth = HTTPBasic(auto_error=False) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False) @@ -396,7 +385,7 @@ def _get_external_interfaces( file_storage_gateway = S3FileStorageGateway() docker_repository: DockerRepository - registry_type = infra_config().docker_registry_type or _infer_registry_type( + registry_type = infra_config().docker_registry_type or infer_registry_type( infra_config().docker_repo_prefix ) if CIRCLECI: diff --git a/model-engine/model_engine_server/core/config.py b/model-engine/model_engine_server/core/config.py index c689d8890..d499bc85d 100644 --- a/model-engine/model_engine_server/core/config.py +++ b/model-engine/model_engine_server/core/config.py @@ -110,6 +110,17 @@ def use_config_context(config_path: str): _infra_config = InfraConfig.from_yaml(config_path) +def infer_registry_type(prefix: str) -> str: + """Infer docker registry type from docker_repo_prefix.""" + if ".dkr.ecr." in prefix and ".amazonaws.com" in prefix: + return "ecr" + if prefix.endswith(".azurecr.io"): + return "acr" + if "-docker.pkg.dev" in prefix: + return "gar" + return "generic" + + def get_config_path_for_env_name(env_name: str) -> Path: path = DEFAULT_CONFIG_PATH.parent / f"{env_name}.yaml" if not path.exists(): diff --git a/model-engine/model_engine_server/entrypoints/k8s_cache.py b/model-engine/model_engine_server/entrypoints/k8s_cache.py index f83dc81c8..285102b2f 100644 --- a/model-engine/model_engine_server/entrypoints/k8s_cache.py +++ b/model-engine/model_engine_server/entrypoints/k8s_cache.py @@ -40,11 +40,12 @@ from model_engine_server.infra.gateways.resources.sqs_queue_endpoint_resource_delegate import ( SQSQueueEndpointResourceDelegate, ) -from model_engine_server.api.dependencies import _infer_registry_type +from model_engine_server.core.config import infer_registry_type from model_engine_server.infra.repositories import ( ACRDockerRepository, ECRDockerRepository, FakeDockerRepository, + GARDockerRepository, GenericDockerRepository, ) from model_engine_server.infra.repositories.db_model_endpoint_record_repository import ( @@ -130,7 +131,7 @@ async def main(args: Any): ) image_cache_gateway = ImageCacheGateway() docker_repo: DockerRepository - registry_type = infra_config().docker_registry_type or _infer_registry_type( + registry_type = infra_config().docker_registry_type or infer_registry_type( infra_config().docker_repo_prefix ) if CIRCLECI: @@ -139,6 +140,8 @@ async def main(args: Any): docker_repo = ECRDockerRepository() elif registry_type == "acr": docker_repo = ACRDockerRepository() + elif registry_type == "gar": + docker_repo = GARDockerRepository() elif registry_type == "onprem": docker_repo = OnPremDockerRepository() else: diff --git a/model-engine/model_engine_server/infra/repositories/generic_docker_repository.py b/model-engine/model_engine_server/infra/repositories/generic_docker_repository.py index 97c758b91..9a48259de 100644 --- a/model-engine/model_engine_server/infra/repositories/generic_docker_repository.py +++ b/model-engine/model_engine_server/infra/repositories/generic_docker_repository.py @@ -48,9 +48,12 @@ class GenericDockerRepository(DockerRepository): def image_exists( self, image_tag: str, repository_name: str, aws_profile: Optional[str] = None ) -> bool: - prefix = infra_config().docker_repo_prefix - registry = prefix.rstrip("/") - manifest_url = f"https://{registry}/v2/{repository_name}/manifests/{image_tag}" + prefix = infra_config().docker_repo_prefix.rstrip("/") + parts = prefix.split("/", 1) + registry_host = parts[0] + path_prefix = parts[1] if len(parts) > 1 else "" + full_repo = f"{path_prefix}/{repository_name}" if path_prefix else repository_name + manifest_url = f"https://{registry_host}/v2/{full_repo}/manifests/{image_tag}" headers = { "Accept": ", ".join([ "application/vnd.docker.distribution.manifest.v2+json", diff --git a/model-engine/tests/unit/infra/repositories/test_generic_docker_repository.py b/model-engine/tests/unit/infra/repositories/test_generic_docker_repository.py index b6a51d0b6..94c45f0c5 100644 --- a/model-engine/tests/unit/infra/repositories/test_generic_docker_repository.py +++ b/model-engine/tests/unit/infra/repositories/test_generic_docker_repository.py @@ -54,7 +54,7 @@ def test_returns_true_on_200(self, generic_docker_repo, mock_infra_config): assert result is True mock_requests.head.assert_called_once() call_url = mock_requests.head.call_args[0][0] - assert call_url == "https://public.ecr.aws/b2z8n5q1/v2/model-engine/vllm/manifests/v0.4.0" + assert call_url == "https://public.ecr.aws/v2/b2z8n5q1/model-engine/vllm/manifests/v0.4.0" def test_returns_false_on_404(self, generic_docker_repo, mock_infra_config): with mock.patch("model_engine_server.infra.repositories.generic_docker_repository.requests") as mock_requests: From 3a18cfa0d5bf2584db3358bf79cb9430530b36cf Mon Sep 17 00:00:00 2001 From: Lukas Ewecker Date: Tue, 10 Mar 2026 12:06:46 +0000 Subject: [PATCH 3/8] add infer_registry_type to module's __all__ --- model-engine/model_engine_server/core/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/model-engine/model_engine_server/core/config.py b/model-engine/model_engine_server/core/config.py index d499bc85d..84b4fda53 100644 --- a/model-engine/model_engine_server/core/config.py +++ b/model-engine/model_engine_server/core/config.py @@ -23,6 +23,7 @@ "CONFIG_PATH", "config_context", "get_config_path_for_env_name", + "infer_registry_type", "infra_config", "use_config_context", ) From 6600f8d3f93859ae9f554e057ff0e19b33b976f0 Mon Sep 17 00:00:00 2001 From: Lukas Ewecker Date: Thu, 12 Mar 2026 11:25:42 +0000 Subject: [PATCH 4/8] address greptile --- model-engine/model_engine_server/core/config.py | 2 +- .../infra/repositories/generic_docker_repository.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/model-engine/model_engine_server/core/config.py b/model-engine/model_engine_server/core/config.py index 84b4fda53..6886174ff 100644 --- a/model-engine/model_engine_server/core/config.py +++ b/model-engine/model_engine_server/core/config.py @@ -115,7 +115,7 @@ def infer_registry_type(prefix: str) -> str: """Infer docker registry type from docker_repo_prefix.""" if ".dkr.ecr." in prefix and ".amazonaws.com" in prefix: return "ecr" - if prefix.endswith(".azurecr.io"): + if ".azurecr.io" in prefix: return "acr" if "-docker.pkg.dev" in prefix: return "gar" diff --git a/model-engine/model_engine_server/infra/repositories/generic_docker_repository.py b/model-engine/model_engine_server/infra/repositories/generic_docker_repository.py index 9a48259de..9c9aad552 100644 --- a/model-engine/model_engine_server/infra/repositories/generic_docker_repository.py +++ b/model-engine/model_engine_server/infra/repositories/generic_docker_repository.py @@ -31,7 +31,8 @@ def _get_token(realm: str, service: Optional[str], scope: Optional[str]) -> Opti query["service"] = service if scope: query["scope"] = scope - url = f"{realm}?{urlencode(query)}" if query else realm + separator = "&" if "?" in realm else "?" + url = f"{realm}{separator}{urlencode(query)}" if query else realm try: resp = requests.get(url, timeout=_REQUEST_TIMEOUT) if resp.status_code == 200: From 012cae6b9bf2fc53ca9b9b09402ca61c5d7cade5 Mon Sep 17 00:00:00 2001 From: Lukas Ewecker Date: Thu, 12 Mar 2026 13:17:35 +0000 Subject: [PATCH 5/8] run black --- .../common/aiohttp_sse_client.py | 1 + .../core/celery/celery_autoscaler.py | 2 +- .../core/docker/docker_image.py | 8 ++--- .../repositories/generic_docker_repository.py | 16 +++++----- .../live_batch_job_orchestration_service.py | 2 +- model-engine/tests/unit/api/test_llms.py | 6 ++-- .../tests/unit/domain/test_llm_use_cases.py | 12 +++----- .../test_generic_docker_repository.py | 30 ++++++++++++++----- 8 files changed, 42 insertions(+), 35 deletions(-) diff --git a/model-engine/model_engine_server/common/aiohttp_sse_client.py b/model-engine/model_engine_server/common/aiohttp_sse_client.py index 20e71f3d4..ddec98747 100644 --- a/model-engine/model_engine_server/common/aiohttp_sse_client.py +++ b/model-engine/model_engine_server/common/aiohttp_sse_client.py @@ -2,6 +2,7 @@ # -*- coding: utf-8 -*- """Main module.""" + # import asyncio # import logging from datetime import timedelta diff --git a/model-engine/model_engine_server/core/celery/celery_autoscaler.py b/model-engine/model_engine_server/core/celery/celery_autoscaler.py index 923244a66..679908616 100644 --- a/model-engine/model_engine_server/core/celery/celery_autoscaler.py +++ b/model-engine/model_engine_server/core/celery/celery_autoscaler.py @@ -381,7 +381,7 @@ async def _get_connection_count(self): connection_count = info.get("connected_clients") max_connections = info.get("maxclients") else: - (info, config) = await aio.gather( + info, config = await aio.gather( redis_client.info(), redis_client.config_get("maxclients"), ) diff --git a/model-engine/model_engine_server/core/docker/docker_image.py b/model-engine/model_engine_server/core/docker/docker_image.py index 8d68f8c8b..70f5f0e56 100644 --- a/model-engine/model_engine_server/core/docker/docker_image.py +++ b/model-engine/model_engine_server/core/docker/docker_image.py @@ -144,14 +144,10 @@ def build( ) if test_command: - logger.info( - textwrap.dedent( - f""" + logger.info(textwrap.dedent(f""" Testing with 'docker run' on the built image. ARGS: {test_command} - (NOTE: Expecting the test command to terminate. """ - ) - ) + (NOTE: Expecting the test command to terminate. """)) home_dir = str(pathlib.Path.home()) output = docker_client.containers.run( # pylint:disable=no-member image=image, diff --git a/model-engine/model_engine_server/infra/repositories/generic_docker_repository.py b/model-engine/model_engine_server/infra/repositories/generic_docker_repository.py index 9c9aad552..08192903a 100644 --- a/model-engine/model_engine_server/infra/repositories/generic_docker_repository.py +++ b/model-engine/model_engine_server/infra/repositories/generic_docker_repository.py @@ -15,7 +15,7 @@ def _parse_www_authenticate(header: str) -> Optional[dict]: """Parse a Www-Authenticate Bearer header into realm, service, and scope.""" - match = re.match(r'Bearer\s+(.*)', header, re.IGNORECASE) + match = re.match(r"Bearer\s+(.*)", header, re.IGNORECASE) if not match: return None params = {} @@ -56,12 +56,14 @@ def image_exists( full_repo = f"{path_prefix}/{repository_name}" if path_prefix else repository_name manifest_url = f"https://{registry_host}/v2/{full_repo}/manifests/{image_tag}" headers = { - "Accept": ", ".join([ - "application/vnd.docker.distribution.manifest.v2+json", - "application/vnd.oci.image.manifest.v1+json", - "application/vnd.docker.distribution.manifest.list.v2+json", - "application/vnd.oci.image.index.v1+json", - ]) + "Accept": ", ".join( + [ + "application/vnd.docker.distribution.manifest.v2+json", + "application/vnd.oci.image.manifest.v1+json", + "application/vnd.docker.distribution.manifest.list.v2+json", + "application/vnd.oci.image.index.v1+json", + ] + ) } try: diff --git a/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py b/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py index 2562a3c8e..f4bb17d4c 100644 --- a/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py +++ b/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py @@ -246,7 +246,7 @@ async def _read_or_submit_tasks( task_ids = [ BatchEndpointInProgressTask.deserialize(tid) for tid in task_ids_serialized ] - num_task_ids = len(task_ids) # type:ignore + num_task_ids = len(task_ids) # type: ignore logger.info(f"Found {num_task_ids} pending tasks for batch job {job_id}") finally: if task_ids is None: diff --git a/model-engine/tests/unit/api/test_llms.py b/model-engine/tests/unit/api/test_llms.py index 163466c2c..0194e7a27 100644 --- a/model-engine/tests/unit/api/test_llms.py +++ b/model-engine/tests/unit/api/test_llms.py @@ -108,13 +108,11 @@ def test_completion_sync_success( fake_docker_image_batch_job_bundle_repository_contents={}, fake_sync_inference_content=SyncEndpointPredictV1Response( status=TaskStatus.SUCCESS, - result={ - "result": """{ + result={"result": """{ "text": "output", "count_prompt_tokens": 1, "count_output_tokens": 1 - }""" - }, + }"""}, traceback=None, status_code=200, ), diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 1738bd1ce..c89f90f6d 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -1077,8 +1077,7 @@ async def test_completion_sync_text_generation_inference_use_case_success( fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_text_generation_inference) fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = SyncEndpointPredictV1Response( status=TaskStatus.SUCCESS, - result={ - "result": """ + result={"result": """ { "generated_text": " Deep Learning is a new type of machine learning", "details": { @@ -1146,8 +1145,7 @@ async def test_completion_sync_text_generation_inference_use_case_success( ] } } -""" - }, +"""}, traceback=None, status_code=200, ) @@ -1376,14 +1374,12 @@ async def test_completion_sync_use_case_predict_failed_with_errors( fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync_tgi[0]) fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = SyncEndpointPredictV1Response( status=TaskStatus.SUCCESS, - result={ - "result": """ + result={"result": """ { "error": "Request failed during generation: Server error: transport error", "error_type": "generation" } -""" - }, +"""}, traceback="failed to predict", status_code=500, ) diff --git a/model-engine/tests/unit/infra/repositories/test_generic_docker_repository.py b/model-engine/tests/unit/infra/repositories/test_generic_docker_repository.py index 94c45f0c5..77112d770 100644 --- a/model-engine/tests/unit/infra/repositories/test_generic_docker_repository.py +++ b/model-engine/tests/unit/infra/repositories/test_generic_docker_repository.py @@ -33,7 +33,7 @@ def test_parses_bearer_header(self): } def test_returns_none_for_basic_auth(self): - assert _parse_www_authenticate("Basic realm=\"registry\"") is None + assert _parse_www_authenticate('Basic realm="registry"') is None def test_returns_none_for_missing_realm(self): assert _parse_www_authenticate('Bearer service="foo"') is None @@ -44,7 +44,9 @@ def test_returns_none_for_empty_string(self): class TestImageExists: def test_returns_true_on_200(self, generic_docker_repo, mock_infra_config): - with mock.patch("model_engine_server.infra.repositories.generic_docker_repository.requests") as mock_requests: + with mock.patch( + "model_engine_server.infra.repositories.generic_docker_repository.requests" + ) as mock_requests: mock_resp = mock.Mock() mock_resp.status_code = 200 mock_requests.head.return_value = mock_resp @@ -54,10 +56,14 @@ def test_returns_true_on_200(self, generic_docker_repo, mock_infra_config): assert result is True mock_requests.head.assert_called_once() call_url = mock_requests.head.call_args[0][0] - assert call_url == "https://public.ecr.aws/v2/b2z8n5q1/model-engine/vllm/manifests/v0.4.0" + assert ( + call_url == "https://public.ecr.aws/v2/b2z8n5q1/model-engine/vllm/manifests/v0.4.0" + ) def test_returns_false_on_404(self, generic_docker_repo, mock_infra_config): - with mock.patch("model_engine_server.infra.repositories.generic_docker_repository.requests") as mock_requests: + with mock.patch( + "model_engine_server.infra.repositories.generic_docker_repository.requests" + ) as mock_requests: mock_resp = mock.Mock() mock_resp.status_code = 404 mock_requests.head.return_value = mock_resp @@ -67,7 +73,9 @@ def test_returns_false_on_404(self, generic_docker_repo, mock_infra_config): assert result is False def test_returns_false_on_connection_error(self, generic_docker_repo, mock_infra_config): - with mock.patch("model_engine_server.infra.repositories.generic_docker_repository.requests") as mock_requests: + with mock.patch( + "model_engine_server.infra.repositories.generic_docker_repository.requests" + ) as mock_requests: mock_requests.head.side_effect = requests.ConnectionError("unreachable") mock_requests.ConnectionError = requests.ConnectionError mock_requests.RequestException = requests.RequestException @@ -77,7 +85,9 @@ def test_returns_false_on_connection_error(self, generic_docker_repo, mock_infra assert result is False def test_token_auth_on_401(self, generic_docker_repo, mock_infra_config): - with mock.patch("model_engine_server.infra.repositories.generic_docker_repository.requests") as mock_requests: + with mock.patch( + "model_engine_server.infra.repositories.generic_docker_repository.requests" + ) as mock_requests: mock_requests.RequestException = requests.RequestException # First HEAD returns 401 with Www-Authenticate @@ -107,8 +117,12 @@ def test_token_auth_on_401(self, generic_docker_repo, mock_infra_config): second_call_headers = mock_requests.head.call_args_list[1][1]["headers"] assert second_call_headers["Authorization"] == "Bearer test-token-123" - def test_returns_false_on_401_without_www_authenticate(self, generic_docker_repo, mock_infra_config): - with mock.patch("model_engine_server.infra.repositories.generic_docker_repository.requests") as mock_requests: + def test_returns_false_on_401_without_www_authenticate( + self, generic_docker_repo, mock_infra_config + ): + with mock.patch( + "model_engine_server.infra.repositories.generic_docker_repository.requests" + ) as mock_requests: mock_requests.RequestException = requests.RequestException mock_resp = mock.Mock() From a56da1e613faee4fb11a4192bbb874f9fb24cb75 Mon Sep 17 00:00:00 2001 From: Lukas Ewecker Date: Thu, 12 Mar 2026 14:06:08 +0000 Subject: [PATCH 6/8] ran linters as in ci --- model-engine/model_engine_server/api/dependencies.py | 2 +- model-engine/model_engine_server/entrypoints/k8s_cache.py | 3 +-- .../model_engine_server/infra/repositories/__init__.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index 181bb4567..4ca4a3001 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -122,9 +122,9 @@ ECRDockerRepository, FakeDockerRepository, GARDockerRepository, - GenericDockerRepository, GCSFileLLMFineTuneEventsRepository, GCSFileLLMFineTuneRepository, + GenericDockerRepository, LiveTokenizerRepository, LLMFineTuneRepository, OnPremDockerRepository, diff --git a/model-engine/model_engine_server/entrypoints/k8s_cache.py b/model-engine/model_engine_server/entrypoints/k8s_cache.py index 285102b2f..36bd8e96f 100644 --- a/model-engine/model_engine_server/entrypoints/k8s_cache.py +++ b/model-engine/model_engine_server/entrypoints/k8s_cache.py @@ -14,7 +14,7 @@ from model_engine_server.common.config import hmi_config from model_engine_server.common.constants import READYZ_FPATH from model_engine_server.common.env_vars import CIRCLECI -from model_engine_server.core.config import infra_config +from model_engine_server.core.config import infer_registry_type, infra_config from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.db.base import get_session_async_null_pool from model_engine_server.domain.repositories import DockerRepository @@ -40,7 +40,6 @@ from model_engine_server.infra.gateways.resources.sqs_queue_endpoint_resource_delegate import ( SQSQueueEndpointResourceDelegate, ) -from model_engine_server.core.config import infer_registry_type from model_engine_server.infra.repositories import ( ACRDockerRepository, ECRDockerRepository, diff --git a/model-engine/model_engine_server/infra/repositories/__init__.py b/model-engine/model_engine_server/infra/repositories/__init__.py index 9682146cb..adeb909c7 100644 --- a/model-engine/model_engine_server/infra/repositories/__init__.py +++ b/model-engine/model_engine_server/infra/repositories/__init__.py @@ -13,9 +13,9 @@ from .fake_docker_repository import FakeDockerRepository from .feature_flag_repository import FeatureFlagRepository from .gar_docker_repository import GARDockerRepository -from .generic_docker_repository import GenericDockerRepository from .gcs_file_llm_fine_tune_events_repository import GCSFileLLMFineTuneEventsRepository from .gcs_file_llm_fine_tune_repository import GCSFileLLMFineTuneRepository +from .generic_docker_repository import GenericDockerRepository from .live_tokenizer_repository import LiveTokenizerRepository from .llm_fine_tune_repository import LLMFineTuneRepository from .model_endpoint_cache_repository import ModelEndpointCacheRepository From 1ec5ec6cc4080fcb317d66720a094d194718d5f1 Mon Sep 17 00:00:00 2001 From: Lukas Ewecker Date: Thu, 12 Mar 2026 14:31:16 +0000 Subject: [PATCH 7/8] ran linters as in ci --- .../model_engine_server/core/docker/docker_image.py | 8 ++++++-- model-engine/tests/unit/api/test_llms.py | 6 ++++-- model-engine/tests/unit/domain/test_llm_use_cases.py | 12 ++++++++---- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/model-engine/model_engine_server/core/docker/docker_image.py b/model-engine/model_engine_server/core/docker/docker_image.py index 70f5f0e56..8d68f8c8b 100644 --- a/model-engine/model_engine_server/core/docker/docker_image.py +++ b/model-engine/model_engine_server/core/docker/docker_image.py @@ -144,10 +144,14 @@ def build( ) if test_command: - logger.info(textwrap.dedent(f""" + logger.info( + textwrap.dedent( + f""" Testing with 'docker run' on the built image. ARGS: {test_command} - (NOTE: Expecting the test command to terminate. """)) + (NOTE: Expecting the test command to terminate. """ + ) + ) home_dir = str(pathlib.Path.home()) output = docker_client.containers.run( # pylint:disable=no-member image=image, diff --git a/model-engine/tests/unit/api/test_llms.py b/model-engine/tests/unit/api/test_llms.py index 0194e7a27..163466c2c 100644 --- a/model-engine/tests/unit/api/test_llms.py +++ b/model-engine/tests/unit/api/test_llms.py @@ -108,11 +108,13 @@ def test_completion_sync_success( fake_docker_image_batch_job_bundle_repository_contents={}, fake_sync_inference_content=SyncEndpointPredictV1Response( status=TaskStatus.SUCCESS, - result={"result": """{ + result={ + "result": """{ "text": "output", "count_prompt_tokens": 1, "count_output_tokens": 1 - }"""}, + }""" + }, traceback=None, status_code=200, ), diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index c89f90f6d..1738bd1ce 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -1077,7 +1077,8 @@ async def test_completion_sync_text_generation_inference_use_case_success( fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_text_generation_inference) fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = SyncEndpointPredictV1Response( status=TaskStatus.SUCCESS, - result={"result": """ + result={ + "result": """ { "generated_text": " Deep Learning is a new type of machine learning", "details": { @@ -1145,7 +1146,8 @@ async def test_completion_sync_text_generation_inference_use_case_success( ] } } -"""}, +""" + }, traceback=None, status_code=200, ) @@ -1374,12 +1376,14 @@ async def test_completion_sync_use_case_predict_failed_with_errors( fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync_tgi[0]) fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = SyncEndpointPredictV1Response( status=TaskStatus.SUCCESS, - result={"result": """ + result={ + "result": """ { "error": "Request failed during generation: Server error: transport error", "error_type": "generation" } -"""}, +""" + }, traceback="failed to predict", status_code=500, ) From b24cbb5abe8901407705134102d2d7c2807708e9 Mon Sep 17 00:00:00 2001 From: Lukas Ewecker Date: Thu, 12 Mar 2026 15:02:39 +0000 Subject: [PATCH 8/8] fix test --- model-engine/tests/unit/api/test_dependencies.py | 1 + 1 file changed, 1 insertion(+) diff --git a/model-engine/tests/unit/api/test_dependencies.py b/model-engine/tests/unit/api/test_dependencies.py index 6d67d27c0..15491c368 100644 --- a/model-engine/tests/unit/api/test_dependencies.py +++ b/model-engine/tests/unit/api/test_dependencies.py @@ -101,6 +101,7 @@ def test_gcp_provider_selects_gcp_implementations(): mock_config_instance.cloud_provider = "gcp" mock_config_instance.celery_broker_type_redis = None mock_config_instance.docker_repo_prefix = "us-docker.pkg.dev/my-project/my-repo" + mock_config_instance.docker_registry_type = None mock_config.return_value = mock_config_instance mock_session = MagicMock()