-
Notifications
You must be signed in to change notification settings - Fork 73
feat: add ModelWeightsManager to auto-sync HF weights on endpoint creation #761
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3e7a2a4
b14deb6
493d1cb
3a18aee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,3 +31,4 @@ class LLMMetadata: | |
| quantize: Optional[Quantization] = None | ||
| checkpoint_path: Optional[str] = None | ||
| chat_template_override: Optional[str] = None | ||
| hf_weights_syncing: bool = False | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Once set to This causes two problems:
Consider adding a callback in Prompt To Fix With AIThis is a comment left during a code review.
Path: model-engine/model_engine_server/domain/entities/llm_entity.py
Line: 34
Comment:
**`hf_weights_syncing` is never cleared after sync completes**
Once set to `True` during endpoint creation, no code path ever sets this back to `False` in the endpoint metadata after the background sync finishes. `ModelWeightsManager._on_task_done` only cleans up in-memory tracking — it doesn't update the database.
This causes two problems:
1. `recover_hf_syncs` re-triggers downloads on **every server restart** for all endpoints that ever had this flag set, even if their weights were successfully synced long ago.
2. The init container (`add_hf_weights_init_container`) will be added on every subsequent deployment/update of these endpoints, adding unnecessary startup latency even when weights are already present.
Consider adding a callback in `_on_task_done` (on success) that updates the endpoint metadata to set `hf_weights_syncing: false`, or have the `recover_hf_syncs` startup handler check whether the weights are actually present before re-triggering.
How can I resolve this? If you propose a fix, please make it concise. |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -652,7 +652,8 @@ def load_model_weights_sub_commands_s3( | |
| s5cmd = "./s5cmd" | ||
|
|
||
| checkpoint_files = self.llm_artifact_gateway.list_files(checkpoint_path) | ||
| validate_checkpoint_files(checkpoint_files) | ||
| if checkpoint_files: | ||
| validate_checkpoint_files(checkpoint_files) | ||
|
Comment on lines
654
to
+656
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Skipping validation silently changes error behavior for non-sync callers Previously, Consider guarding this skip more tightly — e.g., only skip when a if checkpoint_files:
validate_checkpoint_files(checkpoint_files)
elif not hf_weights_syncing:
raise ObjectHasInvalidValueException(
f"No files found at checkpoint path: {checkpoint_path}"
)Prompt To Fix With AIThis is a comment left during a code review.
Path: model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py
Line: 654-656
Comment:
**Skipping validation silently changes error behavior for non-sync callers**
Previously, `validate_checkpoint_files` was always called and would raise `ObjectHasInvalidValueException("No safetensors found in the checkpoint path.")` if the checkpoint contained no safetensors — catching user misconfiguration early. Now, when `list_files` returns an empty list (e.g., a wrong/empty S3 prefix that isn't related to weight syncing), validation is silently skipped and the endpoint creation proceeds, only failing later at inference time.
Consider guarding this skip more tightly — e.g., only skip when a `hf_weights_syncing` flag is passed through, rather than universally skipping for any empty file list:
```python
if checkpoint_files:
validate_checkpoint_files(checkpoint_files)
elif not hf_weights_syncing:
raise ObjectHasInvalidValueException(
f"No files found at checkpoint path: {checkpoint_path}"
)
```
How can I resolve this? If you propose a fix, please make it concise. |
||
|
|
||
| # filter to configs ('*.model' and '*.json') and weights ('*.safetensors') | ||
| # For models that are not supported by transformers directly, we need to include '*.py' and '*.bin' | ||
|
|
@@ -1322,12 +1323,14 @@ def __init__( | |
| model_endpoint_service: ModelEndpointService, | ||
| docker_repository: DockerRepository, | ||
| llm_artifact_gateway: LLMArtifactGateway, | ||
| model_weights_manager=None, | ||
| ): | ||
| self.authz_module = LiveAuthorizationModule() | ||
| self.create_llm_model_bundle_use_case = create_llm_model_bundle_use_case | ||
| self.model_endpoint_service = model_endpoint_service | ||
| self.docker_repository = docker_repository | ||
| self.llm_artifact_gateway = llm_artifact_gateway | ||
| self.model_weights_manager = model_weights_manager | ||
|
|
||
| async def execute( | ||
| self, user: User, request: CreateLLMModelEndpointV1Request | ||
|
|
@@ -1387,6 +1390,21 @@ async def execute( | |
| "Multinode endpoints are only supported for VLLM models." | ||
| ) | ||
|
|
||
| # Resolve checkpoint path: fires background sync and returns expected path immediately | ||
| checkpoint_path = request.checkpoint_path | ||
| hf_weights_syncing = False | ||
| if ( | ||
| checkpoint_path is None | ||
| and request.source == LLMSource.HUGGING_FACE | ||
| and self.model_weights_manager is not None | ||
| ): | ||
| models_info = SUPPORTED_MODELS_INFO.get(request.model_name) | ||
| if models_info and models_info.hf_repo: | ||
| checkpoint_path = self.model_weights_manager.ensure_model_weights_available( | ||
| models_info.hf_repo | ||
| ) | ||
| hf_weights_syncing = True | ||
|
Comment on lines
+1393
to
+1406
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Consequence: on every HF endpoint creation (even with cached weights), the init container in Consider making checkpoint_path, needs_sync = self.model_weights_manager.ensure_model_weights_available(
models_info.hf_repo
)
hf_weights_syncing = needs_syncPrompt To Fix With AIThis is a comment left during a code review.
Path: model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py
Line: 1393-1406
Comment:
**`hf_weights_syncing` set unconditionally on cache hit**
`ensure_model_weights_available` returns the path immediately without checking whether weights are already cached (the cache check happens asynchronously in the background task). This means `hf_weights_syncing` is always set to `True` — even when the weights are already present in storage.
Consequence: on every HF endpoint creation (even with cached weights), the init container in `add_hf_weights_init_container` is added to the pod spec. The init container will exit quickly on cache hit (since it polls S3 and finds the files), but it still adds startup latency from the init container spin-up and S3 `ListObjects` call.
Consider making `ensure_model_weights_available` synchronously check the cache and return a flag indicating whether syncing is actually needed, so `hf_weights_syncing` is only `True` when the background task is genuinely downloading:
```python
checkpoint_path, needs_sync = self.model_weights_manager.ensure_model_weights_available(
models_info.hf_repo
)
hf_weights_syncing = needs_sync
```
How can I resolve this? If you propose a fix, please make it concise. |
||
|
|
||
| bundle = await self.create_llm_model_bundle_use_case.execute( | ||
| user, | ||
| endpoint_name=request.name, | ||
|
|
@@ -1397,7 +1415,7 @@ async def execute( | |
| endpoint_type=request.endpoint_type, | ||
| num_shards=request.num_shards, | ||
| quantize=request.quantize, | ||
| checkpoint_path=request.checkpoint_path, | ||
| checkpoint_path=checkpoint_path, | ||
| chat_template_override=request.chat_template_override, | ||
| nodes_per_worker=request.nodes_per_worker, | ||
| additional_args=request.model_dump(exclude_none=True), | ||
|
|
@@ -1430,8 +1448,9 @@ async def execute( | |
| inference_framework_image_tag=request.inference_framework_image_tag, | ||
| num_shards=request.num_shards, | ||
| quantize=request.quantize, | ||
| checkpoint_path=request.checkpoint_path, | ||
| checkpoint_path=checkpoint_path, | ||
| chat_template_override=request.chat_template_override, | ||
| hf_weights_syncing=hf_weights_syncing, | ||
| ) | ||
| ) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,100 @@ | ||||||||||||
| import asyncio | ||||||||||||
| import functools | ||||||||||||
| import tempfile | ||||||||||||
| from typing import Dict, List, Set | ||||||||||||
|
|
||||||||||||
| from huggingface_hub import snapshot_download | ||||||||||||
| from model_engine_server.common.config import hmi_config | ||||||||||||
| from model_engine_server.core.loggers import logger_name, make_logger | ||||||||||||
| from model_engine_server.domain.gateways.llm_artifact_gateway import LLMArtifactGateway | ||||||||||||
|
|
||||||||||||
| logger = make_logger(logger_name()) | ||||||||||||
|
|
||||||||||||
| # Match the internal sync_model_weights.py inclusion/exclusion patterns | ||||||||||||
| HF_IGNORE_PATTERNS: List[str] = [ | ||||||||||||
| "optimizer*", | ||||||||||||
| "*.msgpack", | ||||||||||||
| "*.h5", | ||||||||||||
| "flax_model*", | ||||||||||||
| "tf_model*", | ||||||||||||
| "rust_model*", | ||||||||||||
| ] | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| class ModelWeightsManager: | ||||||||||||
| def __init__(self, llm_artifact_gateway: LLMArtifactGateway): | ||||||||||||
| self.llm_artifact_gateway = llm_artifact_gateway | ||||||||||||
| self._background_tasks: Set[asyncio.Task] = set() | ||||||||||||
| self._in_progress: Dict[str, asyncio.Task] = {} | ||||||||||||
|
|
||||||||||||
| def get_remote_path(self, hf_repo: str) -> str: | ||||||||||||
| prefix = hmi_config.hf_user_fine_tuned_weights_prefix.rstrip("/") | ||||||||||||
| return f"{prefix}/{hf_repo}" | ||||||||||||
|
|
||||||||||||
| def ensure_model_weights_available(self, hf_repo: str) -> str: | ||||||||||||
| """ | ||||||||||||
| Returns the expected remote path for ``hf_repo`` immediately and starts | ||||||||||||
| syncing weights from HuggingFace Hub to that path in the background. | ||||||||||||
|
|
||||||||||||
| If the weights are already cached the background task exits early. | ||||||||||||
| Callers receive the checkpoint path right away and can proceed with | ||||||||||||
| any following actions (e.g. endpoint creation) without blocking. | ||||||||||||
|
|
||||||||||||
| A second call for the same ``hf_repo`` while a sync is already in | ||||||||||||
| progress is a no-op: the existing task is reused and the same remote | ||||||||||||
| path is returned. | ||||||||||||
|
|
||||||||||||
| Args: | ||||||||||||
| hf_repo: HuggingFace repository ID, e.g. ``"meta-llama/Meta-Llama-3-8B"``. | ||||||||||||
|
|
||||||||||||
| Returns: | ||||||||||||
| The remote path (s3://, gs://, or https://) where the weights will be stored. | ||||||||||||
| """ | ||||||||||||
| remote_path = self.get_remote_path(hf_repo) | ||||||||||||
| if hf_repo not in self._in_progress: | ||||||||||||
| task = asyncio.create_task(self._sync_weights(hf_repo, remote_path)) | ||||||||||||
| self._background_tasks.add(task) | ||||||||||||
| self._in_progress[hf_repo] = task | ||||||||||||
| task.add_done_callback(lambda t: self._on_task_done(t, hf_repo)) | ||||||||||||
| return remote_path | ||||||||||||
|
|
||||||||||||
| def _on_task_done(self, task: asyncio.Task, hf_repo: str) -> None: | ||||||||||||
| self._background_tasks.discard(task) | ||||||||||||
| self._in_progress.pop(hf_repo, None) | ||||||||||||
| if not task.cancelled(): | ||||||||||||
| exc = task.exception() | ||||||||||||
| if exc: | ||||||||||||
| logger.error( | ||||||||||||
| f"Background weight sync failed for {hf_repo}: {exc}", | ||||||||||||
| exc_info=exc, | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| async def _sync_weights(self, hf_repo: str, remote_path: str) -> None: | ||||||||||||
| """Downloads weights from HuggingFace Hub and uploads to remote storage if not cached.""" | ||||||||||||
| files = self.llm_artifact_gateway.list_files(remote_path) | ||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Note: Prompt To Fix With AIThis is a comment left during a code review.
Path: model-engine/model_engine_server/domain/use_cases/model_weights_manager.py
Line: 46
Comment:
**`list_files()` blocks the event loop**
`snapshot_download` and `upload_files` are correctly offloaded via `run_in_executor`, but `list_files()` on line 46 is a synchronous I/O call (S3 `ListObjects` / GCS `list_blobs` / ABS `list_blob_names`) that runs directly on the async event loop. For consistency with the other two calls, this should also be wrapped in `run_in_executor`:
```suggestion
files = await loop.run_in_executor(
None,
functools.partial(self.llm_artifact_gateway.list_files, remote_path),
)
```
Note: `loop` would need to be obtained before this line — move the `loop = asyncio.get_event_loop()` line above this call.
How can I resolve this? If you propose a fix, please make it concise. |
||||||||||||
| if files: | ||||||||||||
| logger.info(f"Cache hit: {len(files)} files at {remote_path}") | ||||||||||||
| return | ||||||||||||
|
|
||||||||||||
| logger.info(f"Cache miss for {hf_repo}. Downloading from HuggingFace Hub...") | ||||||||||||
| loop = asyncio.get_event_loop() | ||||||||||||
| with tempfile.TemporaryDirectory() as tmp_dir: | ||||||||||||
| await loop.run_in_executor( | ||||||||||||
| None, | ||||||||||||
| functools.partial( | ||||||||||||
| snapshot_download, | ||||||||||||
| repo_id=hf_repo, | ||||||||||||
| local_dir=tmp_dir, | ||||||||||||
| ignore_patterns=HF_IGNORE_PATTERNS, | ||||||||||||
| ), | ||||||||||||
| ) | ||||||||||||
| await loop.run_in_executor( | ||||||||||||
| None, | ||||||||||||
| functools.partial( | ||||||||||||
| self.llm_artifact_gateway.upload_files, | ||||||||||||
| tmp_dir, | ||||||||||||
| remote_path, | ||||||||||||
| ), | ||||||||||||
| ) | ||||||||||||
|
Comment on lines
+79
to
+98
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unhandled errors from If try:
await loop.run_in_executor(...)
await loop.run_in_executor(...)
except Exception as e:
logger.error(f"Failed to download/upload weights for {hf_repo}: {e}")
raise ObjectHasInvalidValueException(
f"Could not download model weights for {hf_repo}. "
"Ensure the model is accessible and try again, or provide a checkpoint_path explicitly."
)Prompt To Fix With AIThis is a comment left during a code review.
Path: model-engine/model_engine_server/domain/use_cases/model_weights_manager.py
Line: 51-70
Comment:
**Unhandled errors from `snapshot_download` will crash endpoint creation**
If `snapshot_download` fails (e.g., gated model requiring HF auth token, rate limiting, network timeout), the exception propagates uncaught and returns a 500 to the caller. Many models in `SUPPORTED_MODELS_INFO` (like `meta-llama/*`) are gated and require authentication. Consider wrapping this in a try/except that logs the error and either raises a user-friendly error or falls back to `checkpoint_path = None` (allowing downstream logic to handle it):
```python
try:
await loop.run_in_executor(...)
await loop.run_in_executor(...)
except Exception as e:
logger.error(f"Failed to download/upload weights for {hf_repo}: {e}")
raise ObjectHasInvalidValueException(
f"Could not download model weights for {hf_repo}. "
"Ensure the model is accessible and try again, or provide a checkpoint_path explicitly."
)
```
How can I resolve this? If you propose a fix, please make it concise. |
||||||||||||
|
|
||||||||||||
| logger.info(f"Weights for {hf_repo} uploaded to {remote_path}") | ||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,6 +62,28 @@ | |
| BASE_PATH_IN_ENDPOINT = "/app" | ||
|
|
||
| DATADOG_ENV_VAR = {"DD_TRACE_ENABLED", "DD_SERVICE", "DD_ENV", "DD_VERSION", "DD_AGENT_HOST"} | ||
|
|
||
| # Key under which LLM metadata is stored in model_endpoint_record.metadata | ||
| _LLM_METADATA_KEY = "_llm" | ||
|
|
||
| # Python script run by the init container to poll storage until HF weights are present. | ||
| _HF_WEIGHTS_POLL_SCRIPT = """\ | ||
| import boto3, os, sys, time | ||
| from urllib.parse import urlparse | ||
|
|
||
| cp = os.environ["CHECKPOINT_PATH"] | ||
| url = urlparse(cp) | ||
| bucket = url.netloc | ||
| prefix = url.path.lstrip("/") | ||
| s3 = boto3.client("s3") | ||
| while True: | ||
| resp = s3.list_objects_v2(Bucket=bucket, Prefix=prefix, MaxKeys=1) | ||
| if resp.get("Contents"): | ||
| print(f"Model weights ready at {cp}", flush=True) | ||
| sys.exit(0) | ||
| print(f"Waiting for model weights at {cp}...", flush=True) | ||
| time.sleep(30) | ||
|
Comment on lines
+79
to
+85
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No timeout — pod hangs forever if weight sync fails The Consider adding a maximum wait time (e.g., 1-2 hours) after which the container exits with a non-zero status and a descriptive error message: MAX_WAIT_SECONDS = 7200 # 2 hours
elapsed = 0
while elapsed < MAX_WAIT_SECONDS:
...
time.sleep(30)
elapsed += 30
print(f"Timed out waiting for model weights at {cp}", flush=True)
sys.exit(1)Prompt To Fix With AIThis is a comment left during a code review.
Path: model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py
Line: 79-85
Comment:
**No timeout — pod hangs forever if weight sync fails**
The `while True` loop has no timeout or maximum iteration count. If the background `_sync_weights` task fails (gated HF model, network error, permission issue, disk full), the init container will poll indefinitely, blocking pod startup forever. Since the background task is fire-and-forget with no error propagation back to the caller, there's no signal to the init container that it should stop.
Consider adding a maximum wait time (e.g., 1-2 hours) after which the container exits with a non-zero status and a descriptive error message:
```python
MAX_WAIT_SECONDS = 7200 # 2 hours
elapsed = 0
while elapsed < MAX_WAIT_SECONDS:
...
time.sleep(30)
elapsed += 30
print(f"Timed out waiting for model weights at {cp}", flush=True)
sys.exit(1)
```
How can I resolve this? If you propose a fix, please make it concise. |
||
| """ | ||
|
Comment on lines
+70
to
+86
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Init container poll script only supports S3
The Consider either:
Prompt To Fix With AIThis is a comment left during a code review.
Path: model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py
Line: 70-86
Comment:
**Init container poll script only supports S3**
`_HF_WEIGHTS_POLL_SCRIPT` hardcodes `boto3` and `s3.list_objects_v2`, but the system supports three storage backends: S3, GCS (`gs://`), and Azure Blob Storage (`https://*.blob.core.windows.net`). If `hf_user_fine_tuned_weights_prefix` points to GCS or ABS, this script will fail at runtime — `urlparse("gs://bucket/key")` gives `netloc="bucket"` but the S3 API call will raise an error.
The `upload_files` implementations in `GCSLLMArtifactGateway` and `ABSLLMArtifactGateway` correctly handle their respective backends, so the background sync will succeed, but the init container waiting for the weights will never detect them.
Consider either:
- Dispatching based on the URL scheme (`s3://` → boto3, `gs://` → `google.cloud.storage`, `https://*.blob.core.windows.net` → azure SDK), or
- Using a simpler polling mechanism (e.g., an HTTP HEAD request if the gateway can provide a presigned/public URL).
How can I resolve this? If you propose a fix, please make it concise.
Comment on lines
+70
to
+86
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Init container poll script hardcodes S3 — fails on GCS/Azure
Additionally, the Consider:
Prompt To Fix With AIThis is a comment left during a code review.
Path: model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py
Line: 70-86
Comment:
**Init container poll script hardcodes S3 — fails on GCS/Azure**
`_HF_WEIGHTS_POLL_SCRIPT` imports `boto3` and calls `s3.list_objects_v2`, but the system supports three storage backends (S3, GCS, Azure Blob Storage). The `upload_files` implementations for GCS and ABS correctly handle their respective backends, so the background sync will succeed — but the init container will never detect the uploaded weights on non-S3 backends and will poll indefinitely.
Additionally, the `while True` loop has no timeout or maximum iteration count. If the background sync task fails for any reason (gated HF model, network error, permissions), the init container blocks pod startup forever.
Consider:
1. Dispatching on URL scheme (`s3://` → boto3, `gs://` → `google.cloud.storage`, `https://*.blob.core.windows.net` → azure SDK)
2. Adding a maximum wait time (e.g., 1–2 hours) after which the init container exits non-zero with an error message
How can I resolve this? If you propose a fix, please make it concise. |
||
| LWS_DEFAULT_ENV_VAR = { | ||
| "K8S_OWN_POD_NAME", | ||
| "K8S_OWN_NAMESPACE", | ||
|
|
@@ -339,6 +361,42 @@ def add_pod_metadata_env_to_container(container: Dict[str, Any]) -> None: | |
| ) | ||
|
|
||
|
|
||
| def add_hf_weights_init_container( | ||
| deployment_template: Dict[str, Any], | ||
| checkpoint_path: str, | ||
| ) -> None: | ||
| """Prepend an init container that polls storage until HF weights are present. | ||
|
|
||
| Uses the forwarder image (model-engine gateway image, which has Python and | ||
| boto3) so no additional image pull is required. Authentication relies on | ||
| the pod's service account (IRSA / workload-identity). | ||
| """ | ||
| containers = deployment_template["spec"]["template"]["spec"]["containers"] | ||
| # Prefer the forwarder container image; fall back to the first container. | ||
| forwarder_image = next( | ||
| (c["image"] for c in containers if c["name"] in ("http-forwarder", "celery-forwarder")), | ||
| containers[0]["image"], | ||
| ) | ||
|
|
||
| init_container: Dict[str, Any] = { | ||
| "name": "wait-for-model-weights", | ||
| "image": forwarder_image, | ||
| "env": [{"name": "CHECKPOINT_PATH", "value": checkpoint_path}], | ||
| "command": ["python3", "-c", _HF_WEIGHTS_POLL_SCRIPT], | ||
| } | ||
|
|
||
| # Reuse the AWS config volume mount if the volume is present in the pod spec | ||
| volumes = deployment_template["spec"]["template"]["spec"].get("volumes", []) | ||
| if any(v["name"] == "config-volume" for v in volumes): | ||
| init_container["volumeMounts"] = [ | ||
| {"name": "config-volume", "mountPath": "/opt/.aws/config", "subPath": "config"} | ||
| ] | ||
|
|
||
| deployment_template["spec"]["template"]["spec"].setdefault("initContainers", []).append( | ||
| init_container | ||
| ) | ||
|
|
||
|
|
||
| def add_lws_default_env_vars_to_container(container: Dict[str, Any]) -> None: | ||
| container_envs = [] | ||
| container_envs.extend( | ||
|
|
@@ -1657,6 +1715,9 @@ async def _create_or_update_resources( | |
| user_container = get_main_container_from_deployment_template(deployment_template) | ||
| add_datadog_env_to_container(deployment_template, user_container) | ||
| add_pod_metadata_env_to_container(user_container) | ||
| llm_metadata = (model_endpoint_record.metadata or {}).get(_LLM_METADATA_KEY, {}) | ||
| if llm_metadata.get("hf_weights_syncing") and llm_metadata.get("checkpoint_path"): | ||
| add_hf_weights_init_container(deployment_template, llm_metadata["checkpoint_path"]) | ||
| await self._create_deployment( | ||
| model_endpoint_record=request.build_endpoint_request.model_endpoint_record, | ||
| deployment=deployment_template, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recover_hf_syncsre-triggers syncs on every restart, indefinitelySince
hf_weights_syncingis never cleared tofalseafter a successful sync (the_on_task_donecallback only cleans up in-memory state, not the DB), this startup handler will re-trigger downloads for every endpoint that has ever hadhf_weights_syncing=true— even if their weights were successfully synced long ago.This means on every server restart:
snapshot_download+upload_filestasks are spawned for already-synced models (the cache-hit path returns quickly, but still issues alist_filesI/O call per model).add_hf_weights_init_container) will continue to be added to every subsequent deployment of these endpoints, adding startup latency even when weights are present.Consider either:
_on_task_done(on success) that updates the DB to sethf_weights_syncing: false, orlist_files) before re-triggering the sync.Prompt To Fix With AI