From 3e7a2a4d505ad3d5c2609bf7bee83c8259b9ea06 Mon Sep 17 00:00:00 2001 From: lilyz-ai Date: Fri, 20 Feb 2026 01:50:37 +0000 Subject: [PATCH 1/4] feat: add ModelWeightsManager to auto-sync HF weights on endpoint creation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When a model endpoint is created via POST /v1/llm/model-endpoints with source=HUGGING_FACE and no checkpoint_path, ModelWeightsManager now automatically checks the configured S3/GCS/ABS prefix for cached weights and downloads from HuggingFace Hub + uploads if missing — eliminating the manual sync_model_weights.py step. - Add ModelWeightsManager with ensure_model_weights_available() (async-safe via run_in_executor, cache-hit skips all I/O) - Add upload_files() abstract method to LLMArtifactGateway with implementations for S3, GCS, and ABS - Wire ModelWeightsManager into CreateLLMModelEndpointV1UseCase and the create_model_endpoint API handler - Fix huggingface_hub.utils._errors import for hub>=0.36 compatibility - Add unit tests covering cache hit/miss, path construction, and end-to-end integration with CreateLLMModelEndpointV1UseCase Co-Authored-By: Claude Sonnet 4.6 --- .../model_engine_server/api/llms_v1.py | 5 + .../domain/gateways/llm_artifact_gateway.py | 11 + .../use_cases/llm_model_endpoint_use_cases.py | 19 +- .../domain/use_cases/model_weights_manager.py | 73 ++++++ .../gateways/abs_llm_artifact_gateway.py | 10 + .../gateways/gcs_llm_artifact_gateway.py | 10 + .../infra/gateways/s3_llm_artifact_gateway.py | 11 + .../repositories/live_tokenizer_repository.py | 6 +- model-engine/tests/unit/conftest.py | 3 + .../unit/domain/test_model_weights_manager.py | 212 ++++++++++++++++++ 10 files changed, 357 insertions(+), 3 deletions(-) create mode 100644 model-engine/model_engine_server/domain/use_cases/model_weights_manager.py create mode 100644 model-engine/tests/unit/domain/test_model_weights_manager.py diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 35f55fa5e..6c5c8c244 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -86,6 +86,7 @@ UpdateLLMModelEndpointV1UseCase, ) from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase +from model_engine_server.domain.use_cases.model_weights_manager import ModelWeightsManager from pydantic import RootModel from sse_starlette.sse import EventSourceResponse @@ -168,11 +169,15 @@ async def create_model_endpoint( llm_artifact_gateway=external_interfaces.llm_artifact_gateway, docker_repository=external_interfaces.docker_repository, ) + model_weights_manager = ModelWeightsManager( + llm_artifact_gateway=external_interfaces.llm_artifact_gateway, + ) use_case = CreateLLMModelEndpointV1UseCase( create_llm_model_bundle_use_case=create_llm_model_bundle_use_case, model_endpoint_service=external_interfaces.model_endpoint_service, docker_repository=external_interfaces.docker_repository, llm_artifact_gateway=external_interfaces.llm_artifact_gateway, + model_weights_manager=model_weights_manager, ) return await use_case.execute(user=auth, request=request) except ObjectAlreadyExistsException as exc: diff --git a/model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py b/model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py index 8f8ece698..9f12cf787 100644 --- a/model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py @@ -40,6 +40,17 @@ def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[ """ pass + @abstractmethod + def upload_files(self, local_path: str, remote_path: str, **kwargs) -> None: + """ + Upload all files from a local directory to a remote path. + + Args: + local_path (str): local directory containing files to upload + remote_path (str): remote destination path (s3://, gs://, or https://) + """ + pass + @abstractmethod def get_model_config(self, path: str, **kwargs) -> Dict[str, Any]: """ diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 50b000488..0804521d4 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -1322,12 +1322,14 @@ def __init__( model_endpoint_service: ModelEndpointService, docker_repository: DockerRepository, llm_artifact_gateway: LLMArtifactGateway, + model_weights_manager=None, ): self.authz_module = LiveAuthorizationModule() self.create_llm_model_bundle_use_case = create_llm_model_bundle_use_case self.model_endpoint_service = model_endpoint_service self.docker_repository = docker_repository self.llm_artifact_gateway = llm_artifact_gateway + self.model_weights_manager = model_weights_manager async def execute( self, user: User, request: CreateLLMModelEndpointV1Request @@ -1387,6 +1389,19 @@ async def execute( "Multinode endpoints are only supported for VLLM models." ) + # Resolve checkpoint path: auto-download from HF Hub to remote storage if not cached + checkpoint_path = request.checkpoint_path + if ( + checkpoint_path is None + and request.source == LLMSource.HUGGING_FACE + and self.model_weights_manager is not None + ): + models_info = SUPPORTED_MODELS_INFO.get(request.model_name) + if models_info and models_info.hf_repo: + checkpoint_path = await self.model_weights_manager.ensure_model_weights_available( + hf_repo=models_info.hf_repo + ) + bundle = await self.create_llm_model_bundle_use_case.execute( user, endpoint_name=request.name, @@ -1397,7 +1412,7 @@ async def execute( endpoint_type=request.endpoint_type, num_shards=request.num_shards, quantize=request.quantize, - checkpoint_path=request.checkpoint_path, + checkpoint_path=checkpoint_path, chat_template_override=request.chat_template_override, nodes_per_worker=request.nodes_per_worker, additional_args=request.model_dump(exclude_none=True), @@ -1430,7 +1445,7 @@ async def execute( inference_framework_image_tag=request.inference_framework_image_tag, num_shards=request.num_shards, quantize=request.quantize, - checkpoint_path=request.checkpoint_path, + checkpoint_path=checkpoint_path, chat_template_override=request.chat_template_override, ) ) diff --git a/model-engine/model_engine_server/domain/use_cases/model_weights_manager.py b/model-engine/model_engine_server/domain/use_cases/model_weights_manager.py new file mode 100644 index 000000000..aefd282b3 --- /dev/null +++ b/model-engine/model_engine_server/domain/use_cases/model_weights_manager.py @@ -0,0 +1,73 @@ +import asyncio +import functools +import tempfile +from typing import List + +from huggingface_hub import snapshot_download +from model_engine_server.common.config import hmi_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.gateways.llm_artifact_gateway import LLMArtifactGateway + +logger = make_logger(logger_name()) + +# Match the internal sync_model_weights.py inclusion/exclusion patterns +HF_IGNORE_PATTERNS: List[str] = [ + "optimizer*", + "*.msgpack", + "*.h5", + "flax_model*", + "tf_model*", + "rust_model*", +] + + +class ModelWeightsManager: + def __init__(self, llm_artifact_gateway: LLMArtifactGateway): + self.llm_artifact_gateway = llm_artifact_gateway + + def _get_remote_path(self, hf_repo: str) -> str: + prefix = hmi_config.hf_user_fine_tuned_weights_prefix.rstrip("/") + return f"{prefix}/{hf_repo}" + + async def ensure_model_weights_available(self, hf_repo: str) -> str: + """ + Ensures model weights for ``hf_repo`` are available at the configured remote path. + + If the weights are already cached (remote path is non-empty), returns immediately. + Otherwise downloads from HuggingFace Hub and uploads to the remote path. + + Args: + hf_repo: HuggingFace repository ID, e.g. ``"meta-llama/Meta-Llama-3-8B"``. + + Returns: + The remote path (s3://, gs://, or https://) where the weights are stored. + """ + remote_path = self._get_remote_path(hf_repo) + files = self.llm_artifact_gateway.list_files(remote_path) + if files: + logger.info(f"Cache hit: {len(files)} files at {remote_path}") + return remote_path + + logger.info(f"Cache miss for {hf_repo}. Downloading from HuggingFace Hub...") + loop = asyncio.get_event_loop() + with tempfile.TemporaryDirectory() as tmp_dir: + await loop.run_in_executor( + None, + functools.partial( + snapshot_download, + repo_id=hf_repo, + local_dir=tmp_dir, + ignore_patterns=HF_IGNORE_PATTERNS, + ), + ) + await loop.run_in_executor( + None, + functools.partial( + self.llm_artifact_gateway.upload_files, + tmp_dir, + remote_path, + ), + ) + + logger.info(f"Weights for {hf_repo} uploaded to {remote_path}") + return remote_path diff --git a/model-engine/model_engine_server/infra/gateways/abs_llm_artifact_gateway.py b/model-engine/model_engine_server/infra/gateways/abs_llm_artifact_gateway.py index a12361383..df574841e 100644 --- a/model-engine/model_engine_server/infra/gateways/abs_llm_artifact_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/abs_llm_artifact_gateway.py @@ -59,6 +59,16 @@ def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) downloaded_files.append(local_path) return downloaded_files + def upload_files(self, local_path: str, remote_path: str, **kwargs) -> None: + parsed = parse_attachment_url(remote_path, clean_key=False) + container_client = _get_abs_container_client(parsed.bucket) + for root, _, files in os.walk(local_path): + for file in files: + local_file = os.path.join(root, file) + blob_name = os.path.join(parsed.key, os.path.relpath(local_file, local_path)) + with open(local_file, "rb") as f: + container_client.upload_blob(name=blob_name, data=f, overwrite=True) + def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]: parsed_remote = parse_attachment_url( hmi_config.hf_user_fine_tuned_weights_prefix, clean_key=False diff --git a/model-engine/model_engine_server/infra/gateways/gcs_llm_artifact_gateway.py b/model-engine/model_engine_server/infra/gateways/gcs_llm_artifact_gateway.py index e8d78e5e6..3d726491d 100644 --- a/model-engine/model_engine_server/infra/gateways/gcs_llm_artifact_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/gcs_llm_artifact_gateway.py @@ -52,6 +52,16 @@ def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) downloaded_files.append(local_path) return downloaded_files + def upload_files(self, local_path: str, remote_path: str, **kwargs) -> None: + parsed = parse_attachment_url(remote_path, clean_key=False) + client = get_gcs_sync_client() + bucket = client.bucket(parsed.bucket) + for root, _, files in os.walk(local_path): + for file in files: + local_file = os.path.join(root, file) + blob_name = os.path.join(parsed.key, os.path.relpath(local_file, local_path)) + bucket.blob(blob_name).upload_from_filename(local_file) + def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]: parsed_remote = parse_attachment_url( hmi_config.hf_user_fine_tuned_weights_prefix, clean_key=False diff --git a/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py b/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py index 7b4219787..cfa80cf64 100644 --- a/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py @@ -56,6 +56,17 @@ def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) logger.info(f"Downloaded {len(downloaded_files)} files to {target_path}") return downloaded_files + def upload_files(self, local_path: str, remote_path: str, **kwargs) -> None: + s3 = get_s3_resource(kwargs) + parsed = parse_attachment_url(remote_path, clean_key=False) + bucket = s3.Bucket(parsed.bucket) + for root, _, files in os.walk(local_path): + for file in files: + local_file = os.path.join(root, file) + s3_key = os.path.join(parsed.key, os.path.relpath(local_file, local_path)) + logger.info(f"Uploading {local_file} → s3://{parsed.bucket}/{s3_key}") + bucket.upload_file(local_file, s3_key) + def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]: s3 = get_s3_resource(kwargs) parsed_remote = parse_attachment_url( diff --git a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py index 54e6436c1..ad78584ec 100644 --- a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py +++ b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py @@ -3,7 +3,11 @@ from typing import Dict, NamedTuple, Optional from huggingface_hub import list_repo_refs -from huggingface_hub.utils._errors import RepositoryNotFoundError + +try: + from huggingface_hub.utils._errors import RepositoryNotFoundError +except ImportError: + from huggingface_hub.errors import RepositoryNotFoundError from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.exceptions import ObjectNotFoundException from model_engine_server.domain.gateways.llm_artifact_gateway import LLMArtifactGateway diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index dce1b85ae..04dd9cfe1 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -862,6 +862,9 @@ def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) if path in self.s3_bucket: return self.s3_bucket[path] + def upload_files(self, local_path: str, remote_path: str, **kwargs) -> None: + pass + def get_model_weights_urls(self, owner: str, model_name: str): if (owner, model_name) in self.existing_models: return self.urls diff --git a/model-engine/tests/unit/domain/test_model_weights_manager.py b/model-engine/tests/unit/domain/test_model_weights_manager.py new file mode 100644 index 000000000..1f44a4654 --- /dev/null +++ b/model-engine/tests/unit/domain/test_model_weights_manager.py @@ -0,0 +1,212 @@ +"""Unit tests for ModelWeightsManager.""" + +from typing import Any, Dict, List +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from model_engine_server.domain.gateways.llm_artifact_gateway import LLMArtifactGateway +from model_engine_server.domain.use_cases.model_weights_manager import ( + HF_IGNORE_PATTERNS, + ModelWeightsManager, +) + + +class FakeArtifactGateway(LLMArtifactGateway): + """Minimal fake gateway for testing.""" + + def __init__(self, existing_files: List[str] = None, uploaded: List[tuple] = None): + self._existing_files = existing_files if existing_files is not None else [] + self.uploaded: List[tuple] = uploaded if uploaded is not None else [] + + def list_files(self, path: str, **kwargs) -> List[str]: + return self._existing_files + + def upload_files(self, local_path: str, remote_path: str, **kwargs) -> None: + self.uploaded.append((local_path, remote_path)) + + def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) -> List[str]: + return [] + + def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]: + return [] + + def get_model_config(self, path: str, **kwargs) -> Dict[str, Any]: + return {} + + +@pytest.mark.asyncio +async def test_cache_hit_skips_download(): + """When list_files returns non-empty, no download or upload should occur.""" + gateway = FakeArtifactGateway(existing_files=["model.safetensors"]) + manager = ModelWeightsManager(llm_artifact_gateway=gateway) + + with patch( + "model_engine_server.domain.use_cases.model_weights_manager.snapshot_download" + ) as mock_download: + result = await manager.ensure_model_weights_available("meta-llama/Meta-Llama-3-8B") + + mock_download.assert_not_called() + assert len(gateway.uploaded) == 0 + assert "meta-llama/Meta-Llama-3-8B" in result + + +@pytest.mark.asyncio +async def test_cache_hit_returns_correct_s3_path(monkeypatch): + """On cache hit the returned path should be {prefix}/{hf_repo}.""" + monkeypatch.setattr( + "model_engine_server.domain.use_cases.model_weights_manager.hmi_config", + MagicMock(hf_user_fine_tuned_weights_prefix="s3://my-bucket/weights"), + ) + gateway = FakeArtifactGateway(existing_files=["file.bin"]) + manager = ModelWeightsManager(llm_artifact_gateway=gateway) + + with patch("model_engine_server.domain.use_cases.model_weights_manager.snapshot_download"): + result = await manager.ensure_model_weights_available("org/model") + + assert result == "s3://my-bucket/weights/org/model" + + +@pytest.mark.asyncio +async def test_cache_miss_calls_snapshot_download_and_upload(tmp_path, monkeypatch): + """On cache miss, snapshot_download and upload_files should both be called.""" + monkeypatch.setattr( + "model_engine_server.domain.use_cases.model_weights_manager.hmi_config", + MagicMock(hf_user_fine_tuned_weights_prefix="s3://my-bucket/weights"), + ) + + gateway = FakeArtifactGateway(existing_files=[]) + manager = ModelWeightsManager(llm_artifact_gateway=gateway) + + with patch( + "model_engine_server.domain.use_cases.model_weights_manager.snapshot_download" + ) as mock_download: + result = await manager.ensure_model_weights_available("org/model") + + mock_download.assert_called_once() + call_kwargs = mock_download.call_args + assert call_kwargs.kwargs["repo_id"] == "org/model" + assert call_kwargs.kwargs["ignore_patterns"] == HF_IGNORE_PATTERNS + + assert len(gateway.uploaded) == 1 + _local, remote = gateway.uploaded[0] + assert remote == "s3://my-bucket/weights/org/model" + assert result == "s3://my-bucket/weights/org/model" + + +@pytest.mark.asyncio +async def test_s3_path_construction(monkeypatch): + """Remote path should be {prefix}/{hf_repo} with correct stripping of trailing slash.""" + monkeypatch.setattr( + "model_engine_server.domain.use_cases.model_weights_manager.hmi_config", + MagicMock(hf_user_fine_tuned_weights_prefix="s3://bucket/prefix/"), + ) + gateway = FakeArtifactGateway(existing_files=[]) + manager = ModelWeightsManager(llm_artifact_gateway=gateway) + + path = manager._get_remote_path("myorg/mymodel") + assert path == "s3://bucket/prefix/myorg/mymodel" + + +@pytest.mark.asyncio +async def test_create_llm_model_endpoint_calls_weights_manager_on_hf_source(): + """CreateLLMModelEndpointV1UseCase should call model_weights_manager when source is HF and checkpoint_path is None.""" + from model_engine_server.domain.entities import LLMSource + from model_engine_server.domain.use_cases.model_weights_manager import ModelWeightsManager + + mock_manager = MagicMock(spec=ModelWeightsManager) + mock_manager.ensure_model_weights_available = AsyncMock( + return_value="s3://bucket/weights/huggyllama/llama-7b" + ) + + # Use a real SUPPORTED_MODELS_INFO entry: "llama-2-7b" -> "huggyllama/llama-7b" + from tests.unit.conftest import FakeLLMArtifactGateway + + fake_gateway = FakeLLMArtifactGateway() + # Ensure the resolved checkpoint path is found in the fake bucket + fake_gateway.s3_bucket["s3://bucket/weights/huggyllama/llama-7b"] = ["model.safetensors"] + + from model_engine_server.domain.use_cases.llm_model_endpoint_use_cases import ( + CreateLLMModelEndpointV1UseCase, + ) + + mock_bundle_use_case = MagicMock() + mock_bundle = MagicMock() + mock_bundle.id = "bundle-id" + mock_bundle_use_case.execute = AsyncMock(return_value=mock_bundle) + + mock_endpoint_service = MagicMock() + mock_endpoint_record = MagicMock() + mock_endpoint_record.id = "endpoint-id" + mock_endpoint_record.creation_task_id = "task-123" + mock_endpoint_service.create_model_endpoint = AsyncMock(return_value=mock_endpoint_record) + mock_endpoint_service.can_scale_http_endpoint_from_zero = MagicMock(return_value=False) + mock_endpoint_service.get_inference_autoscaling_metrics_gateway.return_value.emit_prewarm_metric = ( + AsyncMock() + ) + + mock_docker_repository = MagicMock() + + use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=mock_bundle_use_case, + model_endpoint_service=mock_endpoint_service, + docker_repository=mock_docker_repository, + llm_artifact_gateway=fake_gateway, + model_weights_manager=mock_manager, + ) + + from model_engine_server.common.dtos.llms import CreateLLMModelEndpointV1Request + from model_engine_server.core.auth.authentication_repository import User + from pydantic import TypeAdapter + + user = User(user_id="test-user", team_id="test-team", is_privileged_user=True) + request = TypeAdapter(CreateLLMModelEndpointV1Request).validate_python( + { + "name": "test-endpoint", + "model_name": "llama-2-7b", + "source": LLMSource.HUGGING_FACE, + "inference_framework": "vllm", + "inference_framework_image_tag": "0.1.0", + "num_shards": 1, + "endpoint_type": "streaming", + "checkpoint_path": None, + "min_workers": 1, + "max_workers": 1, + "per_worker": 10, + "cpus": 4, + "memory": "16Gi", + "storage": "50Gi", + "gpus": 1, + "gpu_type": "nvidia-ampere-a10", + "nodes_per_worker": 1, + "labels": {"team": "test"}, + "metadata": {}, + } + ) + + # Patch infrastructure helpers to keep the test focused on weights manager behavior + base = "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases" + with ( + patch(f"{base}._fill_hardware_info"), + patch(f"{base}.validate_resource_requests"), + patch(f"{base}.validate_deployment_resources"), + patch(f"{base}.validate_labels"), + patch(f"{base}.validate_billing_tags"), + patch(f"{base}.validate_post_inference_hooks"), + patch(f"{base}.validate_model_name"), + patch(f"{base}.validate_num_shards"), + patch(f"{base}.validate_quantization"), + patch(f"{base}.validate_chat_template"), + patch(f"{base}.LiveAuthorizationModule") as mock_authz, + ): + mock_authz.return_value.get_aws_role_for_user = MagicMock( + return_value="arn:aws:iam::123:role/test" + ) + mock_authz.return_value.get_s3_bucket_for_user = MagicMock(return_value="test-bucket") + await use_case.execute(user=user, request=request) + + mock_manager.ensure_model_weights_available.assert_called_once_with( + hf_repo="huggyllama/llama-7b" + ) + # Verify that the resolved checkpoint path was forwarded to the bundle use case + bundle_call_kwargs = mock_bundle_use_case.execute.call_args.kwargs + assert bundle_call_kwargs["checkpoint_path"] == "s3://bucket/weights/huggyllama/llama-7b" From b14deb6ebaa554ede6aa3ce57cc0c6c9cd3e3b6a Mon Sep 17 00:00:00 2001 From: lilyz-ai Date: Fri, 20 Feb 2026 05:29:06 +0000 Subject: [PATCH 2/4] feat: make HF weights sync non-blocking with K8s init container MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ensure_model_weights_available is now synchronous — it returns the expected checkpoint path immediately and fires a background asyncio task to sync weights from HuggingFace Hub. An init container is injected into the K8s deployment to poll storage until the weights are present before the main container starts. LLMMetadata gains an hf_weights_syncing flag to signal this flow downstream. Co-Authored-By: Claude Sonnet 4.6 --- .../domain/entities/llm_entity.py | 1 + .../use_cases/llm_model_endpoint_use_cases.py | 12 ++-- .../domain/use_cases/model_weights_manager.py | 24 +++++--- .../k8s_endpoint_resource_delegate.py | 61 +++++++++++++++++++ .../unit/domain/test_model_weights_manager.py | 50 ++++++++------- 5 files changed, 114 insertions(+), 34 deletions(-) diff --git a/model-engine/model_engine_server/domain/entities/llm_entity.py b/model-engine/model_engine_server/domain/entities/llm_entity.py index dc6c06090..cf4e4a3e3 100644 --- a/model-engine/model_engine_server/domain/entities/llm_entity.py +++ b/model-engine/model_engine_server/domain/entities/llm_entity.py @@ -31,3 +31,4 @@ class LLMMetadata: quantize: Optional[Quantization] = None checkpoint_path: Optional[str] = None chat_template_override: Optional[str] = None + hf_weights_syncing: bool = False diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 0804521d4..9bb347df9 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -652,7 +652,8 @@ def load_model_weights_sub_commands_s3( s5cmd = "./s5cmd" checkpoint_files = self.llm_artifact_gateway.list_files(checkpoint_path) - validate_checkpoint_files(checkpoint_files) + if checkpoint_files: + validate_checkpoint_files(checkpoint_files) # filter to configs ('*.model' and '*.json') and weights ('*.safetensors') # For models that are not supported by transformers directly, we need to include '*.py' and '*.bin' @@ -1389,8 +1390,9 @@ async def execute( "Multinode endpoints are only supported for VLLM models." ) - # Resolve checkpoint path: auto-download from HF Hub to remote storage if not cached + # Resolve checkpoint path: fires background sync and returns expected path immediately checkpoint_path = request.checkpoint_path + hf_weights_syncing = False if ( checkpoint_path is None and request.source == LLMSource.HUGGING_FACE @@ -1398,9 +1400,10 @@ async def execute( ): models_info = SUPPORTED_MODELS_INFO.get(request.model_name) if models_info and models_info.hf_repo: - checkpoint_path = await self.model_weights_manager.ensure_model_weights_available( - hf_repo=models_info.hf_repo + checkpoint_path = self.model_weights_manager.ensure_model_weights_available( + models_info.hf_repo ) + hf_weights_syncing = True bundle = await self.create_llm_model_bundle_use_case.execute( user, @@ -1447,6 +1450,7 @@ async def execute( quantize=request.quantize, checkpoint_path=checkpoint_path, chat_template_override=request.chat_template_override, + hf_weights_syncing=hf_weights_syncing, ) ) diff --git a/model-engine/model_engine_server/domain/use_cases/model_weights_manager.py b/model-engine/model_engine_server/domain/use_cases/model_weights_manager.py index aefd282b3..f5c422df3 100644 --- a/model-engine/model_engine_server/domain/use_cases/model_weights_manager.py +++ b/model-engine/model_engine_server/domain/use_cases/model_weights_manager.py @@ -25,28 +25,35 @@ class ModelWeightsManager: def __init__(self, llm_artifact_gateway: LLMArtifactGateway): self.llm_artifact_gateway = llm_artifact_gateway - def _get_remote_path(self, hf_repo: str) -> str: + def get_remote_path(self, hf_repo: str) -> str: prefix = hmi_config.hf_user_fine_tuned_weights_prefix.rstrip("/") return f"{prefix}/{hf_repo}" - async def ensure_model_weights_available(self, hf_repo: str) -> str: + def ensure_model_weights_available(self, hf_repo: str) -> str: """ - Ensures model weights for ``hf_repo`` are available at the configured remote path. + Returns the expected remote path for ``hf_repo`` immediately and starts + syncing weights from HuggingFace Hub to that path in the background. - If the weights are already cached (remote path is non-empty), returns immediately. - Otherwise downloads from HuggingFace Hub and uploads to the remote path. + If the weights are already cached the background task exits early. + Callers receive the checkpoint path right away and can proceed with + any following actions (e.g. endpoint creation) without blocking. Args: hf_repo: HuggingFace repository ID, e.g. ``"meta-llama/Meta-Llama-3-8B"``. Returns: - The remote path (s3://, gs://, or https://) where the weights are stored. + The remote path (s3://, gs://, or https://) where the weights will be stored. """ - remote_path = self._get_remote_path(hf_repo) + remote_path = self.get_remote_path(hf_repo) + asyncio.create_task(self._sync_weights(hf_repo, remote_path)) + return remote_path + + async def _sync_weights(self, hf_repo: str, remote_path: str) -> None: + """Downloads weights from HuggingFace Hub and uploads to remote storage if not cached.""" files = self.llm_artifact_gateway.list_files(remote_path) if files: logger.info(f"Cache hit: {len(files)} files at {remote_path}") - return remote_path + return logger.info(f"Cache miss for {hf_repo}. Downloading from HuggingFace Hub...") loop = asyncio.get_event_loop() @@ -70,4 +77,3 @@ async def ensure_model_weights_available(self, hf_repo: str) -> str: ) logger.info(f"Weights for {hf_repo} uploaded to {remote_path}") - return remote_path diff --git a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py index 9c6de78c1..2f922fb49 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py @@ -62,6 +62,28 @@ BASE_PATH_IN_ENDPOINT = "/app" DATADOG_ENV_VAR = {"DD_TRACE_ENABLED", "DD_SERVICE", "DD_ENV", "DD_VERSION", "DD_AGENT_HOST"} + +# Key under which LLM metadata is stored in model_endpoint_record.metadata +_LLM_METADATA_KEY = "_llm" + +# Python script run by the init container to poll storage until HF weights are present. +_HF_WEIGHTS_POLL_SCRIPT = """\ +import boto3, os, sys, time +from urllib.parse import urlparse + +cp = os.environ["CHECKPOINT_PATH"] +url = urlparse(cp) +bucket = url.netloc +prefix = url.path.lstrip("/") +s3 = boto3.client("s3") +while True: + resp = s3.list_objects_v2(Bucket=bucket, Prefix=prefix, MaxKeys=1) + if resp.get("Contents"): + print(f"Model weights ready at {cp}", flush=True) + sys.exit(0) + print(f"Waiting for model weights at {cp}...", flush=True) + time.sleep(30) +""" LWS_DEFAULT_ENV_VAR = { "K8S_OWN_POD_NAME", "K8S_OWN_NAMESPACE", @@ -339,6 +361,42 @@ def add_pod_metadata_env_to_container(container: Dict[str, Any]) -> None: ) +def add_hf_weights_init_container( + deployment_template: Dict[str, Any], + checkpoint_path: str, +) -> None: + """Prepend an init container that polls storage until HF weights are present. + + Uses the forwarder image (model-engine gateway image, which has Python and + boto3) so no additional image pull is required. Authentication relies on + the pod's service account (IRSA / workload-identity). + """ + containers = deployment_template["spec"]["template"]["spec"]["containers"] + # Prefer the forwarder container image; fall back to the first container. + forwarder_image = next( + (c["image"] for c in containers if c["name"] in ("http-forwarder", "celery-forwarder")), + containers[0]["image"], + ) + + init_container: Dict[str, Any] = { + "name": "wait-for-model-weights", + "image": forwarder_image, + "env": [{"name": "CHECKPOINT_PATH", "value": checkpoint_path}], + "command": ["python3", "-c", _HF_WEIGHTS_POLL_SCRIPT], + } + + # Reuse the AWS config volume mount if the volume is present in the pod spec + volumes = deployment_template["spec"]["template"]["spec"].get("volumes", []) + if any(v["name"] == "config-volume" for v in volumes): + init_container["volumeMounts"] = [ + {"name": "config-volume", "mountPath": "/opt/.aws/config", "subPath": "config"} + ] + + deployment_template["spec"]["template"]["spec"].setdefault("initContainers", []).append( + init_container + ) + + def add_lws_default_env_vars_to_container(container: Dict[str, Any]) -> None: container_envs = [] container_envs.extend( @@ -1657,6 +1715,9 @@ async def _create_or_update_resources( user_container = get_main_container_from_deployment_template(deployment_template) add_datadog_env_to_container(deployment_template, user_container) add_pod_metadata_env_to_container(user_container) + llm_metadata = (model_endpoint_record.metadata or {}).get(_LLM_METADATA_KEY, {}) + if llm_metadata.get("hf_weights_syncing") and llm_metadata.get("checkpoint_path"): + add_hf_weights_init_container(deployment_template, llm_metadata["checkpoint_path"]) await self._create_deployment( model_endpoint_record=request.build_endpoint_request.model_endpoint_record, deployment=deployment_template, diff --git a/model-engine/tests/unit/domain/test_model_weights_manager.py b/model-engine/tests/unit/domain/test_model_weights_manager.py index 1f44a4654..60431d943 100644 --- a/model-engine/tests/unit/domain/test_model_weights_manager.py +++ b/model-engine/tests/unit/domain/test_model_weights_manager.py @@ -40,10 +40,14 @@ async def test_cache_hit_skips_download(): gateway = FakeArtifactGateway(existing_files=["model.safetensors"]) manager = ModelWeightsManager(llm_artifact_gateway=gateway) - with patch( - "model_engine_server.domain.use_cases.model_weights_manager.snapshot_download" - ) as mock_download: - result = await manager.ensure_model_weights_available("meta-llama/Meta-Llama-3-8B") + mwm_base = "model_engine_server.domain.use_cases.model_weights_manager" + with ( + patch(f"{mwm_base}.snapshot_download") as mock_download, + patch(f"{mwm_base}.asyncio.create_task") as mock_create_task, + ): + result = manager.ensure_model_weights_available("meta-llama/Meta-Llama-3-8B") + # Run the background sync task to assert on side-effects + await mock_create_task.call_args[0][0] mock_download.assert_not_called() assert len(gateway.uploaded) == 0 @@ -60,8 +64,10 @@ async def test_cache_hit_returns_correct_s3_path(monkeypatch): gateway = FakeArtifactGateway(existing_files=["file.bin"]) manager = ModelWeightsManager(llm_artifact_gateway=gateway) - with patch("model_engine_server.domain.use_cases.model_weights_manager.snapshot_download"): - result = await manager.ensure_model_weights_available("org/model") + mwm_base = "model_engine_server.domain.use_cases.model_weights_manager" + with patch(f"{mwm_base}.asyncio.create_task") as mock_create_task: + result = manager.ensure_model_weights_available("org/model") + await mock_create_task.call_args[0][0] assert result == "s3://my-bucket/weights/org/model" @@ -77,10 +83,14 @@ async def test_cache_miss_calls_snapshot_download_and_upload(tmp_path, monkeypat gateway = FakeArtifactGateway(existing_files=[]) manager = ModelWeightsManager(llm_artifact_gateway=gateway) - with patch( - "model_engine_server.domain.use_cases.model_weights_manager.snapshot_download" - ) as mock_download: - result = await manager.ensure_model_weights_available("org/model") + mwm_base = "model_engine_server.domain.use_cases.model_weights_manager" + with ( + patch(f"{mwm_base}.snapshot_download") as mock_download, + patch(f"{mwm_base}.asyncio.create_task") as mock_create_task, + ): + result = manager.ensure_model_weights_available("org/model") + # Run the background sync task so we can assert on its side-effects + await mock_create_task.call_args[0][0] mock_download.assert_called_once() call_kwargs = mock_download.call_args @@ -93,8 +103,7 @@ async def test_cache_miss_calls_snapshot_download_and_upload(tmp_path, monkeypat assert result == "s3://my-bucket/weights/org/model" -@pytest.mark.asyncio -async def test_s3_path_construction(monkeypatch): +def test_s3_path_construction(monkeypatch): """Remote path should be {prefix}/{hf_repo} with correct stripping of trailing slash.""" monkeypatch.setattr( "model_engine_server.domain.use_cases.model_weights_manager.hmi_config", @@ -103,27 +112,27 @@ async def test_s3_path_construction(monkeypatch): gateway = FakeArtifactGateway(existing_files=[]) manager = ModelWeightsManager(llm_artifact_gateway=gateway) - path = manager._get_remote_path("myorg/mymodel") + path = manager.get_remote_path("myorg/mymodel") assert path == "s3://bucket/prefix/myorg/mymodel" @pytest.mark.asyncio async def test_create_llm_model_endpoint_calls_weights_manager_on_hf_source(): - """CreateLLMModelEndpointV1UseCase should call model_weights_manager when source is HF and checkpoint_path is None.""" + """CreateLLMModelEndpointV1UseCase should call ensure_model_weights_available (sync), + which returns the expected checkpoint path immediately and fires weight sync in the + background. All following actions (bundle, endpoint creation) proceed without blocking.""" from model_engine_server.domain.entities import LLMSource from model_engine_server.domain.use_cases.model_weights_manager import ModelWeightsManager mock_manager = MagicMock(spec=ModelWeightsManager) - mock_manager.ensure_model_weights_available = AsyncMock( - return_value="s3://bucket/weights/huggyllama/llama-7b" + mock_manager.ensure_model_weights_available.return_value = ( + "s3://bucket/weights/huggyllama/llama-7b" ) # Use a real SUPPORTED_MODELS_INFO entry: "llama-2-7b" -> "huggyllama/llama-7b" from tests.unit.conftest import FakeLLMArtifactGateway fake_gateway = FakeLLMArtifactGateway() - # Ensure the resolved checkpoint path is found in the fake bucket - fake_gateway.s3_bucket["s3://bucket/weights/huggyllama/llama-7b"] = ["model.safetensors"] from model_engine_server.domain.use_cases.llm_model_endpoint_use_cases import ( CreateLLMModelEndpointV1UseCase, @@ -204,9 +213,8 @@ async def test_create_llm_model_endpoint_calls_weights_manager_on_hf_source(): mock_authz.return_value.get_s3_bucket_for_user = MagicMock(return_value="test-bucket") await use_case.execute(user=user, request=request) - mock_manager.ensure_model_weights_available.assert_called_once_with( - hf_repo="huggyllama/llama-7b" - ) + # ensure_model_weights_available is called synchronously — no await, no blocking + mock_manager.ensure_model_weights_available.assert_called_once_with("huggyllama/llama-7b") # Verify that the resolved checkpoint path was forwarded to the bundle use case bundle_call_kwargs = mock_bundle_use_case.execute.call_args.kwargs assert bundle_call_kwargs["checkpoint_path"] == "s3://bucket/weights/huggyllama/llama-7b" From 493d1cbac6e1140e7391045236357ee4a5eafbeb Mon Sep 17 00:00:00 2001 From: lilyz-ai Date: Fri, 20 Feb 2026 06:26:17 +0000 Subject: [PATCH 3/4] fix: harden ModelWeightsManager background task reliability - Hold a strong set reference to each asyncio.Task to prevent GC cancellation - Deduplicate concurrent sync requests for the same hf_repo via _in_progress dict - Surface task exceptions via logger.error in _on_task_done callback - Store ModelWeightsManager as app.state singleton so state persists across requests - Add recover_hf_syncs startup handler to re-trigger syncs after server restart Co-Authored-By: Claude Sonnet 4.6 --- model-engine/model_engine_server/api/app.py | 50 +++++++++++++++++ .../model_engine_server/api/llms_v1.py | 12 ++-- .../domain/use_cases/model_weights_manager.py | 25 ++++++++- .../unit/domain/test_model_weights_manager.py | 55 +++++++++++++++++++ 4 files changed, 133 insertions(+), 9 deletions(-) diff --git a/model-engine/model_engine_server/api/app.py b/model-engine/model_engine_server/api/app.py index cac68cda2..f75e97ffc 100644 --- a/model-engine/model_engine_server/api/app.py +++ b/model-engine/model_engine_server/api/app.py @@ -312,6 +312,56 @@ def load_redis(): get_or_create_aioredis_pool() +@app.on_event("startup") +def init_model_weights_manager(): + from model_engine_server.core.config import infra_config + from model_engine_server.domain.use_cases.model_weights_manager import ModelWeightsManager + from model_engine_server.infra.gateways import ( + ABSLLMArtifactGateway, + GCSLLMArtifactGateway, + S3LLMArtifactGateway, + ) + + provider = infra_config().cloud_provider + if provider == "azure": + gateway = ABSLLMArtifactGateway() + elif provider == "gcp": + gateway = GCSLLMArtifactGateway() + else: + gateway = S3LLMArtifactGateway() + app.state.model_weights_manager = ModelWeightsManager(llm_artifact_gateway=gateway) + + +@app.on_event("startup") +async def recover_hf_syncs(): + """Re-trigger weight syncs for endpoints that were syncing when server last stopped.""" + from model_engine_server.db.base import get_session_async + from model_engine_server.infra.repositories.live_tokenizer_repository import ( + SUPPORTED_MODELS_INFO, + ) + from sqlalchemy import text + + session_factory = get_session_async() + try: + async with session_factory() as session: + result = await session.execute( + text( + "SELECT DISTINCT endpoint_metadata->'_llm'->>'model_name' AS model_name " + "FROM endpoints " + "WHERE (endpoint_metadata->'_llm'->>'hf_weights_syncing')::boolean = true" + ) + ) + model_names = [row.model_name for row in result if row.model_name] + except Exception: + logger.warning("Could not query pending HF sync endpoints at startup", exc_info=True) + return + for model_name in model_names: + info = SUPPORTED_MODELS_INFO.get(model_name) + if info and info.hf_repo: + app.state.model_weights_manager.ensure_model_weights_available(info.hf_repo) + logger.info(f"Startup: re-triggered HF weight sync for {model_name}") + + def healthcheck() -> Response: """Returns 200 if the app is healthy.""" return Response(status_code=200) diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 6c5c8c244..99af6a273 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -86,7 +86,6 @@ UpdateLLMModelEndpointV1UseCase, ) from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase -from model_engine_server.domain.use_cases.model_weights_manager import ModelWeightsManager from pydantic import RootModel from sse_starlette.sse import EventSourceResponse @@ -149,14 +148,15 @@ def handle_streaming_exception( @llm_router_v1.post("/model-endpoints", response_model=CreateLLMModelEndpointV1Response) async def create_model_endpoint( wrapped_request: RootModel[CreateLLMModelEndpointV1Request], + request: Request, auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), ) -> CreateLLMModelEndpointV1Response: - request = wrapped_request.root + llm_request = wrapped_request.root """ Creates an LLM endpoint for the current user. """ - logger.info(f"POST /llm/model-endpoints with {request} for {auth}") + logger.info(f"POST /llm/model-endpoints with {llm_request} for {auth}") try: create_model_bundle_use_case = CreateModelBundleV2UseCase( model_bundle_repository=external_interfaces.model_bundle_repository, @@ -169,9 +169,7 @@ async def create_model_endpoint( llm_artifact_gateway=external_interfaces.llm_artifact_gateway, docker_repository=external_interfaces.docker_repository, ) - model_weights_manager = ModelWeightsManager( - llm_artifact_gateway=external_interfaces.llm_artifact_gateway, - ) + model_weights_manager = request.app.state.model_weights_manager use_case = CreateLLMModelEndpointV1UseCase( create_llm_model_bundle_use_case=create_llm_model_bundle_use_case, model_endpoint_service=external_interfaces.model_endpoint_service, @@ -179,7 +177,7 @@ async def create_model_endpoint( llm_artifact_gateway=external_interfaces.llm_artifact_gateway, model_weights_manager=model_weights_manager, ) - return await use_case.execute(user=auth, request=request) + return await use_case.execute(user=auth, request=llm_request) except ObjectAlreadyExistsException as exc: raise HTTPException( status_code=400, diff --git a/model-engine/model_engine_server/domain/use_cases/model_weights_manager.py b/model-engine/model_engine_server/domain/use_cases/model_weights_manager.py index f5c422df3..88b3bb4b3 100644 --- a/model-engine/model_engine_server/domain/use_cases/model_weights_manager.py +++ b/model-engine/model_engine_server/domain/use_cases/model_weights_manager.py @@ -1,7 +1,7 @@ import asyncio import functools import tempfile -from typing import List +from typing import Dict, List, Set from huggingface_hub import snapshot_download from model_engine_server.common.config import hmi_config @@ -24,6 +24,8 @@ class ModelWeightsManager: def __init__(self, llm_artifact_gateway: LLMArtifactGateway): self.llm_artifact_gateway = llm_artifact_gateway + self._background_tasks: Set[asyncio.Task] = set() + self._in_progress: Dict[str, asyncio.Task] = {} def get_remote_path(self, hf_repo: str) -> str: prefix = hmi_config.hf_user_fine_tuned_weights_prefix.rstrip("/") @@ -38,6 +40,10 @@ def ensure_model_weights_available(self, hf_repo: str) -> str: Callers receive the checkpoint path right away and can proceed with any following actions (e.g. endpoint creation) without blocking. + A second call for the same ``hf_repo`` while a sync is already in + progress is a no-op: the existing task is reused and the same remote + path is returned. + Args: hf_repo: HuggingFace repository ID, e.g. ``"meta-llama/Meta-Llama-3-8B"``. @@ -45,9 +51,24 @@ def ensure_model_weights_available(self, hf_repo: str) -> str: The remote path (s3://, gs://, or https://) where the weights will be stored. """ remote_path = self.get_remote_path(hf_repo) - asyncio.create_task(self._sync_weights(hf_repo, remote_path)) + if hf_repo not in self._in_progress: + task = asyncio.create_task(self._sync_weights(hf_repo, remote_path)) + self._background_tasks.add(task) + self._in_progress[hf_repo] = task + task.add_done_callback(lambda t: self._on_task_done(t, hf_repo)) return remote_path + def _on_task_done(self, task: asyncio.Task, hf_repo: str) -> None: + self._background_tasks.discard(task) + self._in_progress.pop(hf_repo, None) + if not task.cancelled(): + exc = task.exception() + if exc: + logger.error( + f"Background weight sync failed for {hf_repo}: {exc}", + exc_info=exc, + ) + async def _sync_weights(self, hf_repo: str, remote_path: str) -> None: """Downloads weights from HuggingFace Hub and uploads to remote storage if not cached.""" files = self.llm_artifact_gateway.list_files(remote_path) diff --git a/model-engine/tests/unit/domain/test_model_weights_manager.py b/model-engine/tests/unit/domain/test_model_weights_manager.py index 60431d943..78f8054c2 100644 --- a/model-engine/tests/unit/domain/test_model_weights_manager.py +++ b/model-engine/tests/unit/domain/test_model_weights_manager.py @@ -116,6 +116,61 @@ def test_s3_path_construction(monkeypatch): assert path == "s3://bucket/prefix/myorg/mymodel" +def test_deduplication_same_hf_repo(): + """Second call for same hf_repo while a sync is in progress should not create a new task.""" + gateway = FakeArtifactGateway(existing_files=[]) + manager = ModelWeightsManager(llm_artifact_gateway=gateway) + + mwm_base = "model_engine_server.domain.use_cases.model_weights_manager" + with patch(f"{mwm_base}.asyncio.create_task") as mock_create_task: + result1 = manager.ensure_model_weights_available("org/model") + result2 = manager.ensure_model_weights_available("org/model") + + assert mock_create_task.call_count == 1 + assert result1 == result2 + + +def test_task_reference_held_until_done(): + """_background_tasks should hold a reference to the task until _on_task_done fires.""" + gateway = FakeArtifactGateway(existing_files=[]) + manager = ModelWeightsManager(llm_artifact_gateway=gateway) + + mwm_base = "model_engine_server.domain.use_cases.model_weights_manager" + mock_task = MagicMock() + with patch(f"{mwm_base}.asyncio.create_task", return_value=mock_task): + manager.ensure_model_weights_available("org/model") + + assert mock_task in manager._background_tasks + assert "org/model" in manager._in_progress + + # Simulate successful task completion via the done callback + mock_task.cancelled.return_value = False + mock_task.exception.return_value = None + manager._on_task_done(mock_task, "org/model") + + assert mock_task not in manager._background_tasks + assert "org/model" not in manager._in_progress + + +def test_error_surfaced_on_task_failure(): + """When the background task raises, _on_task_done should log the error.""" + gateway = FakeArtifactGateway(existing_files=[]) + manager = ModelWeightsManager(llm_artifact_gateway=gateway) + + mock_task = MagicMock() + mock_task.cancelled.return_value = False + exc = RuntimeError("Download failed") + mock_task.exception.return_value = exc + + mwm_base = "model_engine_server.domain.use_cases.model_weights_manager" + with patch(f"{mwm_base}.logger") as mock_logger: + manager._on_task_done(mock_task, "org/model") + mock_logger.error.assert_called_once() + call_args = mock_logger.error.call_args + assert "org/model" in call_args[0][0] + assert call_args[1]["exc_info"] == exc + + @pytest.mark.asyncio async def test_create_llm_model_endpoint_calls_weights_manager_on_hf_source(): """CreateLLMModelEndpointV1UseCase should call ensure_model_weights_available (sync), From 3a18aee80d8f3565d55d8aecd061319a3a34d3de Mon Sep 17 00:00:00 2001 From: lilyz-ai Date: Fri, 20 Feb 2026 06:44:52 +0000 Subject: [PATCH 4/4] fix: use fully-qualified schema in recover_hf_syncs SQL query The endpoints table lives in hosted_model_inference schema; bare 'FROM endpoints' would fail at runtime. Co-Authored-By: Claude Sonnet 4.6 --- model-engine/model_engine_server/api/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/api/app.py b/model-engine/model_engine_server/api/app.py index f75e97ffc..246a813f5 100644 --- a/model-engine/model_engine_server/api/app.py +++ b/model-engine/model_engine_server/api/app.py @@ -347,7 +347,7 @@ async def recover_hf_syncs(): result = await session.execute( text( "SELECT DISTINCT endpoint_metadata->'_llm'->>'model_name' AS model_name " - "FROM endpoints " + "FROM hosted_model_inference.endpoints " "WHERE (endpoint_metadata->'_llm'->>'hf_weights_syncing')::boolean = true" ) )