Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions model-engine/model_engine_server/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from model_engine_server.core.auth.fake_authentication_repository import (
FakeAuthenticationRepository,
)
from model_engine_server.core.config import infra_config
from model_engine_server.core.config import infer_registry_type, infra_config
from model_engine_server.core.loggers import (
LoggerTagKey,
LoggerTagManager,
Expand Down Expand Up @@ -124,6 +124,7 @@
GARDockerRepository,
GCSFileLLMFineTuneEventsRepository,
GCSFileLLMFineTuneRepository,
GenericDockerRepository,
LiveTokenizerRepository,
LLMFineTuneRepository,
OnPremDockerRepository,
Expand All @@ -146,6 +147,7 @@

logger = make_logger(logger_name())


basic_auth = HTTPBasic(auto_error=False)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)

Expand Down Expand Up @@ -383,16 +385,21 @@ def _get_external_interfaces(
file_storage_gateway = S3FileStorageGateway()

docker_repository: DockerRepository
registry_type = infra_config().docker_registry_type or infer_registry_type(
infra_config().docker_repo_prefix
)
if CIRCLECI:
docker_repository = FakeDockerRepository()
elif infra_config().cloud_provider == "onprem":
docker_repository = OnPremDockerRepository()
elif infra_config().cloud_provider == "azure":
elif registry_type == "ecr":
docker_repository = ECRDockerRepository()
elif registry_type == "acr":
docker_repository = ACRDockerRepository()
elif infra_config().cloud_provider == "gcp":
elif registry_type == "gar":
docker_repository = GARDockerRepository()
elif registry_type == "onprem":
docker_repository = OnPremDockerRepository()
else:
docker_repository = ECRDockerRepository()
docker_repository = GenericDockerRepository()

tokenizer_repository = LiveTokenizerRepository(llm_artifact_gateway=llm_artifact_gateway)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# -*- coding: utf-8 -*-
"""Main module."""

# import asyncio
# import logging
from datetime import timedelta
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ async def _get_connection_count(self):
connection_count = info.get("connected_clients")
max_connections = info.get("maxclients")
else:
(info, config) = await aio.gather(
info, config = await aio.gather(
redis_client.info(),
redis_client.config_get("maxclients"),
)
Expand Down
13 changes: 13 additions & 0 deletions model-engine/model_engine_server/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"CONFIG_PATH",
"config_context",
"get_config_path_for_env_name",
"infer_registry_type",
"infra_config",
"use_config_context",
)
Expand Down Expand Up @@ -51,6 +52,7 @@ class _InfraConfig:
prometheus_server_address: Optional[str] = None
celery_broker_type_redis: Optional[bool] = None
celery_enable_sha256: Optional[bool] = None
docker_registry_type: Optional[str] = None
debug_mode: Optional[bool] = None


Expand Down Expand Up @@ -109,6 +111,17 @@ def use_config_context(config_path: str):
_infra_config = InfraConfig.from_yaml(config_path)


def infer_registry_type(prefix: str) -> str:
"""Infer docker registry type from docker_repo_prefix."""
if ".dkr.ecr." in prefix and ".amazonaws.com" in prefix:
return "ecr"
if ".azurecr.io" in prefix:
return "acr"
if "-docker.pkg.dev" in prefix:
return "gar"
return "generic"


def get_config_path_for_env_name(env_name: str) -> Path:
path = DEFAULT_CONFIG_PATH.parent / f"{env_name}.yaml"
if not path.exists():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,10 +387,6 @@ def __init__(
def check_docker_image_exists_for_image_tag(
self, framework_image_tag: str, repository_name: str
):
# Skip ECR validation for on-prem deployments - images are in local registry
if infra_config().cloud_provider == "onprem":
return

if not self.docker_repository.image_exists(
image_tag=framework_image_tag,
repository_name=repository_name,
Expand Down
21 changes: 14 additions & 7 deletions model-engine/model_engine_server/entrypoints/k8s_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from model_engine_server.common.config import hmi_config
from model_engine_server.common.constants import READYZ_FPATH
from model_engine_server.common.env_vars import CIRCLECI
from model_engine_server.core.config import infra_config
from model_engine_server.core.config import infer_registry_type, infra_config
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.db.base import get_session_async_null_pool
from model_engine_server.domain.repositories import DockerRepository
Expand Down Expand Up @@ -44,6 +44,8 @@
ACRDockerRepository,
ECRDockerRepository,
FakeDockerRepository,
GARDockerRepository,
GenericDockerRepository,
)
from model_engine_server.infra.repositories.db_model_endpoint_record_repository import (
DbModelEndpointRecordRepository,
Expand Down Expand Up @@ -128,16 +130,21 @@ async def main(args: Any):
)
image_cache_gateway = ImageCacheGateway()
docker_repo: DockerRepository
registry_type = infra_config().docker_registry_type or infer_registry_type(
infra_config().docker_repo_prefix
)
if CIRCLECI:
docker_repo = FakeDockerRepository()
elif infra_config().cloud_provider == "onprem":
docker_repo = OnPremDockerRepository()
elif infra_config().cloud_provider == "azure" or infra_config().docker_repo_prefix.endswith(
"azurecr.io"
):
elif registry_type == "ecr":
docker_repo = ECRDockerRepository()
elif registry_type == "acr":
docker_repo = ACRDockerRepository()
elif registry_type == "gar":
docker_repo = GARDockerRepository()
elif registry_type == "onprem":
docker_repo = OnPremDockerRepository()
else:
docker_repo = ECRDockerRepository()
docker_repo = GenericDockerRepository()
while True:
loop_start = time.time()
await loop_iteration(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .gar_docker_repository import GARDockerRepository
from .gcs_file_llm_fine_tune_events_repository import GCSFileLLMFineTuneEventsRepository
from .gcs_file_llm_fine_tune_repository import GCSFileLLMFineTuneRepository
from .generic_docker_repository import GenericDockerRepository
from .live_tokenizer_repository import LiveTokenizerRepository
from .llm_fine_tune_repository import LLMFineTuneRepository
from .model_endpoint_cache_repository import ModelEndpointCacheRepository
Expand All @@ -38,6 +39,7 @@
"ECRDockerRepository",
"FakeDockerRepository",
"GARDockerRepository",
"GenericDockerRepository",
"FeatureFlagRepository",
"GCSFileLLMFineTuneEventsRepository",
"GCSFileLLMFineTuneRepository",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import re
from typing import Optional
from urllib.parse import urlencode

import requests
from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse
from model_engine_server.core.config import infra_config
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.domain.repositories import DockerRepository

logger = make_logger(logger_name())

_REQUEST_TIMEOUT = 10


def _parse_www_authenticate(header: str) -> Optional[dict]:
"""Parse a Www-Authenticate Bearer header into realm, service, and scope."""
match = re.match(r"Bearer\s+(.*)", header, re.IGNORECASE)
if not match:
return None
params = {}
for m in re.finditer(r'(\w+)="([^"]*)"', match.group(1)):
params[m.group(1)] = m.group(2)
return params if "realm" in params else None


def _get_token(realm: str, service: Optional[str], scope: Optional[str]) -> Optional[str]:
"""Fetch a bearer token from the registry's token endpoint."""
query = {}
if service:
query["service"] = service
if scope:
query["scope"] = scope
separator = "&" if "?" in realm else "?"
url = f"{realm}{separator}{urlencode(query)}" if query else realm
try:
resp = requests.get(url, timeout=_REQUEST_TIMEOUT)
if resp.status_code == 200:
data = resp.json()
return data.get("token") or data.get("access_token")
except (requests.RequestException, ValueError):
pass
return None


class GenericDockerRepository(DockerRepository):
"""Registry-agnostic Docker repository using the OCI Distribution / Docker Registry V2 HTTP API."""

def image_exists(
self, image_tag: str, repository_name: str, aws_profile: Optional[str] = None
) -> bool:
prefix = infra_config().docker_repo_prefix.rstrip("/")
parts = prefix.split("/", 1)
registry_host = parts[0]
path_prefix = parts[1] if len(parts) > 1 else ""
full_repo = f"{path_prefix}/{repository_name}" if path_prefix else repository_name
manifest_url = f"https://{registry_host}/v2/{full_repo}/manifests/{image_tag}"
headers = {
"Accept": ", ".join(
[
"application/vnd.docker.distribution.manifest.v2+json",
"application/vnd.oci.image.manifest.v1+json",
"application/vnd.docker.distribution.manifest.list.v2+json",
"application/vnd.oci.image.index.v1+json",
]
)
}

try:
resp = requests.head(manifest_url, headers=headers, timeout=_REQUEST_TIMEOUT)

if resp.status_code == 200:
return True

if resp.status_code == 401:
www_auth = resp.headers.get("Www-Authenticate", "")
auth_params = _parse_www_authenticate(www_auth)
if auth_params:
token = _get_token(
realm=auth_params["realm"],
service=auth_params.get("service"),
scope=auth_params.get("scope"),
)
if token:
headers["Authorization"] = f"Bearer {token}"
resp = requests.head(
manifest_url, headers=headers, timeout=_REQUEST_TIMEOUT
)
return resp.status_code == 200

return False
except requests.RequestException as e:
logger.warning(f"Failed to check image existence at {manifest_url}: {e}")
return False

def get_image_url(self, image_tag: str, repository_name: str) -> str:
if self.is_repo_name(repository_name):
return f"{infra_config().docker_repo_prefix}/{repository_name}:{image_tag}"
return f"{repository_name}:{image_tag}"

def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse:
raise NotImplementedError("GenericDockerRepository does not support building images")

def get_latest_image_tag(self, repository_name: str) -> str:
raise NotImplementedError("GenericDockerRepository does not support querying latest tags")
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ async def _read_or_submit_tasks(
task_ids = [
BatchEndpointInProgressTask.deserialize(tid) for tid in task_ids_serialized
]
num_task_ids = len(task_ids) # type:ignore
num_task_ids = len(task_ids) # type: ignore
logger.info(f"Found {num_task_ids} pending tasks for batch job {job_id}")
finally:
if task_ids is None:
Expand Down
1 change: 1 addition & 0 deletions model-engine/tests/unit/api/test_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def test_gcp_provider_selects_gcp_implementations():
mock_config_instance.cloud_provider = "gcp"
mock_config_instance.celery_broker_type_redis = None
mock_config_instance.docker_repo_prefix = "us-docker.pkg.dev/my-project/my-repo"
mock_config_instance.docker_registry_type = None
mock_config.return_value = mock_config_instance

mock_session = MagicMock()
Expand Down
Loading