Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions model-engine/model_engine_server/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,56 @@ def load_redis():
get_or_create_aioredis_pool()


@app.on_event("startup")
def init_model_weights_manager():
from model_engine_server.core.config import infra_config
from model_engine_server.domain.use_cases.model_weights_manager import ModelWeightsManager
from model_engine_server.infra.gateways import (
ABSLLMArtifactGateway,
GCSLLMArtifactGateway,
S3LLMArtifactGateway,
)

provider = infra_config().cloud_provider
if provider == "azure":
gateway = ABSLLMArtifactGateway()
elif provider == "gcp":
gateway = GCSLLMArtifactGateway()
else:
gateway = S3LLMArtifactGateway()
app.state.model_weights_manager = ModelWeightsManager(llm_artifact_gateway=gateway)


@app.on_event("startup")
async def recover_hf_syncs():
"""Re-trigger weight syncs for endpoints that were syncing when server last stopped."""
from model_engine_server.db.base import get_session_async
from model_engine_server.infra.repositories.live_tokenizer_repository import (
SUPPORTED_MODELS_INFO,
)
from sqlalchemy import text

session_factory = get_session_async()
try:
async with session_factory() as session:
result = await session.execute(
text(
"SELECT DISTINCT endpoint_metadata->'_llm'->>'model_name' AS model_name "
"FROM hosted_model_inference.endpoints "
"WHERE (endpoint_metadata->'_llm'->>'hf_weights_syncing')::boolean = true"
)
)
model_names = [row.model_name for row in result if row.model_name]
except Exception:
logger.warning("Could not query pending HF sync endpoints at startup", exc_info=True)
return
for model_name in model_names:
info = SUPPORTED_MODELS_INFO.get(model_name)
if info and info.hf_repo:
app.state.model_weights_manager.ensure_model_weights_available(info.hf_repo)
logger.info(f"Startup: re-triggered HF weight sync for {model_name}")
Comment on lines +336 to +362
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recover_hf_syncs re-triggers syncs on every restart, indefinitely

Since hf_weights_syncing is never cleared to false after a successful sync (the _on_task_done callback only cleans up in-memory state, not the DB), this startup handler will re-trigger downloads for every endpoint that has ever had hf_weights_syncing=true — even if their weights were successfully synced long ago.

This means on every server restart:

  1. Unnecessary background snapshot_download + upload_files tasks are spawned for already-synced models (the cache-hit path returns quickly, but still issues a list_files I/O call per model).
  2. The init container (add_hf_weights_init_container) will continue to be added to every subsequent deployment of these endpoints, adding startup latency even when weights are present.

Consider either:

  • Adding a step in _on_task_done (on success) that updates the DB to set hf_weights_syncing: false, or
  • Having this handler check whether weights are actually present (via list_files) before re-triggering the sync.
Prompt To Fix With AI
This is a comment left during a code review.
Path: model-engine/model_engine_server/api/app.py
Line: 336-362

Comment:
**`recover_hf_syncs` re-triggers syncs on every restart, indefinitely**

Since `hf_weights_syncing` is never cleared to `false` after a successful sync (the `_on_task_done` callback only cleans up in-memory state, not the DB), this startup handler will re-trigger downloads for **every** endpoint that has ever had `hf_weights_syncing=true` — even if their weights were successfully synced long ago.

This means on every server restart:
1. Unnecessary background `snapshot_download` + `upload_files` tasks are spawned for already-synced models (the cache-hit path returns quickly, but still issues a `list_files` I/O call per model).
2. The init container (`add_hf_weights_init_container`) will continue to be added to every subsequent deployment of these endpoints, adding startup latency even when weights are present.

Consider either:
- Adding a step in `_on_task_done` (on success) that updates the DB to set `hf_weights_syncing: false`, or
- Having this handler check whether weights are actually present (via `list_files`) before re-triggering the sync.

How can I resolve this? If you propose a fix, please make it concise.



def healthcheck() -> Response:
"""Returns 200 if the app is healthy."""
return Response(status_code=200)
Expand Down
9 changes: 6 additions & 3 deletions model-engine/model_engine_server/api/llms_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,15 @@ def handle_streaming_exception(
@llm_router_v1.post("/model-endpoints", response_model=CreateLLMModelEndpointV1Response)
async def create_model_endpoint(
wrapped_request: RootModel[CreateLLMModelEndpointV1Request],
request: Request,
auth: User = Depends(verify_authentication),
external_interfaces: ExternalInterfaces = Depends(get_external_interfaces),
) -> CreateLLMModelEndpointV1Response:
request = wrapped_request.root
llm_request = wrapped_request.root
"""
Creates an LLM endpoint for the current user.
"""
logger.info(f"POST /llm/model-endpoints with {request} for {auth}")
logger.info(f"POST /llm/model-endpoints with {llm_request} for {auth}")
try:
create_model_bundle_use_case = CreateModelBundleV2UseCase(
model_bundle_repository=external_interfaces.model_bundle_repository,
Expand All @@ -168,13 +169,15 @@ async def create_model_endpoint(
llm_artifact_gateway=external_interfaces.llm_artifact_gateway,
docker_repository=external_interfaces.docker_repository,
)
model_weights_manager = request.app.state.model_weights_manager
use_case = CreateLLMModelEndpointV1UseCase(
create_llm_model_bundle_use_case=create_llm_model_bundle_use_case,
model_endpoint_service=external_interfaces.model_endpoint_service,
docker_repository=external_interfaces.docker_repository,
llm_artifact_gateway=external_interfaces.llm_artifact_gateway,
model_weights_manager=model_weights_manager,
)
return await use_case.execute(user=auth, request=request)
return await use_case.execute(user=auth, request=llm_request)
except ObjectAlreadyExistsException as exc:
raise HTTPException(
status_code=400,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ class LLMMetadata:
quantize: Optional[Quantization] = None
checkpoint_path: Optional[str] = None
chat_template_override: Optional[str] = None
hf_weights_syncing: bool = False
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hf_weights_syncing is never cleared after sync completes

Once set to True during endpoint creation, no code path ever sets this back to False in the endpoint metadata after the background sync finishes. ModelWeightsManager._on_task_done only cleans up in-memory tracking — it doesn't update the database.

This causes two problems:

  1. recover_hf_syncs re-triggers downloads on every server restart for all endpoints that ever had this flag set, even if their weights were successfully synced long ago.
  2. The init container (add_hf_weights_init_container) will be added on every subsequent deployment/update of these endpoints, adding unnecessary startup latency even when weights are already present.

Consider adding a callback in _on_task_done (on success) that updates the endpoint metadata to set hf_weights_syncing: false, or have the recover_hf_syncs startup handler check whether the weights are actually present before re-triggering.

Prompt To Fix With AI
This is a comment left during a code review.
Path: model-engine/model_engine_server/domain/entities/llm_entity.py
Line: 34

Comment:
**`hf_weights_syncing` is never cleared after sync completes**

Once set to `True` during endpoint creation, no code path ever sets this back to `False` in the endpoint metadata after the background sync finishes. `ModelWeightsManager._on_task_done` only cleans up in-memory tracking — it doesn't update the database.

This causes two problems:
1. `recover_hf_syncs` re-triggers downloads on **every server restart** for all endpoints that ever had this flag set, even if their weights were successfully synced long ago.
2. The init container (`add_hf_weights_init_container`) will be added on every subsequent deployment/update of these endpoints, adding unnecessary startup latency even when weights are already present.

Consider adding a callback in `_on_task_done` (on success) that updates the endpoint metadata to set `hf_weights_syncing: false`, or have the `recover_hf_syncs` startup handler check whether the weights are actually present before re-triggering.

How can I resolve this? If you propose a fix, please make it concise.

Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[
"""
pass

@abstractmethod
def upload_files(self, local_path: str, remote_path: str, **kwargs) -> None:
"""
Upload all files from a local directory to a remote path.

Args:
local_path (str): local directory containing files to upload
remote_path (str): remote destination path (s3://, gs://, or https://)
"""
pass

@abstractmethod
def get_model_config(self, path: str, **kwargs) -> Dict[str, Any]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,8 @@ def load_model_weights_sub_commands_s3(
s5cmd = "./s5cmd"

checkpoint_files = self.llm_artifact_gateway.list_files(checkpoint_path)
validate_checkpoint_files(checkpoint_files)
if checkpoint_files:
validate_checkpoint_files(checkpoint_files)
Comment on lines 654 to +656
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skipping validation silently changes error behavior for non-sync callers

Previously, validate_checkpoint_files was always called and would raise ObjectHasInvalidValueException("No safetensors found in the checkpoint path.") if the checkpoint contained no safetensors — catching user misconfiguration early. Now, when list_files returns an empty list (e.g., a wrong/empty S3 prefix that isn't related to weight syncing), validation is silently skipped and the endpoint creation proceeds, only failing later at inference time.

Consider guarding this skip more tightly — e.g., only skip when a hf_weights_syncing flag is passed through, rather than universally skipping for any empty file list:

if checkpoint_files:
    validate_checkpoint_files(checkpoint_files)
elif not hf_weights_syncing:
    raise ObjectHasInvalidValueException(
        f"No files found at checkpoint path: {checkpoint_path}"
    )
Prompt To Fix With AI
This is a comment left during a code review.
Path: model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py
Line: 654-656

Comment:
**Skipping validation silently changes error behavior for non-sync callers**

Previously, `validate_checkpoint_files` was always called and would raise `ObjectHasInvalidValueException("No safetensors found in the checkpoint path.")` if the checkpoint contained no safetensors — catching user misconfiguration early. Now, when `list_files` returns an empty list (e.g., a wrong/empty S3 prefix that isn't related to weight syncing), validation is silently skipped and the endpoint creation proceeds, only failing later at inference time.

Consider guarding this skip more tightly — e.g., only skip when a `hf_weights_syncing` flag is passed through, rather than universally skipping for any empty file list:
```python
if checkpoint_files:
    validate_checkpoint_files(checkpoint_files)
elif not hf_weights_syncing:
    raise ObjectHasInvalidValueException(
        f"No files found at checkpoint path: {checkpoint_path}"
    )
```

How can I resolve this? If you propose a fix, please make it concise.


# filter to configs ('*.model' and '*.json') and weights ('*.safetensors')
# For models that are not supported by transformers directly, we need to include '*.py' and '*.bin'
Expand Down Expand Up @@ -1322,12 +1323,14 @@ def __init__(
model_endpoint_service: ModelEndpointService,
docker_repository: DockerRepository,
llm_artifact_gateway: LLMArtifactGateway,
model_weights_manager=None,
):
self.authz_module = LiveAuthorizationModule()
self.create_llm_model_bundle_use_case = create_llm_model_bundle_use_case
self.model_endpoint_service = model_endpoint_service
self.docker_repository = docker_repository
self.llm_artifact_gateway = llm_artifact_gateway
self.model_weights_manager = model_weights_manager

async def execute(
self, user: User, request: CreateLLMModelEndpointV1Request
Expand Down Expand Up @@ -1387,6 +1390,21 @@ async def execute(
"Multinode endpoints are only supported for VLLM models."
)

# Resolve checkpoint path: fires background sync and returns expected path immediately
checkpoint_path = request.checkpoint_path
hf_weights_syncing = False
if (
checkpoint_path is None
and request.source == LLMSource.HUGGING_FACE
and self.model_weights_manager is not None
):
models_info = SUPPORTED_MODELS_INFO.get(request.model_name)
if models_info and models_info.hf_repo:
checkpoint_path = self.model_weights_manager.ensure_model_weights_available(
models_info.hf_repo
)
hf_weights_syncing = True
Comment on lines +1393 to +1406
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hf_weights_syncing set unconditionally on cache hit

ensure_model_weights_available returns the path immediately without checking whether weights are already cached (the cache check happens asynchronously in the background task). This means hf_weights_syncing is always set to True — even when the weights are already present in storage.

Consequence: on every HF endpoint creation (even with cached weights), the init container in add_hf_weights_init_container is added to the pod spec. The init container will exit quickly on cache hit (since it polls S3 and finds the files), but it still adds startup latency from the init container spin-up and S3 ListObjects call.

Consider making ensure_model_weights_available synchronously check the cache and return a flag indicating whether syncing is actually needed, so hf_weights_syncing is only True when the background task is genuinely downloading:

checkpoint_path, needs_sync = self.model_weights_manager.ensure_model_weights_available(
    models_info.hf_repo
)
hf_weights_syncing = needs_sync
Prompt To Fix With AI
This is a comment left during a code review.
Path: model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py
Line: 1393-1406

Comment:
**`hf_weights_syncing` set unconditionally on cache hit**

`ensure_model_weights_available` returns the path immediately without checking whether weights are already cached (the cache check happens asynchronously in the background task). This means `hf_weights_syncing` is always set to `True` — even when the weights are already present in storage.

Consequence: on every HF endpoint creation (even with cached weights), the init container in `add_hf_weights_init_container` is added to the pod spec. The init container will exit quickly on cache hit (since it polls S3 and finds the files), but it still adds startup latency from the init container spin-up and S3 `ListObjects` call.

Consider making `ensure_model_weights_available` synchronously check the cache and return a flag indicating whether syncing is actually needed, so `hf_weights_syncing` is only `True` when the background task is genuinely downloading:

```python
checkpoint_path, needs_sync = self.model_weights_manager.ensure_model_weights_available(
    models_info.hf_repo
)
hf_weights_syncing = needs_sync
```

How can I resolve this? If you propose a fix, please make it concise.


bundle = await self.create_llm_model_bundle_use_case.execute(
user,
endpoint_name=request.name,
Expand All @@ -1397,7 +1415,7 @@ async def execute(
endpoint_type=request.endpoint_type,
num_shards=request.num_shards,
quantize=request.quantize,
checkpoint_path=request.checkpoint_path,
checkpoint_path=checkpoint_path,
chat_template_override=request.chat_template_override,
nodes_per_worker=request.nodes_per_worker,
additional_args=request.model_dump(exclude_none=True),
Expand Down Expand Up @@ -1430,8 +1448,9 @@ async def execute(
inference_framework_image_tag=request.inference_framework_image_tag,
num_shards=request.num_shards,
quantize=request.quantize,
checkpoint_path=request.checkpoint_path,
checkpoint_path=checkpoint_path,
chat_template_override=request.chat_template_override,
hf_weights_syncing=hf_weights_syncing,
)
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import asyncio
import functools
import tempfile
from typing import Dict, List, Set

from huggingface_hub import snapshot_download
from model_engine_server.common.config import hmi_config
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.domain.gateways.llm_artifact_gateway import LLMArtifactGateway

logger = make_logger(logger_name())

# Match the internal sync_model_weights.py inclusion/exclusion patterns
HF_IGNORE_PATTERNS: List[str] = [
"optimizer*",
"*.msgpack",
"*.h5",
"flax_model*",
"tf_model*",
"rust_model*",
]


class ModelWeightsManager:
def __init__(self, llm_artifact_gateway: LLMArtifactGateway):
self.llm_artifact_gateway = llm_artifact_gateway
self._background_tasks: Set[asyncio.Task] = set()
self._in_progress: Dict[str, asyncio.Task] = {}

def get_remote_path(self, hf_repo: str) -> str:
prefix = hmi_config.hf_user_fine_tuned_weights_prefix.rstrip("/")
return f"{prefix}/{hf_repo}"

def ensure_model_weights_available(self, hf_repo: str) -> str:
"""
Returns the expected remote path for ``hf_repo`` immediately and starts
syncing weights from HuggingFace Hub to that path in the background.

If the weights are already cached the background task exits early.
Callers receive the checkpoint path right away and can proceed with
any following actions (e.g. endpoint creation) without blocking.

A second call for the same ``hf_repo`` while a sync is already in
progress is a no-op: the existing task is reused and the same remote
path is returned.

Args:
hf_repo: HuggingFace repository ID, e.g. ``"meta-llama/Meta-Llama-3-8B"``.

Returns:
The remote path (s3://, gs://, or https://) where the weights will be stored.
"""
remote_path = self.get_remote_path(hf_repo)
if hf_repo not in self._in_progress:
task = asyncio.create_task(self._sync_weights(hf_repo, remote_path))
self._background_tasks.add(task)
self._in_progress[hf_repo] = task
task.add_done_callback(lambda t: self._on_task_done(t, hf_repo))
return remote_path

def _on_task_done(self, task: asyncio.Task, hf_repo: str) -> None:
self._background_tasks.discard(task)
self._in_progress.pop(hf_repo, None)
if not task.cancelled():
exc = task.exception()
if exc:
logger.error(
f"Background weight sync failed for {hf_repo}: {exc}",
exc_info=exc,
)

async def _sync_weights(self, hf_repo: str, remote_path: str) -> None:
"""Downloads weights from HuggingFace Hub and uploads to remote storage if not cached."""
files = self.llm_artifact_gateway.list_files(remote_path)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list_files() blocks the event loop

snapshot_download and upload_files are correctly offloaded via run_in_executor, but list_files() on line 46 is a synchronous I/O call (S3 ListObjects / GCS list_blobs / ABS list_blob_names) that runs directly on the async event loop. For consistency with the other two calls, this should also be wrapped in run_in_executor:

Suggested change
files = self.llm_artifact_gateway.list_files(remote_path)
files = await loop.run_in_executor(
None,
functools.partial(self.llm_artifact_gateway.list_files, remote_path),
)

Note: loop would need to be obtained before this line — move the loop = asyncio.get_event_loop() line above this call.

Prompt To Fix With AI
This is a comment left during a code review.
Path: model-engine/model_engine_server/domain/use_cases/model_weights_manager.py
Line: 46

Comment:
**`list_files()` blocks the event loop**

`snapshot_download` and `upload_files` are correctly offloaded via `run_in_executor`, but `list_files()` on line 46 is a synchronous I/O call (S3 `ListObjects` / GCS `list_blobs` / ABS `list_blob_names`) that runs directly on the async event loop. For consistency with the other two calls, this should also be wrapped in `run_in_executor`:

```suggestion
        files = await loop.run_in_executor(
            None,
            functools.partial(self.llm_artifact_gateway.list_files, remote_path),
        )
```

Note: `loop` would need to be obtained before this line — move the `loop = asyncio.get_event_loop()` line above this call.

How can I resolve this? If you propose a fix, please make it concise.

if files:
logger.info(f"Cache hit: {len(files)} files at {remote_path}")
return

logger.info(f"Cache miss for {hf_repo}. Downloading from HuggingFace Hub...")
loop = asyncio.get_event_loop()
with tempfile.TemporaryDirectory() as tmp_dir:
await loop.run_in_executor(
None,
functools.partial(
snapshot_download,
repo_id=hf_repo,
local_dir=tmp_dir,
ignore_patterns=HF_IGNORE_PATTERNS,
),
)
await loop.run_in_executor(
None,
functools.partial(
self.llm_artifact_gateway.upload_files,
tmp_dir,
remote_path,
),
)
Comment on lines +79 to +98
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unhandled errors from snapshot_download will crash endpoint creation

If snapshot_download fails (e.g., gated model requiring HF auth token, rate limiting, network timeout), the exception propagates uncaught and returns a 500 to the caller. Many models in SUPPORTED_MODELS_INFO (like meta-llama/*) are gated and require authentication. Consider wrapping this in a try/except that logs the error and either raises a user-friendly error or falls back to checkpoint_path = None (allowing downstream logic to handle it):

try:
    await loop.run_in_executor(...)
    await loop.run_in_executor(...)
except Exception as e:
    logger.error(f"Failed to download/upload weights for {hf_repo}: {e}")
    raise ObjectHasInvalidValueException(
        f"Could not download model weights for {hf_repo}. "
        "Ensure the model is accessible and try again, or provide a checkpoint_path explicitly."
    )
Prompt To Fix With AI
This is a comment left during a code review.
Path: model-engine/model_engine_server/domain/use_cases/model_weights_manager.py
Line: 51-70

Comment:
**Unhandled errors from `snapshot_download` will crash endpoint creation**

If `snapshot_download` fails (e.g., gated model requiring HF auth token, rate limiting, network timeout), the exception propagates uncaught and returns a 500 to the caller. Many models in `SUPPORTED_MODELS_INFO` (like `meta-llama/*`) are gated and require authentication. Consider wrapping this in a try/except that logs the error and either raises a user-friendly error or falls back to `checkpoint_path = None` (allowing downstream logic to handle it):

```python
try:
    await loop.run_in_executor(...)
    await loop.run_in_executor(...)
except Exception as e:
    logger.error(f"Failed to download/upload weights for {hf_repo}: {e}")
    raise ObjectHasInvalidValueException(
        f"Could not download model weights for {hf_repo}. "
        "Ensure the model is accessible and try again, or provide a checkpoint_path explicitly."
    )
```

How can I resolve this? If you propose a fix, please make it concise.


logger.info(f"Weights for {hf_repo} uploaded to {remote_path}")
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ def download_files(self, path: str, target_path: str, overwrite=False, **kwargs)
downloaded_files.append(local_path)
return downloaded_files

def upload_files(self, local_path: str, remote_path: str, **kwargs) -> None:
parsed = parse_attachment_url(remote_path, clean_key=False)
container_client = _get_abs_container_client(parsed.bucket)
for root, _, files in os.walk(local_path):
for file in files:
local_file = os.path.join(root, file)
blob_name = os.path.join(parsed.key, os.path.relpath(local_file, local_path))
with open(local_file, "rb") as f:
container_client.upload_blob(name=blob_name, data=f, overwrite=True)

def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]:
parsed_remote = parse_attachment_url(
hmi_config.hf_user_fine_tuned_weights_prefix, clean_key=False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ def download_files(self, path: str, target_path: str, overwrite=False, **kwargs)
downloaded_files.append(local_path)
return downloaded_files

def upload_files(self, local_path: str, remote_path: str, **kwargs) -> None:
parsed = parse_attachment_url(remote_path, clean_key=False)
client = get_gcs_sync_client()
bucket = client.bucket(parsed.bucket)
for root, _, files in os.walk(local_path):
for file in files:
local_file = os.path.join(root, file)
blob_name = os.path.join(parsed.key, os.path.relpath(local_file, local_path))
bucket.blob(blob_name).upload_from_filename(local_file)

def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]:
parsed_remote = parse_attachment_url(
hmi_config.hf_user_fine_tuned_weights_prefix, clean_key=False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,28 @@
BASE_PATH_IN_ENDPOINT = "/app"

DATADOG_ENV_VAR = {"DD_TRACE_ENABLED", "DD_SERVICE", "DD_ENV", "DD_VERSION", "DD_AGENT_HOST"}

# Key under which LLM metadata is stored in model_endpoint_record.metadata
_LLM_METADATA_KEY = "_llm"

# Python script run by the init container to poll storage until HF weights are present.
_HF_WEIGHTS_POLL_SCRIPT = """\
import boto3, os, sys, time
from urllib.parse import urlparse

cp = os.environ["CHECKPOINT_PATH"]
url = urlparse(cp)
bucket = url.netloc
prefix = url.path.lstrip("/")
s3 = boto3.client("s3")
while True:
resp = s3.list_objects_v2(Bucket=bucket, Prefix=prefix, MaxKeys=1)
if resp.get("Contents"):
print(f"Model weights ready at {cp}", flush=True)
sys.exit(0)
print(f"Waiting for model weights at {cp}...", flush=True)
time.sleep(30)
Comment on lines +79 to +85
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No timeout — pod hangs forever if weight sync fails

The while True loop has no timeout or maximum iteration count. If the background _sync_weights task fails (gated HF model, network error, permission issue, disk full), the init container will poll indefinitely, blocking pod startup forever. Since the background task is fire-and-forget with no error propagation back to the caller, there's no signal to the init container that it should stop.

Consider adding a maximum wait time (e.g., 1-2 hours) after which the container exits with a non-zero status and a descriptive error message:

MAX_WAIT_SECONDS = 7200  # 2 hours
elapsed = 0
while elapsed < MAX_WAIT_SECONDS:
    ...
    time.sleep(30)
    elapsed += 30
print(f"Timed out waiting for model weights at {cp}", flush=True)
sys.exit(1)
Prompt To Fix With AI
This is a comment left during a code review.
Path: model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py
Line: 79-85

Comment:
**No timeout — pod hangs forever if weight sync fails**

The `while True` loop has no timeout or maximum iteration count. If the background `_sync_weights` task fails (gated HF model, network error, permission issue, disk full), the init container will poll indefinitely, blocking pod startup forever. Since the background task is fire-and-forget with no error propagation back to the caller, there's no signal to the init container that it should stop.

Consider adding a maximum wait time (e.g., 1-2 hours) after which the container exits with a non-zero status and a descriptive error message:
```python
MAX_WAIT_SECONDS = 7200  # 2 hours
elapsed = 0
while elapsed < MAX_WAIT_SECONDS:
    ...
    time.sleep(30)
    elapsed += 30
print(f"Timed out waiting for model weights at {cp}", flush=True)
sys.exit(1)
```

How can I resolve this? If you propose a fix, please make it concise.

"""
Comment on lines +70 to +86
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Init container poll script only supports S3

_HF_WEIGHTS_POLL_SCRIPT hardcodes boto3 and s3.list_objects_v2, but the system supports three storage backends: S3, GCS (gs://), and Azure Blob Storage (https://*.blob.core.windows.net). If hf_user_fine_tuned_weights_prefix points to GCS or ABS, this script will fail at runtime — urlparse("gs://bucket/key") gives netloc="bucket" but the S3 API call will raise an error.

The upload_files implementations in GCSLLMArtifactGateway and ABSLLMArtifactGateway correctly handle their respective backends, so the background sync will succeed, but the init container waiting for the weights will never detect them.

Consider either:

  • Dispatching based on the URL scheme (s3:// → boto3, gs://google.cloud.storage, https://*.blob.core.windows.net → azure SDK), or
  • Using a simpler polling mechanism (e.g., an HTTP HEAD request if the gateway can provide a presigned/public URL).
Prompt To Fix With AI
This is a comment left during a code review.
Path: model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py
Line: 70-86

Comment:
**Init container poll script only supports S3**

`_HF_WEIGHTS_POLL_SCRIPT` hardcodes `boto3` and `s3.list_objects_v2`, but the system supports three storage backends: S3, GCS (`gs://`), and Azure Blob Storage (`https://*.blob.core.windows.net`). If `hf_user_fine_tuned_weights_prefix` points to GCS or ABS, this script will fail at runtime — `urlparse("gs://bucket/key")` gives `netloc="bucket"` but the S3 API call will raise an error.

The `upload_files` implementations in `GCSLLMArtifactGateway` and `ABSLLMArtifactGateway` correctly handle their respective backends, so the background sync will succeed, but the init container waiting for the weights will never detect them.

Consider either:
- Dispatching based on the URL scheme (`s3://` → boto3, `gs://``google.cloud.storage`, `https://*.blob.core.windows.net` → azure SDK), or
- Using a simpler polling mechanism (e.g., an HTTP HEAD request if the gateway can provide a presigned/public URL).

How can I resolve this? If you propose a fix, please make it concise.

Comment on lines +70 to +86
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Init container poll script hardcodes S3 — fails on GCS/Azure

_HF_WEIGHTS_POLL_SCRIPT imports boto3 and calls s3.list_objects_v2, but the system supports three storage backends (S3, GCS, Azure Blob Storage). The upload_files implementations for GCS and ABS correctly handle their respective backends, so the background sync will succeed — but the init container will never detect the uploaded weights on non-S3 backends and will poll indefinitely.

Additionally, the while True loop has no timeout or maximum iteration count. If the background sync task fails for any reason (gated HF model, network error, permissions), the init container blocks pod startup forever.

Consider:

  1. Dispatching on URL scheme (s3:// → boto3, gs://google.cloud.storage, https://*.blob.core.windows.net → azure SDK)
  2. Adding a maximum wait time (e.g., 1–2 hours) after which the init container exits non-zero with an error message
Prompt To Fix With AI
This is a comment left during a code review.
Path: model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py
Line: 70-86

Comment:
**Init container poll script hardcodes S3 — fails on GCS/Azure**

`_HF_WEIGHTS_POLL_SCRIPT` imports `boto3` and calls `s3.list_objects_v2`, but the system supports three storage backends (S3, GCS, Azure Blob Storage). The `upload_files` implementations for GCS and ABS correctly handle their respective backends, so the background sync will succeed — but the init container will never detect the uploaded weights on non-S3 backends and will poll indefinitely.

Additionally, the `while True` loop has no timeout or maximum iteration count. If the background sync task fails for any reason (gated HF model, network error, permissions), the init container blocks pod startup forever.

Consider:
1. Dispatching on URL scheme (`s3://` → boto3, `gs://``google.cloud.storage`, `https://*.blob.core.windows.net` → azure SDK)
2. Adding a maximum wait time (e.g., 1–2 hours) after which the init container exits non-zero with an error message

How can I resolve this? If you propose a fix, please make it concise.

LWS_DEFAULT_ENV_VAR = {
"K8S_OWN_POD_NAME",
"K8S_OWN_NAMESPACE",
Expand Down Expand Up @@ -339,6 +361,42 @@ def add_pod_metadata_env_to_container(container: Dict[str, Any]) -> None:
)


def add_hf_weights_init_container(
deployment_template: Dict[str, Any],
checkpoint_path: str,
) -> None:
"""Prepend an init container that polls storage until HF weights are present.

Uses the forwarder image (model-engine gateway image, which has Python and
boto3) so no additional image pull is required. Authentication relies on
the pod's service account (IRSA / workload-identity).
"""
containers = deployment_template["spec"]["template"]["spec"]["containers"]
# Prefer the forwarder container image; fall back to the first container.
forwarder_image = next(
(c["image"] for c in containers if c["name"] in ("http-forwarder", "celery-forwarder")),
containers[0]["image"],
)

init_container: Dict[str, Any] = {
"name": "wait-for-model-weights",
"image": forwarder_image,
"env": [{"name": "CHECKPOINT_PATH", "value": checkpoint_path}],
"command": ["python3", "-c", _HF_WEIGHTS_POLL_SCRIPT],
}

# Reuse the AWS config volume mount if the volume is present in the pod spec
volumes = deployment_template["spec"]["template"]["spec"].get("volumes", [])
if any(v["name"] == "config-volume" for v in volumes):
init_container["volumeMounts"] = [
{"name": "config-volume", "mountPath": "/opt/.aws/config", "subPath": "config"}
]

deployment_template["spec"]["template"]["spec"].setdefault("initContainers", []).append(
init_container
)


def add_lws_default_env_vars_to_container(container: Dict[str, Any]) -> None:
container_envs = []
container_envs.extend(
Expand Down Expand Up @@ -1657,6 +1715,9 @@ async def _create_or_update_resources(
user_container = get_main_container_from_deployment_template(deployment_template)
add_datadog_env_to_container(deployment_template, user_container)
add_pod_metadata_env_to_container(user_container)
llm_metadata = (model_endpoint_record.metadata or {}).get(_LLM_METADATA_KEY, {})
if llm_metadata.get("hf_weights_syncing") and llm_metadata.get("checkpoint_path"):
add_hf_weights_init_container(deployment_template, llm_metadata["checkpoint_path"])
await self._create_deployment(
model_endpoint_record=request.build_endpoint_request.model_endpoint_record,
deployment=deployment_template,
Expand Down
Loading