diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index a2865e78a..4ca4a3001 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -14,7 +14,7 @@ from model_engine_server.core.auth.fake_authentication_repository import ( FakeAuthenticationRepository, ) -from model_engine_server.core.config import infra_config +from model_engine_server.core.config import infer_registry_type, infra_config from model_engine_server.core.loggers import ( LoggerTagKey, LoggerTagManager, @@ -124,6 +124,7 @@ GARDockerRepository, GCSFileLLMFineTuneEventsRepository, GCSFileLLMFineTuneRepository, + GenericDockerRepository, LiveTokenizerRepository, LLMFineTuneRepository, OnPremDockerRepository, @@ -146,6 +147,7 @@ logger = make_logger(logger_name()) + basic_auth = HTTPBasic(auto_error=False) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False) @@ -383,16 +385,21 @@ def _get_external_interfaces( file_storage_gateway = S3FileStorageGateway() docker_repository: DockerRepository + registry_type = infra_config().docker_registry_type or infer_registry_type( + infra_config().docker_repo_prefix + ) if CIRCLECI: docker_repository = FakeDockerRepository() - elif infra_config().cloud_provider == "onprem": - docker_repository = OnPremDockerRepository() - elif infra_config().cloud_provider == "azure": + elif registry_type == "ecr": + docker_repository = ECRDockerRepository() + elif registry_type == "acr": docker_repository = ACRDockerRepository() - elif infra_config().cloud_provider == "gcp": + elif registry_type == "gar": docker_repository = GARDockerRepository() + elif registry_type == "onprem": + docker_repository = OnPremDockerRepository() else: - docker_repository = ECRDockerRepository() + docker_repository = GenericDockerRepository() tokenizer_repository = LiveTokenizerRepository(llm_artifact_gateway=llm_artifact_gateway) diff --git a/model-engine/model_engine_server/common/aiohttp_sse_client.py b/model-engine/model_engine_server/common/aiohttp_sse_client.py index 20e71f3d4..ddec98747 100644 --- a/model-engine/model_engine_server/common/aiohttp_sse_client.py +++ b/model-engine/model_engine_server/common/aiohttp_sse_client.py @@ -2,6 +2,7 @@ # -*- coding: utf-8 -*- """Main module.""" + # import asyncio # import logging from datetime import timedelta diff --git a/model-engine/model_engine_server/core/celery/celery_autoscaler.py b/model-engine/model_engine_server/core/celery/celery_autoscaler.py index 923244a66..679908616 100644 --- a/model-engine/model_engine_server/core/celery/celery_autoscaler.py +++ b/model-engine/model_engine_server/core/celery/celery_autoscaler.py @@ -381,7 +381,7 @@ async def _get_connection_count(self): connection_count = info.get("connected_clients") max_connections = info.get("maxclients") else: - (info, config) = await aio.gather( + info, config = await aio.gather( redis_client.info(), redis_client.config_get("maxclients"), ) diff --git a/model-engine/model_engine_server/core/config.py b/model-engine/model_engine_server/core/config.py index e721d3c5a..6886174ff 100644 --- a/model-engine/model_engine_server/core/config.py +++ b/model-engine/model_engine_server/core/config.py @@ -23,6 +23,7 @@ "CONFIG_PATH", "config_context", "get_config_path_for_env_name", + "infer_registry_type", "infra_config", "use_config_context", ) @@ -51,6 +52,7 @@ class _InfraConfig: prometheus_server_address: Optional[str] = None celery_broker_type_redis: Optional[bool] = None celery_enable_sha256: Optional[bool] = None + docker_registry_type: Optional[str] = None debug_mode: Optional[bool] = None @@ -109,6 +111,17 @@ def use_config_context(config_path: str): _infra_config = InfraConfig.from_yaml(config_path) +def infer_registry_type(prefix: str) -> str: + """Infer docker registry type from docker_repo_prefix.""" + if ".dkr.ecr." in prefix and ".amazonaws.com" in prefix: + return "ecr" + if ".azurecr.io" in prefix: + return "acr" + if "-docker.pkg.dev" in prefix: + return "gar" + return "generic" + + def get_config_path_for_env_name(env_name: str) -> Path: path = DEFAULT_CONFIG_PATH.parent / f"{env_name}.yaml" if not path.exists(): diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 849d93566..fe0631d7f 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -387,10 +387,6 @@ def __init__( def check_docker_image_exists_for_image_tag( self, framework_image_tag: str, repository_name: str ): - # Skip ECR validation for on-prem deployments - images are in local registry - if infra_config().cloud_provider == "onprem": - return - if not self.docker_repository.image_exists( image_tag=framework_image_tag, repository_name=repository_name, diff --git a/model-engine/model_engine_server/entrypoints/k8s_cache.py b/model-engine/model_engine_server/entrypoints/k8s_cache.py index f39d304b8..36bd8e96f 100644 --- a/model-engine/model_engine_server/entrypoints/k8s_cache.py +++ b/model-engine/model_engine_server/entrypoints/k8s_cache.py @@ -14,7 +14,7 @@ from model_engine_server.common.config import hmi_config from model_engine_server.common.constants import READYZ_FPATH from model_engine_server.common.env_vars import CIRCLECI -from model_engine_server.core.config import infra_config +from model_engine_server.core.config import infer_registry_type, infra_config from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.db.base import get_session_async_null_pool from model_engine_server.domain.repositories import DockerRepository @@ -44,6 +44,8 @@ ACRDockerRepository, ECRDockerRepository, FakeDockerRepository, + GARDockerRepository, + GenericDockerRepository, ) from model_engine_server.infra.repositories.db_model_endpoint_record_repository import ( DbModelEndpointRecordRepository, @@ -128,16 +130,21 @@ async def main(args: Any): ) image_cache_gateway = ImageCacheGateway() docker_repo: DockerRepository + registry_type = infra_config().docker_registry_type or infer_registry_type( + infra_config().docker_repo_prefix + ) if CIRCLECI: docker_repo = FakeDockerRepository() - elif infra_config().cloud_provider == "onprem": - docker_repo = OnPremDockerRepository() - elif infra_config().cloud_provider == "azure" or infra_config().docker_repo_prefix.endswith( - "azurecr.io" - ): + elif registry_type == "ecr": + docker_repo = ECRDockerRepository() + elif registry_type == "acr": docker_repo = ACRDockerRepository() + elif registry_type == "gar": + docker_repo = GARDockerRepository() + elif registry_type == "onprem": + docker_repo = OnPremDockerRepository() else: - docker_repo = ECRDockerRepository() + docker_repo = GenericDockerRepository() while True: loop_start = time.time() await loop_iteration( diff --git a/model-engine/model_engine_server/infra/repositories/__init__.py b/model-engine/model_engine_server/infra/repositories/__init__.py index 5a34f7ec5..adeb909c7 100644 --- a/model-engine/model_engine_server/infra/repositories/__init__.py +++ b/model-engine/model_engine_server/infra/repositories/__init__.py @@ -15,6 +15,7 @@ from .gar_docker_repository import GARDockerRepository from .gcs_file_llm_fine_tune_events_repository import GCSFileLLMFineTuneEventsRepository from .gcs_file_llm_fine_tune_repository import GCSFileLLMFineTuneRepository +from .generic_docker_repository import GenericDockerRepository from .live_tokenizer_repository import LiveTokenizerRepository from .llm_fine_tune_repository import LLMFineTuneRepository from .model_endpoint_cache_repository import ModelEndpointCacheRepository @@ -38,6 +39,7 @@ "ECRDockerRepository", "FakeDockerRepository", "GARDockerRepository", + "GenericDockerRepository", "FeatureFlagRepository", "GCSFileLLMFineTuneEventsRepository", "GCSFileLLMFineTuneRepository", diff --git a/model-engine/model_engine_server/infra/repositories/generic_docker_repository.py b/model-engine/model_engine_server/infra/repositories/generic_docker_repository.py new file mode 100644 index 000000000..08192903a --- /dev/null +++ b/model-engine/model_engine_server/infra/repositories/generic_docker_repository.py @@ -0,0 +1,105 @@ +import re +from typing import Optional +from urllib.parse import urlencode + +import requests +from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.repositories import DockerRepository + +logger = make_logger(logger_name()) + +_REQUEST_TIMEOUT = 10 + + +def _parse_www_authenticate(header: str) -> Optional[dict]: + """Parse a Www-Authenticate Bearer header into realm, service, and scope.""" + match = re.match(r"Bearer\s+(.*)", header, re.IGNORECASE) + if not match: + return None + params = {} + for m in re.finditer(r'(\w+)="([^"]*)"', match.group(1)): + params[m.group(1)] = m.group(2) + return params if "realm" in params else None + + +def _get_token(realm: str, service: Optional[str], scope: Optional[str]) -> Optional[str]: + """Fetch a bearer token from the registry's token endpoint.""" + query = {} + if service: + query["service"] = service + if scope: + query["scope"] = scope + separator = "&" if "?" in realm else "?" + url = f"{realm}{separator}{urlencode(query)}" if query else realm + try: + resp = requests.get(url, timeout=_REQUEST_TIMEOUT) + if resp.status_code == 200: + data = resp.json() + return data.get("token") or data.get("access_token") + except (requests.RequestException, ValueError): + pass + return None + + +class GenericDockerRepository(DockerRepository): + """Registry-agnostic Docker repository using the OCI Distribution / Docker Registry V2 HTTP API.""" + + def image_exists( + self, image_tag: str, repository_name: str, aws_profile: Optional[str] = None + ) -> bool: + prefix = infra_config().docker_repo_prefix.rstrip("/") + parts = prefix.split("/", 1) + registry_host = parts[0] + path_prefix = parts[1] if len(parts) > 1 else "" + full_repo = f"{path_prefix}/{repository_name}" if path_prefix else repository_name + manifest_url = f"https://{registry_host}/v2/{full_repo}/manifests/{image_tag}" + headers = { + "Accept": ", ".join( + [ + "application/vnd.docker.distribution.manifest.v2+json", + "application/vnd.oci.image.manifest.v1+json", + "application/vnd.docker.distribution.manifest.list.v2+json", + "application/vnd.oci.image.index.v1+json", + ] + ) + } + + try: + resp = requests.head(manifest_url, headers=headers, timeout=_REQUEST_TIMEOUT) + + if resp.status_code == 200: + return True + + if resp.status_code == 401: + www_auth = resp.headers.get("Www-Authenticate", "") + auth_params = _parse_www_authenticate(www_auth) + if auth_params: + token = _get_token( + realm=auth_params["realm"], + service=auth_params.get("service"), + scope=auth_params.get("scope"), + ) + if token: + headers["Authorization"] = f"Bearer {token}" + resp = requests.head( + manifest_url, headers=headers, timeout=_REQUEST_TIMEOUT + ) + return resp.status_code == 200 + + return False + except requests.RequestException as e: + logger.warning(f"Failed to check image existence at {manifest_url}: {e}") + return False + + def get_image_url(self, image_tag: str, repository_name: str) -> str: + if self.is_repo_name(repository_name): + return f"{infra_config().docker_repo_prefix}/{repository_name}:{image_tag}" + return f"{repository_name}:{image_tag}" + + def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse: + raise NotImplementedError("GenericDockerRepository does not support building images") + + def get_latest_image_tag(self, repository_name: str) -> str: + raise NotImplementedError("GenericDockerRepository does not support querying latest tags") diff --git a/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py b/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py index 2562a3c8e..f4bb17d4c 100644 --- a/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py +++ b/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py @@ -246,7 +246,7 @@ async def _read_or_submit_tasks( task_ids = [ BatchEndpointInProgressTask.deserialize(tid) for tid in task_ids_serialized ] - num_task_ids = len(task_ids) # type:ignore + num_task_ids = len(task_ids) # type: ignore logger.info(f"Found {num_task_ids} pending tasks for batch job {job_id}") finally: if task_ids is None: diff --git a/model-engine/tests/unit/api/test_dependencies.py b/model-engine/tests/unit/api/test_dependencies.py index 6d67d27c0..15491c368 100644 --- a/model-engine/tests/unit/api/test_dependencies.py +++ b/model-engine/tests/unit/api/test_dependencies.py @@ -101,6 +101,7 @@ def test_gcp_provider_selects_gcp_implementations(): mock_config_instance.cloud_provider = "gcp" mock_config_instance.celery_broker_type_redis = None mock_config_instance.docker_repo_prefix = "us-docker.pkg.dev/my-project/my-repo" + mock_config_instance.docker_registry_type = None mock_config.return_value = mock_config_instance mock_session = MagicMock() diff --git a/model-engine/tests/unit/infra/repositories/test_generic_docker_repository.py b/model-engine/tests/unit/infra/repositories/test_generic_docker_repository.py new file mode 100644 index 000000000..77112d770 --- /dev/null +++ b/model-engine/tests/unit/infra/repositories/test_generic_docker_repository.py @@ -0,0 +1,155 @@ +from unittest import mock + +import pytest +import requests +from model_engine_server.infra.repositories.generic_docker_repository import ( + GenericDockerRepository, + _parse_www_authenticate, +) + + +@pytest.fixture +def generic_docker_repo(): + return GenericDockerRepository() + + +@pytest.fixture +def mock_infra_config(): + with mock.patch( + "model_engine_server.infra.repositories.generic_docker_repository.infra_config" + ) as mock_config: + mock_config.return_value.docker_repo_prefix = "public.ecr.aws/b2z8n5q1" + yield mock_config + + +class TestParseWwwAuthenticate: + def test_parses_bearer_header(self): + header = 'Bearer realm="https://auth.example.com/token",service="registry.example.com",scope="repository:myrepo:pull"' + result = _parse_www_authenticate(header) + assert result == { + "realm": "https://auth.example.com/token", + "service": "registry.example.com", + "scope": "repository:myrepo:pull", + } + + def test_returns_none_for_basic_auth(self): + assert _parse_www_authenticate('Basic realm="registry"') is None + + def test_returns_none_for_missing_realm(self): + assert _parse_www_authenticate('Bearer service="foo"') is None + + def test_returns_none_for_empty_string(self): + assert _parse_www_authenticate("") is None + + +class TestImageExists: + def test_returns_true_on_200(self, generic_docker_repo, mock_infra_config): + with mock.patch( + "model_engine_server.infra.repositories.generic_docker_repository.requests" + ) as mock_requests: + mock_resp = mock.Mock() + mock_resp.status_code = 200 + mock_requests.head.return_value = mock_resp + + result = generic_docker_repo.image_exists("v0.4.0", "model-engine/vllm") + + assert result is True + mock_requests.head.assert_called_once() + call_url = mock_requests.head.call_args[0][0] + assert ( + call_url == "https://public.ecr.aws/v2/b2z8n5q1/model-engine/vllm/manifests/v0.4.0" + ) + + def test_returns_false_on_404(self, generic_docker_repo, mock_infra_config): + with mock.patch( + "model_engine_server.infra.repositories.generic_docker_repository.requests" + ) as mock_requests: + mock_resp = mock.Mock() + mock_resp.status_code = 404 + mock_requests.head.return_value = mock_resp + + result = generic_docker_repo.image_exists("nonexistent", "vllm") + + assert result is False + + def test_returns_false_on_connection_error(self, generic_docker_repo, mock_infra_config): + with mock.patch( + "model_engine_server.infra.repositories.generic_docker_repository.requests" + ) as mock_requests: + mock_requests.head.side_effect = requests.ConnectionError("unreachable") + mock_requests.ConnectionError = requests.ConnectionError + mock_requests.RequestException = requests.RequestException + + result = generic_docker_repo.image_exists("v1.0", "vllm") + + assert result is False + + def test_token_auth_on_401(self, generic_docker_repo, mock_infra_config): + with mock.patch( + "model_engine_server.infra.repositories.generic_docker_repository.requests" + ) as mock_requests: + mock_requests.RequestException = requests.RequestException + + # First HEAD returns 401 with Www-Authenticate + unauthed_resp = mock.Mock() + unauthed_resp.status_code = 401 + unauthed_resp.headers = { + "Www-Authenticate": 'Bearer realm="https://public.ecr.aws/token",service="public.ecr.aws",scope="repository:b2z8n5q1/vllm:pull"' + } + + # Second HEAD (with token) returns 200 + authed_resp = mock.Mock() + authed_resp.status_code = 200 + + mock_requests.head.side_effect = [unauthed_resp, authed_resp] + + # Token endpoint returns a token + token_resp = mock.Mock() + token_resp.status_code = 200 + token_resp.json.return_value = {"token": "test-token-123"} + mock_requests.get.return_value = token_resp + + result = generic_docker_repo.image_exists("v0.4.0", "vllm") + + assert result is True + assert mock_requests.head.call_count == 2 + # Verify the second HEAD had the Authorization header + second_call_headers = mock_requests.head.call_args_list[1][1]["headers"] + assert second_call_headers["Authorization"] == "Bearer test-token-123" + + def test_returns_false_on_401_without_www_authenticate( + self, generic_docker_repo, mock_infra_config + ): + with mock.patch( + "model_engine_server.infra.repositories.generic_docker_repository.requests" + ) as mock_requests: + mock_requests.RequestException = requests.RequestException + + mock_resp = mock.Mock() + mock_resp.status_code = 401 + mock_resp.headers = {} + mock_requests.head.return_value = mock_resp + + result = generic_docker_repo.image_exists("v1.0", "vllm") + + assert result is False + + +class TestGetImageUrl: + def test_prepends_prefix_for_simple_repo_name(self, generic_docker_repo, mock_infra_config): + result = generic_docker_repo.get_image_url("v1.0", "vllm") + assert result == "public.ecr.aws/b2z8n5q1/vllm:v1.0" + + def test_no_prefix_for_full_url(self, generic_docker_repo, mock_infra_config): + result = generic_docker_repo.get_image_url("v1.0", "docker.io/library/nginx") + assert result == "docker.io/library/nginx:v1.0" + + +class TestNotImplemented: + def test_build_image_raises(self, generic_docker_repo): + with pytest.raises(NotImplementedError): + generic_docker_repo.build_image(None) + + def test_get_latest_image_tag_raises(self, generic_docker_repo): + with pytest.raises(NotImplementedError): + generic_docker_repo.get_latest_image_tag("vllm")