diff --git a/model-engine/model_engine_server/api/app.py b/model-engine/model_engine_server/api/app.py index cac68cda2..246a813f5 100644 --- a/model-engine/model_engine_server/api/app.py +++ b/model-engine/model_engine_server/api/app.py @@ -312,6 +312,56 @@ def load_redis(): get_or_create_aioredis_pool() +@app.on_event("startup") +def init_model_weights_manager(): + from model_engine_server.core.config import infra_config + from model_engine_server.domain.use_cases.model_weights_manager import ModelWeightsManager + from model_engine_server.infra.gateways import ( + ABSLLMArtifactGateway, + GCSLLMArtifactGateway, + S3LLMArtifactGateway, + ) + + provider = infra_config().cloud_provider + if provider == "azure": + gateway = ABSLLMArtifactGateway() + elif provider == "gcp": + gateway = GCSLLMArtifactGateway() + else: + gateway = S3LLMArtifactGateway() + app.state.model_weights_manager = ModelWeightsManager(llm_artifact_gateway=gateway) + + +@app.on_event("startup") +async def recover_hf_syncs(): + """Re-trigger weight syncs for endpoints that were syncing when server last stopped.""" + from model_engine_server.db.base import get_session_async + from model_engine_server.infra.repositories.live_tokenizer_repository import ( + SUPPORTED_MODELS_INFO, + ) + from sqlalchemy import text + + session_factory = get_session_async() + try: + async with session_factory() as session: + result = await session.execute( + text( + "SELECT DISTINCT endpoint_metadata->'_llm'->>'model_name' AS model_name " + "FROM hosted_model_inference.endpoints " + "WHERE (endpoint_metadata->'_llm'->>'hf_weights_syncing')::boolean = true" + ) + ) + model_names = [row.model_name for row in result if row.model_name] + except Exception: + logger.warning("Could not query pending HF sync endpoints at startup", exc_info=True) + return + for model_name in model_names: + info = SUPPORTED_MODELS_INFO.get(model_name) + if info and info.hf_repo: + app.state.model_weights_manager.ensure_model_weights_available(info.hf_repo) + logger.info(f"Startup: re-triggered HF weight sync for {model_name}") + + def healthcheck() -> Response: """Returns 200 if the app is healthy.""" return Response(status_code=200) diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 35f55fa5e..99af6a273 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -148,14 +148,15 @@ def handle_streaming_exception( @llm_router_v1.post("/model-endpoints", response_model=CreateLLMModelEndpointV1Response) async def create_model_endpoint( wrapped_request: RootModel[CreateLLMModelEndpointV1Request], + request: Request, auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), ) -> CreateLLMModelEndpointV1Response: - request = wrapped_request.root + llm_request = wrapped_request.root """ Creates an LLM endpoint for the current user. """ - logger.info(f"POST /llm/model-endpoints with {request} for {auth}") + logger.info(f"POST /llm/model-endpoints with {llm_request} for {auth}") try: create_model_bundle_use_case = CreateModelBundleV2UseCase( model_bundle_repository=external_interfaces.model_bundle_repository, @@ -168,13 +169,15 @@ async def create_model_endpoint( llm_artifact_gateway=external_interfaces.llm_artifact_gateway, docker_repository=external_interfaces.docker_repository, ) + model_weights_manager = request.app.state.model_weights_manager use_case = CreateLLMModelEndpointV1UseCase( create_llm_model_bundle_use_case=create_llm_model_bundle_use_case, model_endpoint_service=external_interfaces.model_endpoint_service, docker_repository=external_interfaces.docker_repository, llm_artifact_gateway=external_interfaces.llm_artifact_gateway, + model_weights_manager=model_weights_manager, ) - return await use_case.execute(user=auth, request=request) + return await use_case.execute(user=auth, request=llm_request) except ObjectAlreadyExistsException as exc: raise HTTPException( status_code=400, diff --git a/model-engine/model_engine_server/domain/entities/llm_entity.py b/model-engine/model_engine_server/domain/entities/llm_entity.py index dc6c06090..cf4e4a3e3 100644 --- a/model-engine/model_engine_server/domain/entities/llm_entity.py +++ b/model-engine/model_engine_server/domain/entities/llm_entity.py @@ -31,3 +31,4 @@ class LLMMetadata: quantize: Optional[Quantization] = None checkpoint_path: Optional[str] = None chat_template_override: Optional[str] = None + hf_weights_syncing: bool = False diff --git a/model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py b/model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py index 8f8ece698..9f12cf787 100644 --- a/model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py @@ -40,6 +40,17 @@ def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[ """ pass + @abstractmethod + def upload_files(self, local_path: str, remote_path: str, **kwargs) -> None: + """ + Upload all files from a local directory to a remote path. + + Args: + local_path (str): local directory containing files to upload + remote_path (str): remote destination path (s3://, gs://, or https://) + """ + pass + @abstractmethod def get_model_config(self, path: str, **kwargs) -> Dict[str, Any]: """ diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 50b000488..9bb347df9 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -652,7 +652,8 @@ def load_model_weights_sub_commands_s3( s5cmd = "./s5cmd" checkpoint_files = self.llm_artifact_gateway.list_files(checkpoint_path) - validate_checkpoint_files(checkpoint_files) + if checkpoint_files: + validate_checkpoint_files(checkpoint_files) # filter to configs ('*.model' and '*.json') and weights ('*.safetensors') # For models that are not supported by transformers directly, we need to include '*.py' and '*.bin' @@ -1322,12 +1323,14 @@ def __init__( model_endpoint_service: ModelEndpointService, docker_repository: DockerRepository, llm_artifact_gateway: LLMArtifactGateway, + model_weights_manager=None, ): self.authz_module = LiveAuthorizationModule() self.create_llm_model_bundle_use_case = create_llm_model_bundle_use_case self.model_endpoint_service = model_endpoint_service self.docker_repository = docker_repository self.llm_artifact_gateway = llm_artifact_gateway + self.model_weights_manager = model_weights_manager async def execute( self, user: User, request: CreateLLMModelEndpointV1Request @@ -1387,6 +1390,21 @@ async def execute( "Multinode endpoints are only supported for VLLM models." ) + # Resolve checkpoint path: fires background sync and returns expected path immediately + checkpoint_path = request.checkpoint_path + hf_weights_syncing = False + if ( + checkpoint_path is None + and request.source == LLMSource.HUGGING_FACE + and self.model_weights_manager is not None + ): + models_info = SUPPORTED_MODELS_INFO.get(request.model_name) + if models_info and models_info.hf_repo: + checkpoint_path = self.model_weights_manager.ensure_model_weights_available( + models_info.hf_repo + ) + hf_weights_syncing = True + bundle = await self.create_llm_model_bundle_use_case.execute( user, endpoint_name=request.name, @@ -1397,7 +1415,7 @@ async def execute( endpoint_type=request.endpoint_type, num_shards=request.num_shards, quantize=request.quantize, - checkpoint_path=request.checkpoint_path, + checkpoint_path=checkpoint_path, chat_template_override=request.chat_template_override, nodes_per_worker=request.nodes_per_worker, additional_args=request.model_dump(exclude_none=True), @@ -1430,8 +1448,9 @@ async def execute( inference_framework_image_tag=request.inference_framework_image_tag, num_shards=request.num_shards, quantize=request.quantize, - checkpoint_path=request.checkpoint_path, + checkpoint_path=checkpoint_path, chat_template_override=request.chat_template_override, + hf_weights_syncing=hf_weights_syncing, ) ) diff --git a/model-engine/model_engine_server/domain/use_cases/model_weights_manager.py b/model-engine/model_engine_server/domain/use_cases/model_weights_manager.py new file mode 100644 index 000000000..88b3bb4b3 --- /dev/null +++ b/model-engine/model_engine_server/domain/use_cases/model_weights_manager.py @@ -0,0 +1,100 @@ +import asyncio +import functools +import tempfile +from typing import Dict, List, Set + +from huggingface_hub import snapshot_download +from model_engine_server.common.config import hmi_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.gateways.llm_artifact_gateway import LLMArtifactGateway + +logger = make_logger(logger_name()) + +# Match the internal sync_model_weights.py inclusion/exclusion patterns +HF_IGNORE_PATTERNS: List[str] = [ + "optimizer*", + "*.msgpack", + "*.h5", + "flax_model*", + "tf_model*", + "rust_model*", +] + + +class ModelWeightsManager: + def __init__(self, llm_artifact_gateway: LLMArtifactGateway): + self.llm_artifact_gateway = llm_artifact_gateway + self._background_tasks: Set[asyncio.Task] = set() + self._in_progress: Dict[str, asyncio.Task] = {} + + def get_remote_path(self, hf_repo: str) -> str: + prefix = hmi_config.hf_user_fine_tuned_weights_prefix.rstrip("/") + return f"{prefix}/{hf_repo}" + + def ensure_model_weights_available(self, hf_repo: str) -> str: + """ + Returns the expected remote path for ``hf_repo`` immediately and starts + syncing weights from HuggingFace Hub to that path in the background. + + If the weights are already cached the background task exits early. + Callers receive the checkpoint path right away and can proceed with + any following actions (e.g. endpoint creation) without blocking. + + A second call for the same ``hf_repo`` while a sync is already in + progress is a no-op: the existing task is reused and the same remote + path is returned. + + Args: + hf_repo: HuggingFace repository ID, e.g. ``"meta-llama/Meta-Llama-3-8B"``. + + Returns: + The remote path (s3://, gs://, or https://) where the weights will be stored. + """ + remote_path = self.get_remote_path(hf_repo) + if hf_repo not in self._in_progress: + task = asyncio.create_task(self._sync_weights(hf_repo, remote_path)) + self._background_tasks.add(task) + self._in_progress[hf_repo] = task + task.add_done_callback(lambda t: self._on_task_done(t, hf_repo)) + return remote_path + + def _on_task_done(self, task: asyncio.Task, hf_repo: str) -> None: + self._background_tasks.discard(task) + self._in_progress.pop(hf_repo, None) + if not task.cancelled(): + exc = task.exception() + if exc: + logger.error( + f"Background weight sync failed for {hf_repo}: {exc}", + exc_info=exc, + ) + + async def _sync_weights(self, hf_repo: str, remote_path: str) -> None: + """Downloads weights from HuggingFace Hub and uploads to remote storage if not cached.""" + files = self.llm_artifact_gateway.list_files(remote_path) + if files: + logger.info(f"Cache hit: {len(files)} files at {remote_path}") + return + + logger.info(f"Cache miss for {hf_repo}. Downloading from HuggingFace Hub...") + loop = asyncio.get_event_loop() + with tempfile.TemporaryDirectory() as tmp_dir: + await loop.run_in_executor( + None, + functools.partial( + snapshot_download, + repo_id=hf_repo, + local_dir=tmp_dir, + ignore_patterns=HF_IGNORE_PATTERNS, + ), + ) + await loop.run_in_executor( + None, + functools.partial( + self.llm_artifact_gateway.upload_files, + tmp_dir, + remote_path, + ), + ) + + logger.info(f"Weights for {hf_repo} uploaded to {remote_path}") diff --git a/model-engine/model_engine_server/infra/gateways/abs_llm_artifact_gateway.py b/model-engine/model_engine_server/infra/gateways/abs_llm_artifact_gateway.py index a12361383..df574841e 100644 --- a/model-engine/model_engine_server/infra/gateways/abs_llm_artifact_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/abs_llm_artifact_gateway.py @@ -59,6 +59,16 @@ def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) downloaded_files.append(local_path) return downloaded_files + def upload_files(self, local_path: str, remote_path: str, **kwargs) -> None: + parsed = parse_attachment_url(remote_path, clean_key=False) + container_client = _get_abs_container_client(parsed.bucket) + for root, _, files in os.walk(local_path): + for file in files: + local_file = os.path.join(root, file) + blob_name = os.path.join(parsed.key, os.path.relpath(local_file, local_path)) + with open(local_file, "rb") as f: + container_client.upload_blob(name=blob_name, data=f, overwrite=True) + def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]: parsed_remote = parse_attachment_url( hmi_config.hf_user_fine_tuned_weights_prefix, clean_key=False diff --git a/model-engine/model_engine_server/infra/gateways/gcs_llm_artifact_gateway.py b/model-engine/model_engine_server/infra/gateways/gcs_llm_artifact_gateway.py index e8d78e5e6..3d726491d 100644 --- a/model-engine/model_engine_server/infra/gateways/gcs_llm_artifact_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/gcs_llm_artifact_gateway.py @@ -52,6 +52,16 @@ def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) downloaded_files.append(local_path) return downloaded_files + def upload_files(self, local_path: str, remote_path: str, **kwargs) -> None: + parsed = parse_attachment_url(remote_path, clean_key=False) + client = get_gcs_sync_client() + bucket = client.bucket(parsed.bucket) + for root, _, files in os.walk(local_path): + for file in files: + local_file = os.path.join(root, file) + blob_name = os.path.join(parsed.key, os.path.relpath(local_file, local_path)) + bucket.blob(blob_name).upload_from_filename(local_file) + def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]: parsed_remote = parse_attachment_url( hmi_config.hf_user_fine_tuned_weights_prefix, clean_key=False diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py index 9c6de78c1..2f922fb49 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py @@ -62,6 +62,28 @@ BASE_PATH_IN_ENDPOINT = "/app" DATADOG_ENV_VAR = {"DD_TRACE_ENABLED", "DD_SERVICE", "DD_ENV", "DD_VERSION", "DD_AGENT_HOST"} + +# Key under which LLM metadata is stored in model_endpoint_record.metadata +_LLM_METADATA_KEY = "_llm" + +# Python script run by the init container to poll storage until HF weights are present. +_HF_WEIGHTS_POLL_SCRIPT = """\ +import boto3, os, sys, time +from urllib.parse import urlparse + +cp = os.environ["CHECKPOINT_PATH"] +url = urlparse(cp) +bucket = url.netloc +prefix = url.path.lstrip("/") +s3 = boto3.client("s3") +while True: + resp = s3.list_objects_v2(Bucket=bucket, Prefix=prefix, MaxKeys=1) + if resp.get("Contents"): + print(f"Model weights ready at {cp}", flush=True) + sys.exit(0) + print(f"Waiting for model weights at {cp}...", flush=True) + time.sleep(30) +""" LWS_DEFAULT_ENV_VAR = { "K8S_OWN_POD_NAME", "K8S_OWN_NAMESPACE", @@ -339,6 +361,42 @@ def add_pod_metadata_env_to_container(container: Dict[str, Any]) -> None: ) +def add_hf_weights_init_container( + deployment_template: Dict[str, Any], + checkpoint_path: str, +) -> None: + """Prepend an init container that polls storage until HF weights are present. + + Uses the forwarder image (model-engine gateway image, which has Python and + boto3) so no additional image pull is required. Authentication relies on + the pod's service account (IRSA / workload-identity). + """ + containers = deployment_template["spec"]["template"]["spec"]["containers"] + # Prefer the forwarder container image; fall back to the first container. + forwarder_image = next( + (c["image"] for c in containers if c["name"] in ("http-forwarder", "celery-forwarder")), + containers[0]["image"], + ) + + init_container: Dict[str, Any] = { + "name": "wait-for-model-weights", + "image": forwarder_image, + "env": [{"name": "CHECKPOINT_PATH", "value": checkpoint_path}], + "command": ["python3", "-c", _HF_WEIGHTS_POLL_SCRIPT], + } + + # Reuse the AWS config volume mount if the volume is present in the pod spec + volumes = deployment_template["spec"]["template"]["spec"].get("volumes", []) + if any(v["name"] == "config-volume" for v in volumes): + init_container["volumeMounts"] = [ + {"name": "config-volume", "mountPath": "/opt/.aws/config", "subPath": "config"} + ] + + deployment_template["spec"]["template"]["spec"].setdefault("initContainers", []).append( + init_container + ) + + def add_lws_default_env_vars_to_container(container: Dict[str, Any]) -> None: container_envs = [] container_envs.extend( @@ -1657,6 +1715,9 @@ async def _create_or_update_resources( user_container = get_main_container_from_deployment_template(deployment_template) add_datadog_env_to_container(deployment_template, user_container) add_pod_metadata_env_to_container(user_container) + llm_metadata = (model_endpoint_record.metadata or {}).get(_LLM_METADATA_KEY, {}) + if llm_metadata.get("hf_weights_syncing") and llm_metadata.get("checkpoint_path"): + add_hf_weights_init_container(deployment_template, llm_metadata["checkpoint_path"]) await self._create_deployment( model_endpoint_record=request.build_endpoint_request.model_endpoint_record, deployment=deployment_template, diff --git a/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py b/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py index 7b4219787..cfa80cf64 100644 --- a/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py @@ -56,6 +56,17 @@ def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) logger.info(f"Downloaded {len(downloaded_files)} files to {target_path}") return downloaded_files + def upload_files(self, local_path: str, remote_path: str, **kwargs) -> None: + s3 = get_s3_resource(kwargs) + parsed = parse_attachment_url(remote_path, clean_key=False) + bucket = s3.Bucket(parsed.bucket) + for root, _, files in os.walk(local_path): + for file in files: + local_file = os.path.join(root, file) + s3_key = os.path.join(parsed.key, os.path.relpath(local_file, local_path)) + logger.info(f"Uploading {local_file} → s3://{parsed.bucket}/{s3_key}") + bucket.upload_file(local_file, s3_key) + def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]: s3 = get_s3_resource(kwargs) parsed_remote = parse_attachment_url( diff --git a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py index 54e6436c1..ad78584ec 100644 --- a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py +++ b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py @@ -3,7 +3,11 @@ from typing import Dict, NamedTuple, Optional from huggingface_hub import list_repo_refs -from huggingface_hub.utils._errors import RepositoryNotFoundError + +try: + from huggingface_hub.utils._errors import RepositoryNotFoundError +except ImportError: + from huggingface_hub.errors import RepositoryNotFoundError from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.exceptions import ObjectNotFoundException from model_engine_server.domain.gateways.llm_artifact_gateway import LLMArtifactGateway diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index dce1b85ae..04dd9cfe1 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -862,6 +862,9 @@ def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) if path in self.s3_bucket: return self.s3_bucket[path] + def upload_files(self, local_path: str, remote_path: str, **kwargs) -> None: + pass + def get_model_weights_urls(self, owner: str, model_name: str): if (owner, model_name) in self.existing_models: return self.urls diff --git a/model-engine/tests/unit/domain/test_model_weights_manager.py b/model-engine/tests/unit/domain/test_model_weights_manager.py new file mode 100644 index 000000000..78f8054c2 --- /dev/null +++ b/model-engine/tests/unit/domain/test_model_weights_manager.py @@ -0,0 +1,275 @@ +"""Unit tests for ModelWeightsManager.""" + +from typing import Any, Dict, List +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from model_engine_server.domain.gateways.llm_artifact_gateway import LLMArtifactGateway +from model_engine_server.domain.use_cases.model_weights_manager import ( + HF_IGNORE_PATTERNS, + ModelWeightsManager, +) + + +class FakeArtifactGateway(LLMArtifactGateway): + """Minimal fake gateway for testing.""" + + def __init__(self, existing_files: List[str] = None, uploaded: List[tuple] = None): + self._existing_files = existing_files if existing_files is not None else [] + self.uploaded: List[tuple] = uploaded if uploaded is not None else [] + + def list_files(self, path: str, **kwargs) -> List[str]: + return self._existing_files + + def upload_files(self, local_path: str, remote_path: str, **kwargs) -> None: + self.uploaded.append((local_path, remote_path)) + + def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) -> List[str]: + return [] + + def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]: + return [] + + def get_model_config(self, path: str, **kwargs) -> Dict[str, Any]: + return {} + + +@pytest.mark.asyncio +async def test_cache_hit_skips_download(): + """When list_files returns non-empty, no download or upload should occur.""" + gateway = FakeArtifactGateway(existing_files=["model.safetensors"]) + manager = ModelWeightsManager(llm_artifact_gateway=gateway) + + mwm_base = "model_engine_server.domain.use_cases.model_weights_manager" + with ( + patch(f"{mwm_base}.snapshot_download") as mock_download, + patch(f"{mwm_base}.asyncio.create_task") as mock_create_task, + ): + result = manager.ensure_model_weights_available("meta-llama/Meta-Llama-3-8B") + # Run the background sync task to assert on side-effects + await mock_create_task.call_args[0][0] + + mock_download.assert_not_called() + assert len(gateway.uploaded) == 0 + assert "meta-llama/Meta-Llama-3-8B" in result + + +@pytest.mark.asyncio +async def test_cache_hit_returns_correct_s3_path(monkeypatch): + """On cache hit the returned path should be {prefix}/{hf_repo}.""" + monkeypatch.setattr( + "model_engine_server.domain.use_cases.model_weights_manager.hmi_config", + MagicMock(hf_user_fine_tuned_weights_prefix="s3://my-bucket/weights"), + ) + gateway = FakeArtifactGateway(existing_files=["file.bin"]) + manager = ModelWeightsManager(llm_artifact_gateway=gateway) + + mwm_base = "model_engine_server.domain.use_cases.model_weights_manager" + with patch(f"{mwm_base}.asyncio.create_task") as mock_create_task: + result = manager.ensure_model_weights_available("org/model") + await mock_create_task.call_args[0][0] + + assert result == "s3://my-bucket/weights/org/model" + + +@pytest.mark.asyncio +async def test_cache_miss_calls_snapshot_download_and_upload(tmp_path, monkeypatch): + """On cache miss, snapshot_download and upload_files should both be called.""" + monkeypatch.setattr( + "model_engine_server.domain.use_cases.model_weights_manager.hmi_config", + MagicMock(hf_user_fine_tuned_weights_prefix="s3://my-bucket/weights"), + ) + + gateway = FakeArtifactGateway(existing_files=[]) + manager = ModelWeightsManager(llm_artifact_gateway=gateway) + + mwm_base = "model_engine_server.domain.use_cases.model_weights_manager" + with ( + patch(f"{mwm_base}.snapshot_download") as mock_download, + patch(f"{mwm_base}.asyncio.create_task") as mock_create_task, + ): + result = manager.ensure_model_weights_available("org/model") + # Run the background sync task so we can assert on its side-effects + await mock_create_task.call_args[0][0] + + mock_download.assert_called_once() + call_kwargs = mock_download.call_args + assert call_kwargs.kwargs["repo_id"] == "org/model" + assert call_kwargs.kwargs["ignore_patterns"] == HF_IGNORE_PATTERNS + + assert len(gateway.uploaded) == 1 + _local, remote = gateway.uploaded[0] + assert remote == "s3://my-bucket/weights/org/model" + assert result == "s3://my-bucket/weights/org/model" + + +def test_s3_path_construction(monkeypatch): + """Remote path should be {prefix}/{hf_repo} with correct stripping of trailing slash.""" + monkeypatch.setattr( + "model_engine_server.domain.use_cases.model_weights_manager.hmi_config", + MagicMock(hf_user_fine_tuned_weights_prefix="s3://bucket/prefix/"), + ) + gateway = FakeArtifactGateway(existing_files=[]) + manager = ModelWeightsManager(llm_artifact_gateway=gateway) + + path = manager.get_remote_path("myorg/mymodel") + assert path == "s3://bucket/prefix/myorg/mymodel" + + +def test_deduplication_same_hf_repo(): + """Second call for same hf_repo while a sync is in progress should not create a new task.""" + gateway = FakeArtifactGateway(existing_files=[]) + manager = ModelWeightsManager(llm_artifact_gateway=gateway) + + mwm_base = "model_engine_server.domain.use_cases.model_weights_manager" + with patch(f"{mwm_base}.asyncio.create_task") as mock_create_task: + result1 = manager.ensure_model_weights_available("org/model") + result2 = manager.ensure_model_weights_available("org/model") + + assert mock_create_task.call_count == 1 + assert result1 == result2 + + +def test_task_reference_held_until_done(): + """_background_tasks should hold a reference to the task until _on_task_done fires.""" + gateway = FakeArtifactGateway(existing_files=[]) + manager = ModelWeightsManager(llm_artifact_gateway=gateway) + + mwm_base = "model_engine_server.domain.use_cases.model_weights_manager" + mock_task = MagicMock() + with patch(f"{mwm_base}.asyncio.create_task", return_value=mock_task): + manager.ensure_model_weights_available("org/model") + + assert mock_task in manager._background_tasks + assert "org/model" in manager._in_progress + + # Simulate successful task completion via the done callback + mock_task.cancelled.return_value = False + mock_task.exception.return_value = None + manager._on_task_done(mock_task, "org/model") + + assert mock_task not in manager._background_tasks + assert "org/model" not in manager._in_progress + + +def test_error_surfaced_on_task_failure(): + """When the background task raises, _on_task_done should log the error.""" + gateway = FakeArtifactGateway(existing_files=[]) + manager = ModelWeightsManager(llm_artifact_gateway=gateway) + + mock_task = MagicMock() + mock_task.cancelled.return_value = False + exc = RuntimeError("Download failed") + mock_task.exception.return_value = exc + + mwm_base = "model_engine_server.domain.use_cases.model_weights_manager" + with patch(f"{mwm_base}.logger") as mock_logger: + manager._on_task_done(mock_task, "org/model") + mock_logger.error.assert_called_once() + call_args = mock_logger.error.call_args + assert "org/model" in call_args[0][0] + assert call_args[1]["exc_info"] == exc + + +@pytest.mark.asyncio +async def test_create_llm_model_endpoint_calls_weights_manager_on_hf_source(): + """CreateLLMModelEndpointV1UseCase should call ensure_model_weights_available (sync), + which returns the expected checkpoint path immediately and fires weight sync in the + background. All following actions (bundle, endpoint creation) proceed without blocking.""" + from model_engine_server.domain.entities import LLMSource + from model_engine_server.domain.use_cases.model_weights_manager import ModelWeightsManager + + mock_manager = MagicMock(spec=ModelWeightsManager) + mock_manager.ensure_model_weights_available.return_value = ( + "s3://bucket/weights/huggyllama/llama-7b" + ) + + # Use a real SUPPORTED_MODELS_INFO entry: "llama-2-7b" -> "huggyllama/llama-7b" + from tests.unit.conftest import FakeLLMArtifactGateway + + fake_gateway = FakeLLMArtifactGateway() + + from model_engine_server.domain.use_cases.llm_model_endpoint_use_cases import ( + CreateLLMModelEndpointV1UseCase, + ) + + mock_bundle_use_case = MagicMock() + mock_bundle = MagicMock() + mock_bundle.id = "bundle-id" + mock_bundle_use_case.execute = AsyncMock(return_value=mock_bundle) + + mock_endpoint_service = MagicMock() + mock_endpoint_record = MagicMock() + mock_endpoint_record.id = "endpoint-id" + mock_endpoint_record.creation_task_id = "task-123" + mock_endpoint_service.create_model_endpoint = AsyncMock(return_value=mock_endpoint_record) + mock_endpoint_service.can_scale_http_endpoint_from_zero = MagicMock(return_value=False) + mock_endpoint_service.get_inference_autoscaling_metrics_gateway.return_value.emit_prewarm_metric = ( + AsyncMock() + ) + + mock_docker_repository = MagicMock() + + use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=mock_bundle_use_case, + model_endpoint_service=mock_endpoint_service, + docker_repository=mock_docker_repository, + llm_artifact_gateway=fake_gateway, + model_weights_manager=mock_manager, + ) + + from model_engine_server.common.dtos.llms import CreateLLMModelEndpointV1Request + from model_engine_server.core.auth.authentication_repository import User + from pydantic import TypeAdapter + + user = User(user_id="test-user", team_id="test-team", is_privileged_user=True) + request = TypeAdapter(CreateLLMModelEndpointV1Request).validate_python( + { + "name": "test-endpoint", + "model_name": "llama-2-7b", + "source": LLMSource.HUGGING_FACE, + "inference_framework": "vllm", + "inference_framework_image_tag": "0.1.0", + "num_shards": 1, + "endpoint_type": "streaming", + "checkpoint_path": None, + "min_workers": 1, + "max_workers": 1, + "per_worker": 10, + "cpus": 4, + "memory": "16Gi", + "storage": "50Gi", + "gpus": 1, + "gpu_type": "nvidia-ampere-a10", + "nodes_per_worker": 1, + "labels": {"team": "test"}, + "metadata": {}, + } + ) + + # Patch infrastructure helpers to keep the test focused on weights manager behavior + base = "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases" + with ( + patch(f"{base}._fill_hardware_info"), + patch(f"{base}.validate_resource_requests"), + patch(f"{base}.validate_deployment_resources"), + patch(f"{base}.validate_labels"), + patch(f"{base}.validate_billing_tags"), + patch(f"{base}.validate_post_inference_hooks"), + patch(f"{base}.validate_model_name"), + patch(f"{base}.validate_num_shards"), + patch(f"{base}.validate_quantization"), + patch(f"{base}.validate_chat_template"), + patch(f"{base}.LiveAuthorizationModule") as mock_authz, + ): + mock_authz.return_value.get_aws_role_for_user = MagicMock( + return_value="arn:aws:iam::123:role/test" + ) + mock_authz.return_value.get_s3_bucket_for_user = MagicMock(return_value="test-bucket") + await use_case.execute(user=user, request=request) + + # ensure_model_weights_available is called synchronously — no await, no blocking + mock_manager.ensure_model_weights_available.assert_called_once_with("huggyllama/llama-7b") + # Verify that the resolved checkpoint path was forwarded to the bundle use case + bundle_call_kwargs = mock_bundle_use_case.execute.call_args.kwargs + assert bundle_call_kwargs["checkpoint_path"] == "s3://bucket/weights/huggyllama/llama-7b"