From 5db6f311b520fe1cb6b61035cd3add0e680ffce1 Mon Sep 17 00:00:00 2001 From: Rakesh Paul Date: Sun, 5 Apr 2026 03:02:57 +0530 Subject: [PATCH] Update DGXCloudExecutor for improved workload management - Renamed app_id and app_secret to client_id and client_secret for clarity. - Introduced new methods for deleting workloads and checking workspace status. - Enhanced data movement functionality to use a tarball when within character limits, falling back to individual file deployment otherwise. - Updated RayCluster and RayJob to integrate DGXCloudExecutor and its corresponding classes. Fixes #478 Signed-off-by: Rakesh Paul --- nemo_run/core/execution/dgxcloud.py | 286 +++++++++++++++------- nemo_run/run/ray/cluster.py | 3 + nemo_run/run/ray/dgxcloud.py | 362 ++++++++++++++++++++++++++++ nemo_run/run/ray/job.py | 3 + 4 files changed, 572 insertions(+), 82 deletions(-) create mode 100644 nemo_run/run/ray/dgxcloud.py diff --git a/nemo_run/core/execution/dgxcloud.py b/nemo_run/core/execution/dgxcloud.py index 24724503..27f04022 100644 --- a/nemo_run/core/execution/dgxcloud.py +++ b/nemo_run/core/execution/dgxcloud.py @@ -15,9 +15,11 @@ import base64 import glob +import gzip import json import logging import os +import shutil import subprocess import tempfile import time @@ -70,8 +72,8 @@ class DGXCloudExecutor(Executor): base_url: str kube_apiserver_url: str - app_id: str - app_secret: str + client_id: str + client_secret: str project_name: str container_image: str pvc_nemo_run_dir: str @@ -83,13 +85,14 @@ class DGXCloudExecutor(Executor): pvcs: list[dict[str, Any]] = field(default_factory=list) distributed_framework: str = "PyTorch" custom_spec: dict[str, Any] = field(default_factory=dict) + MAX_ARGS_CHARS: int = 9500 def get_auth_token(self) -> Optional[str]: url = f"{self.base_url}/token" payload = { - "grantType": "app_token", - "appId": self.app_id, - "appSecret": self.app_secret, + "grantType": "client_credentials", + "clientId": self.client_id, + "clientSecret": self.client_secret, } n_attempts = 0 @@ -138,18 +141,46 @@ def copy_directory_data_command(self, local_dir_path: str, dest_path: str) -> st cmd = f"rm -rf {dest_path} && mkdir -p {dest_path} && echo {encoded_data} | base64 -d > {dest_path}/archive.tar.gz && tar -xzf {dest_path}/archive.tar.gz -C {dest_path} && rm {dest_path}/archive.tar.gz" return cmd - def create_data_mover_workload(self, token: str, project_id: str, cluster_id: str): - """ - Creates a CPU only workload to move job directory into PVC using the provided project/cluster IDs. - """ + def delete_workload(self, token: str, workload_id: str): + url = f"{self.base_url}/workloads/workspaces/{workload_id}" + headers = self._default_headers(token=token) - cmd = self.copy_directory_data_command(self.job_dir, self.pvc_job_dir) + response = requests.delete(url, headers=headers) - url = f"{self.base_url}/workloads/workspaces" + logger.debug( + "Delete interactive workspace; response code=%s, content=%s", + response.status_code, + response.text.strip(), + ) + return response + + def _workspace_status(self, workload_id: str) -> Optional[DGXCloudState]: + """Query workspace-specific status endpoint for data-mover workloads.""" + url = f"{self.base_url}/workloads/workspaces/{workload_id}" + token = self.get_auth_token() + if not token: + return None headers = self._default_headers(token=token) + response = requests.get(url, headers=headers) + if response.status_code != 200: + return None + data = response.json() + phase = data.get("actualPhase") or data.get("phase") + return DGXCloudState(phase) if phase else None + def _run_workspace_and_wait( + self, + token: str, + project_id: str, + cluster_id: str, + name: str, + cmd: str, + sleep: float = 10, + timeout: int = 300, + ) -> None: + """Create a workspace workload, poll until done, then delete it.""" payload = { - "name": "data-mover", + "name": name, "useGivenNameAsPrefix": True, "projectId": project_id, "clusterId": cluster_id, @@ -160,76 +191,94 @@ def create_data_mover_workload(self, token: str, project_id: str, cluster_id: st "storage": {"pvc": self.pvcs}, }, } - - response = requests.post(url, json=payload, headers=headers) - - logger.debug( - "Created workload; response code=%s, content=%s", - response.status_code, - response.text.strip(), - ) - - return response - - def delete_workload(self, token: str, workload_id: str): - url = f"{self.base_url}/workloads/workspaces/{workload_id}" headers = self._default_headers(token=token) - - response = requests.delete(url, headers=headers) - - logger.debug( - "Delete interactive workspace; response code=%s, content=%s", - response.status_code, - response.text.strip(), - ) - return response + resp = requests.post(f"{self.base_url}/workloads/workspaces", json=payload, headers=headers) + if resp.status_code not in (200, 202): + raise RuntimeError(f"Workload '{name}' failed: {resp.status_code} {resp.text}") + wid = resp.json()["workloadId"] + logger.info(" workload %s (%s) created", name, wid[:12]) + + elapsed = 0 + while elapsed < timeout: + time.sleep(sleep) + elapsed += sleep + status = self._workspace_status(wid) + if status == DGXCloudState.COMPLETED: + self.delete_workload(token, wid) + return + if status in (DGXCloudState.FAILED, DGXCloudState.STOPPED, DGXCloudState.DEGRADED): + self.delete_workload(token, wid) + raise RuntimeError(f"Workload {wid} ended with: {status}") + raise RuntimeError(f"Workload {wid} timed out after {timeout}s") def move_data(self, token: str, project_id: str, cluster_id: str, sleep: float = 10) -> None: - """ - Moves job directory into PVC and deletes the workload after completion - """ - - resp = self.create_data_mover_workload(token, project_id, cluster_id) - if resp.status_code not in [200, 202]: - raise RuntimeError( - f"Failed to create data mover workload, status_code={resp.status_code}, reason={resp.text}" - ) - - resp_json = resp.json() - workload_id = resp_json["workloadId"] - status = DGXCloudState(resp_json["actualPhase"]) + """Move job directory into PVC. - logger.info(f"Successfully created data movement workload {workload_id} on DGXCloud") - - while status in [ - DGXCloudState.PENDING, - DGXCloudState.CREATING, - DGXCloudState.INITIALIZING, - DGXCloudState.RUNNING, - ]: - time.sleep(sleep) - status = self.status(workload_id) - logger.debug( - f"Polling data movement workload {workload_id}'s status. Current status is: {status}" - ) + Uses the fast single-command tarball when it fits within the API's + 10 000-char limit. Falls back to per-file deployment otherwise. + """ + cmd = self.copy_directory_data_command(self.job_dir, self.pvc_job_dir) - if status is not DGXCloudState.COMPLETED: - raise RuntimeError(f"Failed to move data to PVC. Workload status is {status}") + if len(cmd) <= self.MAX_ARGS_CHARS: + self._run_workspace_and_wait(token, project_id, cluster_id, "data-mover", cmd, sleep) + return - resp = self.delete_workload(token, workload_id) - if resp.status_code >= 200 and resp.status_code < 300: - logger.info( - "Successfully deleted data movement workload %s on DGXCloud with response code %d", - workload_id, - resp.status_code, - ) - else: - logger.error( - "Failed to delete data movement workload %s, response code=%d, reason=%s", - workload_id, - resp.status_code, - resp.text, - ) + logger.info( + "Tarball is %d chars (limit %d), deploying files individually", + len(cmd), + self.MAX_ARGS_CHARS, + ) + for root, _, filenames in os.walk(self.job_dir): + for fn in filenames: + if fn.endswith(".tar.gz"): + continue + abs_path = os.path.join(root, fn) + rel_path = os.path.relpath(abs_path, self.job_dir) + dest = os.path.join(self.pvc_job_dir, rel_path) + with open(abs_path, "rb") as f: + data = f.read() + + compressed = gzip.compress(data, compresslevel=9) + encoded = base64.b64encode(compressed).decode() + overhead = len(f"mkdir -p $(dirname {dest}) && echo | base64 -d | gunzip > {dest}") + chunk_b64_limit = self.MAX_ARGS_CHARS - overhead - 50 + + if len(encoded) <= chunk_b64_limit: + file_cmd = f"mkdir -p $(dirname {dest}) && echo {encoded} | base64 -d | gunzip > {dest}" + logger.info( + " deploying %s (%d→%d bytes)", rel_path, len(data), len(compressed) + ) + self._run_workspace_and_wait( + token, project_id, cluster_id, "data-mover", file_cmd, sleep + ) + else: + chunk_size = (chunk_b64_limit * 3) // 4 + raw_chunks = [ + compressed[i : i + chunk_size] + for i in range(0, len(compressed), chunk_size) + ] + logger.info( + " deploying %s in %d chunks (%d→%d bytes)", + rel_path, + len(raw_chunks), + len(data), + len(compressed), + ) + for ci, chunk in enumerate(raw_chunks): + b64 = base64.b64encode(chunk).decode() + if ci == 0: + file_cmd = ( + f"mkdir -p $(dirname {dest}) && echo {b64} | base64 -d > {dest}.gz" + ) + else: + file_cmd = f"echo {b64} | base64 -d >> {dest}.gz" + self._run_workspace_and_wait( + token, project_id, cluster_id, "data-mover", file_cmd, sleep + ) + gunzip_cmd = f"gunzip -f {dest}.gz" + self._run_workspace_and_wait( + token, project_id, cluster_id, "data-mover", gunzip_cmd, sleep + ) def create_training_job( self, token: str, project_id: str, cluster_id: str, name: str @@ -272,7 +321,7 @@ def create_training_job( common_spec = { "command": f"/bin/bash {self.pvc_job_dir}/launch_script.sh", "image": self.container_image, - "compute": {"gpuDevicesRequest": self.gpus_per_node}, + "compute": {"gpuDevicesRequest": self.gpus_per_node, "largeShmRequest": True}, "storage": {"pvc": self.pvcs}, "environmentVariables": [ {"name": key, "value": value} for key, value in self.env_vars.items() @@ -321,6 +370,17 @@ def launch(self, name: str, cmd: list[str]) -> tuple[str, str]: if not project_id or not cluster_id: raise RuntimeError("Unable to determine project/cluster IDs for job submission") + # Copy experiment-level files referenced in cmd into job_dir + # so they are included in the data mover transfer to the PVC + cmd_str = " ".join(cmd) + for fname in os.listdir(self.experiment_dir): + fpath = os.path.join(self.experiment_dir, fname) + if os.path.isfile(fpath) and fpath in cmd_str: + shutil.copy2(fpath, os.path.join(self.job_dir, fname)) + + # Rewrite local paths in cmd to point to the PVC job directory + cmd = [c.replace(self.experiment_dir, self.pvc_job_dir) for c in cmd] + # prepare launch script and move data to PVC launch_script = f""" ln -s {self.pvc_job_dir}/ /nemo_run @@ -390,9 +450,37 @@ def fetch_logs( stderr: Optional[bool] = None, stdout: Optional[bool] = None, ) -> Iterable[str]: - while self.status(job_id) != DGXCloudState.RUNNING: - logger.info("Waiting for job to start...") + state = self.status(job_id) + while state != DGXCloudState.RUNNING: + logger.info("Job %s — status: %s", job_id[:12], state.value if state else "Unknown") + if state in ( + DGXCloudState.COMPLETED, + DGXCloudState.FAILED, + DGXCloudState.STOPPED, + DGXCloudState.DEGRADED, + ): + logger.warning("Job reached terminal state %s before logs were available", state) + return time.sleep(15) + state = self.status(job_id) + + if not self.launched_from_cluster: + logger.info("Job %s is RUNNING. Logs are available in the Run:AI UI.", job_id[:12]) + terminal = ( + DGXCloudState.COMPLETED, + DGXCloudState.FAILED, + DGXCloudState.STOPPED, + DGXCloudState.DEGRADED, + ) + while True: + time.sleep(30) + state = self.status(job_id) + logger.info("Job %s — status: %s", job_id[:12], state.value if state else "Unknown") + if state in terminal: + yield f"Job finished with status: {state.value}" + return + + logger.info("Job %s is RUNNING, waiting for log files...", job_id[:12]) cmd = ["tail"] @@ -405,12 +493,21 @@ def fetch_logs( self.pvc_job_dir = os.path.join(self.pvc_nemo_run_dir, job_subdir) files = [] + poll_count = 0 while len(files) < self.nodes: files = list(glob.glob(f"{self.pvc_job_dir}/log_*.out")) files = [f for f in files if "log-allranks_0" not in f] - logger.info( - f"Waiting for {self.nodes + 1 - len(files)} log files to be created in {self.pvc_job_dir}..." - ) + if poll_count == 0 or poll_count % 10 == 0: + logger.info( + "Log files: %d/%d ready (watching %s)", + len(files), + self.nodes, + self.pvc_job_dir, + ) + poll_count += 1 + if poll_count > 100: + logger.warning("Timed out waiting for log files after 5 minutes") + return time.sleep(3) cmd.extend(files) @@ -526,6 +623,30 @@ def assign( ) self.experiment_id = exp_id + def deploy_script_to_pvc( + self, + script_content: str, + dest_path: str, + token: Optional[str] = None, + project_id: Optional[str] = None, + cluster_id: Optional[str] = None, + ) -> None: + """Write a script to the PVC via a short-lived busybox workspace.""" + if not token: + token = self.get_auth_token() + if not token: + raise RuntimeError("Failed to get auth token for script deployment") + if not project_id or not cluster_id: + project_id, cluster_id = self.get_project_and_cluster_id(token) + + encoded = base64.b64encode(gzip.compress(script_content.encode(), compresslevel=9)).decode() + cmd = ( + f"mkdir -p $(dirname {dest_path}) && " + f"echo {encoded} | base64 -d | gunzip > {dest_path} && " + f"chmod +x {dest_path}" + ) + self._run_workspace_and_wait(token, project_id, cluster_id, "script-deploy", cmd) + def get_launcher_prefix(self) -> Optional[list[str]]: launcher = self.get_launcher() if launcher.nsys_profile: @@ -574,6 +695,7 @@ def package(self, packager: Packager, job_name: str): ctx.run( f"tar -xvzf {local_pkg} -C {local_code_extraction_path} --ignore-zeros", hide=True ) + os.remove(local_pkg) def macro_values(self) -> Optional[ExecutorMacros]: return None diff --git a/nemo_run/run/ray/cluster.py b/nemo_run/run/ray/cluster.py index a15bdc55..e4cea5fd 100644 --- a/nemo_run/run/ray/cluster.py +++ b/nemo_run/run/ray/cluster.py @@ -17,9 +17,11 @@ from typing import Optional, Type from nemo_run.core.execution.base import Executor +from nemo_run.core.execution.dgxcloud import DGXCloudExecutor from nemo_run.core.execution.lepton import LeptonExecutor from nemo_run.core.execution.slurm import SlurmExecutor from nemo_run.core.frontend.console.api import configure_logging +from nemo_run.run.ray.dgxcloud import DGXCloudRayCluster from nemo_run.run.ray.lepton import LeptonRayCluster from nemo_run.run.ray.slurm import SlurmRayCluster @@ -46,6 +48,7 @@ def __post_init__(self): backend_map: dict[Type[Executor], Type] = { SlurmExecutor: SlurmRayCluster, LeptonExecutor: LeptonRayCluster, + DGXCloudExecutor: DGXCloudRayCluster, } if _KUBERAY_AVAILABLE and KubeRayExecutor is not None and KubeRayCluster is not None: diff --git a/nemo_run/run/ray/dgxcloud.py b/nemo_run/run/ray/dgxcloud.py new file mode 100644 index 00000000..450c26f7 --- /dev/null +++ b/nemo_run/run/ray/dgxcloud.py @@ -0,0 +1,362 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DGX Cloud Ray backend for NeMo Run. + +Ray orchestration on DGX Cloud works differently from Lepton or KubeRay +backends. There is no dedicated RayCluster CRD; instead a *distributed +workload* is submitted where every pod runs a bootstrap script that +self-organises into a Ray head + worker topology. Pod rank is derived +from the hostname suffix (``...-worker-N`` -> N); worker-0 becomes the +Ray head and writes its IP to a per-job file on the shared PVC so the +remaining workers can discover and join it. +""" + +from __future__ import annotations + +import logging +import time +from dataclasses import dataclass, field +from typing import Any, Optional + +import requests + +from nemo_run.core.execution.dgxcloud import DGXCloudExecutor, DGXCloudState + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Ray bootstrap template +# --------------------------------------------------------------------------- + +RAY_BOOTSTRAP_TEMPLATE = """#!/bin/bash +set -euo pipefail + +RAY_PORT=6379 +NUM_GPUS={gpus_per_node} +NUM_NODES={num_nodes} + +JOB_PREFIX=$(echo $HOSTNAME | sed 's/-worker-[0-9]*$//') +HEAD_IP_FILE="{head_ip_dir}/.ray_head_ip_$JOB_PREFIX" +DONE_FILE="{head_ip_dir}/.ray_done_$JOB_PREFIX" + +MY_RANK=$(echo $HOSTNAME | grep -oE '[0-9]+$') +MY_IP=$(hostname -i 2>/dev/null || echo "127.0.0.1") + +echo "Ray bootstrap: pod=$HOSTNAME rank=$MY_RANK/$NUM_NODES ip=$MY_IP" + +if [ "$MY_RANK" -eq 0 ]; then + rm -f $DONE_FILE + ray start --head --port=$RAY_PORT --num-gpus=$NUM_GPUS --dashboard-host=0.0.0.0 + + mkdir -p $(dirname $HEAD_IP_FILE) + echo "$MY_IP" > $HEAD_IP_FILE + export RAY_ADDRESS="$MY_IP:$RAY_PORT" + + echo "Waiting for $NUM_NODES Ray node(s)..." + for _i in $(seq 1 120); do + CONNECTED=$(python3 -c "import ray; ray.init(address='auto',ignore_reinit_error=True); print(len(ray.nodes()))" 2>/dev/null || echo 0) + [ "$CONNECTED" -ge "$NUM_NODES" ] && break + sleep 5 + done + echo "$CONNECTED/$NUM_NODES Ray nodes connected." + + {training_command} + EXIT_CODE=$? + + echo "$EXIT_CODE" > $DONE_FILE + ray stop + rm -f $HEAD_IP_FILE + exit $EXIT_CODE +else + echo "Waiting for Ray head IP at $HEAD_IP_FILE..." + for _w in $(seq 1 120); do + [ -f "$HEAD_IP_FILE" ] && break + sleep 3 + done + [ ! -f "$HEAD_IP_FILE" ] && echo "ERROR: head IP not found" && exit 1 + + HEAD_IP=$(cat $HEAD_IP_FILE) + echo "Joining Ray head at $HEAD_IP:$RAY_PORT" + ray start --address="$HEAD_IP:$RAY_PORT" --num-gpus=$NUM_GPUS + + echo "Worker $MY_RANK running. Waiting for job completion..." + while [ ! -f "$DONE_FILE" ]; do + sleep 10 + done + EXIT_CODE=$(cat $DONE_FILE) + echo "Head signaled completion (exit=$EXIT_CODE). Shutting down." + ray stop + exit $EXIT_CODE +fi +""" + + +def build_ray_bootstrap_script( + training_command: str, + gpus_per_node: int, + num_nodes: int, + head_ip_dir: str = "/workspace/nemo_run", +) -> str: + """Generate a bash script that bootstraps Ray across Run:AI distributed pods. + + Pod rank is derived from the hostname suffix (``...-worker-N`` -> N). + Worker-0 starts the Ray head and writes its IP to a per-job file on the + shared PVC; other workers read the IP and join. The file name includes + the job prefix from the hostname so concurrent runs don't collide. + """ + return RAY_BOOTSTRAP_TEMPLATE.format( + gpus_per_node=gpus_per_node, + num_nodes=num_nodes, + training_command=training_command, + head_ip_dir=head_ip_dir, + ) + + +# --------------------------------------------------------------------------- +# DGXCloudRayCluster +# --------------------------------------------------------------------------- + + +@dataclass(kw_only=True) +class DGXCloudRayCluster: + """Placeholder cluster for DGX Cloud. + + On DGX Cloud the Ray cluster is bootstrapped *inside* the distributed + workload itself (no separate CRD). The ``create`` / ``delete`` methods + are intentional no-ops; all real work happens in ``DGXCloudRayJob``. + """ + + EXECUTOR_CLS = DGXCloudExecutor + + name: str + executor: DGXCloudExecutor + + def create( + self, + pre_ray_start_commands: Optional[list[str]] = None, + dryrun: bool = False, + ) -> None: + logger.info( + "DGXCloudRayCluster.create() is a no-op; " + "Ray is bootstrapped inside the distributed workload." + ) + + def wait_until_running(self, timeout: int = 600) -> bool: + logger.info("DGXCloudRayCluster.wait_until_running() is a no-op.") + return True + + def status(self, display: bool = False) -> dict[str, Any]: + return {"state": "Implicit", "cluster_name": self.name, "ray_ready": True} + + def port_forward(self, *args: Any, **kwargs: Any) -> Any: + raise NotImplementedError("Port forwarding is not supported for DGXCloudRayCluster.") + + def stop_forwarding(self, *args: Any, **kwargs: Any) -> Any: + raise NotImplementedError("Port forwarding is not supported for DGXCloudRayCluster.") + + def delete(self, wait: bool = False, **kwargs: Any) -> bool: + logger.info("DGXCloudRayCluster.delete() is a no-op.") + return True + + +# --------------------------------------------------------------------------- +# DGXCloudRayJob +# --------------------------------------------------------------------------- + + +@dataclass(kw_only=True) +class DGXCloudRayJob: + """Submit and monitor a Ray job on DGX Cloud via the Run:AI REST API. + + Instead of using NeMo Run's ``DGXCloudExecutor.launch()`` (which + assumes Torchrun), this class: + + 1. Builds a Ray bootstrap script from the training command. + 2. Deploys the script to the PVC via a short-lived workspace workload. + 3. Creates a *distributed* workload where every pod runs the bootstrap. + 4. Polls until the workload reaches a terminal state. + """ + + name: str + executor: DGXCloudExecutor + workload_id: Optional[str] = None + poll_interval: int = 30 + _token: Optional[str] = field(init=False, default=None, repr=False) + _project_id: Optional[str] = field(init=False, default=None, repr=False) + _cluster_id: Optional[str] = field(init=False, default=None, repr=False) + + def _ensure_auth(self) -> None: + if not self._token: + self._token = self.executor.get_auth_token() + if not self._token: + raise RuntimeError("Failed to get auth token") + if not self._project_id or not self._cluster_id: + self._project_id, self._cluster_id = self.executor.get_project_and_cluster_id( + self._token + ) + if not self._project_id or not self._cluster_id: + raise RuntimeError("Unable to determine project/cluster IDs") + + # ------------------------------------------------------------------ + # Public API (matches the interface expected by RayJob) + # ------------------------------------------------------------------ + + def start( + self, + command: str, + workdir: str, + runtime_env_yaml: Optional[str] = None, + pre_ray_start_commands: Optional[list[str]] = None, + dryrun: bool = False, + ) -> Optional[str]: + """Build a Ray bootstrap script and submit a distributed workload.""" + self._ensure_auth() + ex = self.executor + + job_name = self.name.replace("_", "-").replace(".", "-").lower() + if len(job_name) > 35: + logger.warning("Job name exceeds 35 characters, truncating.") + job_name = job_name[:34] + + ray_script = build_ray_bootstrap_script( + training_command=command, + gpus_per_node=ex.gpus_per_node, + num_nodes=ex.nodes, + head_ip_dir=ex.pvc_nemo_run_dir, + ) + + if dryrun: + logger.info("Dry run — Ray bootstrap script:\n%s", ray_script) + return None + + script_pvc_path = f"{ex.pvc_nemo_run_dir}/ray_bootstrap_{job_name}.sh" + logger.info("Deploying Ray bootstrap script to %s", script_pvc_path) + ex.deploy_script_to_pvc( + script_content=ray_script, + dest_path=script_pvc_path, + token=self._token, + project_id=self._project_id, + cluster_id=self._cluster_id, + ) + + logger.info("Submitting distributed workload for Ray job '%s'", job_name) + payload = { + "name": job_name, + "useGivenNameAsPrefix": True, + "projectId": self._project_id, + "clusterId": self._cluster_id, + "spec": { + "command": f"/bin/bash {script_pvc_path}", + "image": ex.container_image, + "compute": { + "gpuDevicesRequest": ex.gpus_per_node, + "largeShmRequest": True, + }, + "storage": {"pvc": ex.pvcs}, + "environmentVariables": [{"name": k, "value": v} for k, v in ex.env_vars.items()], + "distributedFramework": ex.distributed_framework, + "minReplicas": ex.nodes, + "maxReplicas": ex.nodes, + "numWorkers": ex.nodes, + **ex.custom_spec, + }, + } + + headers = ex._default_headers(token=self._token) + resp = requests.post(f"{ex.base_url}/workloads/distributed", json=payload, headers=headers) + if resp.status_code not in (200, 202): + raise RuntimeError( + f"Distributed workload creation failed: {resp.status_code} {resp.text}" + ) + self.workload_id = resp.json()["workloadId"] + logger.info("Ray job submitted — workload ID: %s", self.workload_id) + return self.workload_id + + def status(self, display: bool = True) -> Optional[dict[str, Any]]: + if not self.workload_id: + logger.warning("No workload ID; call start() first.") + return None + + state = self.executor.status(self.workload_id) + info = { + "workload_id": self.workload_id, + "state": state.value if state else "Unknown", + "name": self.name, + } + if display: + logger.info( + "Ray Job Status (DGX Cloud)\n" + " Name: %s\n" + " Workload ID: %s\n" + " State: %s", + self.name, + self.workload_id, + info["state"], + ) + return info + + def stop(self, wait: bool = False, **kwargs: Any) -> None: + if not self.workload_id: + logger.warning("No workload ID to cancel.") + return + self.executor.cancel(self.workload_id) + if wait: + terminal = { + DGXCloudState.COMPLETED, + DGXCloudState.FAILED, + DGXCloudState.STOPPED, + DGXCloudState.DEGRADED, + } + for _ in range(60): + state = self.executor.status(self.workload_id) + if state in terminal: + logger.info("Workload %s reached %s", self.workload_id, state) + return + time.sleep(5) + logger.warning("Timed out waiting for workload %s to stop", self.workload_id) + + def logs(self, follow: bool = False, **kwargs: Any) -> None: + if not self.workload_id: + logger.warning("No workload ID; call start() first.") + return + for line in self.executor.fetch_logs(self.workload_id, stream=follow): + print(line, end="" if follow else "\n") + + def wait(self, poll: int | None = None) -> str: + """Block until the distributed workload reaches a terminal state. + + Returns the final phase as a string. + """ + if not self.workload_id: + raise RuntimeError("No workload ID; call start() first.") + + poll = poll or self.poll_interval + terminal = { + DGXCloudState.COMPLETED, + DGXCloudState.FAILED, + DGXCloudState.STOPPED, + DGXCloudState.DEGRADED, + } + while True: + time.sleep(poll) + state = self.executor.status(self.workload_id) + if state: + logger.info("Ray job %s — status: %s", self.name, state.value) + if state in terminal: + if state != DGXCloudState.COMPLETED: + raise RuntimeError(f"Ray job '{self.name}' ended with phase: {state.value}") + return state.value diff --git a/nemo_run/run/ray/job.py b/nemo_run/run/ray/job.py index 8c608a3f..9105e9ca 100644 --- a/nemo_run/run/ray/job.py +++ b/nemo_run/run/ray/job.py @@ -17,9 +17,11 @@ from typing import Any, Optional, Type from nemo_run.core.execution.base import Executor +from nemo_run.core.execution.dgxcloud import DGXCloudExecutor from nemo_run.core.execution.lepton import LeptonExecutor from nemo_run.core.execution.slurm import SlurmExecutor from nemo_run.core.frontend.console.api import configure_logging +from nemo_run.run.ray.dgxcloud import DGXCloudRayJob from nemo_run.run.ray.lepton import LeptonRayJob from nemo_run.run.ray.slurm import SlurmRayJob @@ -51,6 +53,7 @@ def __post_init__(self) -> None: # noqa: D401 – simple implementation backend_map: dict[Type[Executor], Type[Any]] = { LeptonExecutor: LeptonRayJob, SlurmExecutor: SlurmRayJob, + DGXCloudExecutor: DGXCloudRayJob, } if _KUBERAY_AVAILABLE and KubeRayExecutor is not None and KubeRayJob is not None: