Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 204 additions & 82 deletions nemo_run/core/execution/dgxcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@

import base64
import glob
import gzip
import json
import logging
import os
import shutil
import subprocess
import tempfile
import time
Expand Down Expand Up @@ -70,8 +72,8 @@ class DGXCloudExecutor(Executor):

base_url: str
kube_apiserver_url: str
app_id: str
app_secret: str
client_id: str
client_secret: str
project_name: str
container_image: str
pvc_nemo_run_dir: str
Expand All @@ -83,13 +85,14 @@ class DGXCloudExecutor(Executor):
pvcs: list[dict[str, Any]] = field(default_factory=list)
distributed_framework: str = "PyTorch"
custom_spec: dict[str, Any] = field(default_factory=dict)
MAX_ARGS_CHARS: int = 9500

def get_auth_token(self) -> Optional[str]:
url = f"{self.base_url}/token"
payload = {
"grantType": "app_token",
"appId": self.app_id,
"appSecret": self.app_secret,
"grantType": "client_credentials",
"clientId": self.client_id,
"clientSecret": self.client_secret,
}

n_attempts = 0
Expand Down Expand Up @@ -138,18 +141,46 @@ def copy_directory_data_command(self, local_dir_path: str, dest_path: str) -> st
cmd = f"rm -rf {dest_path} && mkdir -p {dest_path} && echo {encoded_data} | base64 -d > {dest_path}/archive.tar.gz && tar -xzf {dest_path}/archive.tar.gz -C {dest_path} && rm {dest_path}/archive.tar.gz"
return cmd

def create_data_mover_workload(self, token: str, project_id: str, cluster_id: str):
"""
Creates a CPU only workload to move job directory into PVC using the provided project/cluster IDs.
"""
def delete_workload(self, token: str, workload_id: str):
url = f"{self.base_url}/workloads/workspaces/{workload_id}"
headers = self._default_headers(token=token)

cmd = self.copy_directory_data_command(self.job_dir, self.pvc_job_dir)
response = requests.delete(url, headers=headers)

url = f"{self.base_url}/workloads/workspaces"
logger.debug(
"Delete interactive workspace; response code=%s, content=%s",
response.status_code,
response.text.strip(),
)
return response

def _workspace_status(self, workload_id: str) -> Optional[DGXCloudState]:
"""Query workspace-specific status endpoint for data-mover workloads."""
url = f"{self.base_url}/workloads/workspaces/{workload_id}"
token = self.get_auth_token()
if not token:
return None
headers = self._default_headers(token=token)
response = requests.get(url, headers=headers)
if response.status_code != 200:
return None
data = response.json()
phase = data.get("actualPhase") or data.get("phase")
return DGXCloudState(phase) if phase else None

def _run_workspace_and_wait(
self,
token: str,
project_id: str,
cluster_id: str,
name: str,
cmd: str,
sleep: float = 10,
timeout: int = 300,
) -> None:
"""Create a workspace workload, poll until done, then delete it."""
payload = {
"name": "data-mover",
"name": name,
"useGivenNameAsPrefix": True,
"projectId": project_id,
"clusterId": cluster_id,
Expand All @@ -160,76 +191,94 @@ def create_data_mover_workload(self, token: str, project_id: str, cluster_id: st
"storage": {"pvc": self.pvcs},
},
}

response = requests.post(url, json=payload, headers=headers)

logger.debug(
"Created workload; response code=%s, content=%s",
response.status_code,
response.text.strip(),
)

return response

def delete_workload(self, token: str, workload_id: str):
url = f"{self.base_url}/workloads/workspaces/{workload_id}"
headers = self._default_headers(token=token)

response = requests.delete(url, headers=headers)

logger.debug(
"Delete interactive workspace; response code=%s, content=%s",
response.status_code,
response.text.strip(),
)
return response
resp = requests.post(f"{self.base_url}/workloads/workspaces", json=payload, headers=headers)
if resp.status_code not in (200, 202):
raise RuntimeError(f"Workload '{name}' failed: {resp.status_code} {resp.text}")
wid = resp.json()["workloadId"]
logger.info(" workload %s (%s) created", name, wid[:12])

elapsed = 0
while elapsed < timeout:
time.sleep(sleep)
elapsed += sleep
status = self._workspace_status(wid)
if status == DGXCloudState.COMPLETED:
self.delete_workload(token, wid)
return
if status in (DGXCloudState.FAILED, DGXCloudState.STOPPED, DGXCloudState.DEGRADED):
self.delete_workload(token, wid)
raise RuntimeError(f"Workload {wid} ended with: {status}")
raise RuntimeError(f"Workload {wid} timed out after {timeout}s")

def move_data(self, token: str, project_id: str, cluster_id: str, sleep: float = 10) -> None:
"""
Moves job directory into PVC and deletes the workload after completion
"""

resp = self.create_data_mover_workload(token, project_id, cluster_id)
if resp.status_code not in [200, 202]:
raise RuntimeError(
f"Failed to create data mover workload, status_code={resp.status_code}, reason={resp.text}"
)

resp_json = resp.json()
workload_id = resp_json["workloadId"]
status = DGXCloudState(resp_json["actualPhase"])
"""Move job directory into PVC.

logger.info(f"Successfully created data movement workload {workload_id} on DGXCloud")

while status in [
DGXCloudState.PENDING,
DGXCloudState.CREATING,
DGXCloudState.INITIALIZING,
DGXCloudState.RUNNING,
]:
time.sleep(sleep)
status = self.status(workload_id)
logger.debug(
f"Polling data movement workload {workload_id}'s status. Current status is: {status}"
)
Uses the fast single-command tarball when it fits within the API's
10 000-char limit. Falls back to per-file deployment otherwise.
"""
cmd = self.copy_directory_data_command(self.job_dir, self.pvc_job_dir)

if status is not DGXCloudState.COMPLETED:
raise RuntimeError(f"Failed to move data to PVC. Workload status is {status}")
if len(cmd) <= self.MAX_ARGS_CHARS:
self._run_workspace_and_wait(token, project_id, cluster_id, "data-mover", cmd, sleep)
return

resp = self.delete_workload(token, workload_id)
if resp.status_code >= 200 and resp.status_code < 300:
logger.info(
"Successfully deleted data movement workload %s on DGXCloud with response code %d",
workload_id,
resp.status_code,
)
else:
logger.error(
"Failed to delete data movement workload %s, response code=%d, reason=%s",
workload_id,
resp.status_code,
resp.text,
)
logger.info(
"Tarball is %d chars (limit %d), deploying files individually",
len(cmd),
self.MAX_ARGS_CHARS,
)
for root, _, filenames in os.walk(self.job_dir):
for fn in filenames:
if fn.endswith(".tar.gz"):
continue
abs_path = os.path.join(root, fn)
rel_path = os.path.relpath(abs_path, self.job_dir)
dest = os.path.join(self.pvc_job_dir, rel_path)
with open(abs_path, "rb") as f:
data = f.read()

compressed = gzip.compress(data, compresslevel=9)
encoded = base64.b64encode(compressed).decode()
overhead = len(f"mkdir -p $(dirname {dest}) && echo | base64 -d | gunzip > {dest}")
chunk_b64_limit = self.MAX_ARGS_CHARS - overhead - 50

if len(encoded) <= chunk_b64_limit:
file_cmd = f"mkdir -p $(dirname {dest}) && echo {encoded} | base64 -d | gunzip > {dest}"
logger.info(
" deploying %s (%d→%d bytes)", rel_path, len(data), len(compressed)
)
self._run_workspace_and_wait(
token, project_id, cluster_id, "data-mover", file_cmd, sleep
)
else:
chunk_size = (chunk_b64_limit * 3) // 4
raw_chunks = [
compressed[i : i + chunk_size]
for i in range(0, len(compressed), chunk_size)
]
logger.info(
" deploying %s in %d chunks (%d→%d bytes)",
rel_path,
len(raw_chunks),
len(data),
len(compressed),
)
for ci, chunk in enumerate(raw_chunks):
b64 = base64.b64encode(chunk).decode()
if ci == 0:
file_cmd = (
f"mkdir -p $(dirname {dest}) && echo {b64} | base64 -d > {dest}.gz"
)
else:
file_cmd = f"echo {b64} | base64 -d >> {dest}.gz"
self._run_workspace_and_wait(
token, project_id, cluster_id, "data-mover", file_cmd, sleep
)
gunzip_cmd = f"gunzip -f {dest}.gz"
self._run_workspace_and_wait(
token, project_id, cluster_id, "data-mover", gunzip_cmd, sleep
)

def create_training_job(
self, token: str, project_id: str, cluster_id: str, name: str
Expand Down Expand Up @@ -272,7 +321,7 @@ def create_training_job(
common_spec = {
"command": f"/bin/bash {self.pvc_job_dir}/launch_script.sh",
"image": self.container_image,
"compute": {"gpuDevicesRequest": self.gpus_per_node},
"compute": {"gpuDevicesRequest": self.gpus_per_node, "largeShmRequest": True},
"storage": {"pvc": self.pvcs},
"environmentVariables": [
{"name": key, "value": value} for key, value in self.env_vars.items()
Expand Down Expand Up @@ -321,6 +370,17 @@ def launch(self, name: str, cmd: list[str]) -> tuple[str, str]:
if not project_id or not cluster_id:
raise RuntimeError("Unable to determine project/cluster IDs for job submission")

# Copy experiment-level files referenced in cmd into job_dir
# so they are included in the data mover transfer to the PVC
cmd_str = " ".join(cmd)
for fname in os.listdir(self.experiment_dir):
fpath = os.path.join(self.experiment_dir, fname)
if os.path.isfile(fpath) and fpath in cmd_str:
shutil.copy2(fpath, os.path.join(self.job_dir, fname))

# Rewrite local paths in cmd to point to the PVC job directory
cmd = [c.replace(self.experiment_dir, self.pvc_job_dir) for c in cmd]

# prepare launch script and move data to PVC
launch_script = f"""
ln -s {self.pvc_job_dir}/ /nemo_run
Expand Down Expand Up @@ -390,9 +450,37 @@ def fetch_logs(
stderr: Optional[bool] = None,
stdout: Optional[bool] = None,
) -> Iterable[str]:
while self.status(job_id) != DGXCloudState.RUNNING:
logger.info("Waiting for job to start...")
state = self.status(job_id)
while state != DGXCloudState.RUNNING:
logger.info("Job %s — status: %s", job_id[:12], state.value if state else "Unknown")
if state in (
DGXCloudState.COMPLETED,
DGXCloudState.FAILED,
DGXCloudState.STOPPED,
DGXCloudState.DEGRADED,
):
logger.warning("Job reached terminal state %s before logs were available", state)
return
time.sleep(15)
state = self.status(job_id)

if not self.launched_from_cluster:
logger.info("Job %s is RUNNING. Logs are available in the Run:AI UI.", job_id[:12])
terminal = (
DGXCloudState.COMPLETED,
DGXCloudState.FAILED,
DGXCloudState.STOPPED,
DGXCloudState.DEGRADED,
)
while True:
time.sleep(30)
state = self.status(job_id)
logger.info("Job %s — status: %s", job_id[:12], state.value if state else "Unknown")
if state in terminal:
yield f"Job finished with status: {state.value}"
return

logger.info("Job %s is RUNNING, waiting for log files...", job_id[:12])

cmd = ["tail"]

Expand All @@ -405,12 +493,21 @@ def fetch_logs(
self.pvc_job_dir = os.path.join(self.pvc_nemo_run_dir, job_subdir)

files = []
poll_count = 0
while len(files) < self.nodes:
files = list(glob.glob(f"{self.pvc_job_dir}/log_*.out"))
files = [f for f in files if "log-allranks_0" not in f]
logger.info(
f"Waiting for {self.nodes + 1 - len(files)} log files to be created in {self.pvc_job_dir}..."
)
if poll_count == 0 or poll_count % 10 == 0:
logger.info(
"Log files: %d/%d ready (watching %s)",
len(files),
self.nodes,
self.pvc_job_dir,
)
poll_count += 1
if poll_count > 100:
logger.warning("Timed out waiting for log files after 5 minutes")
return
time.sleep(3)

cmd.extend(files)
Expand Down Expand Up @@ -526,6 +623,30 @@ def assign(
)
self.experiment_id = exp_id

def deploy_script_to_pvc(
self,
script_content: str,
dest_path: str,
token: Optional[str] = None,
project_id: Optional[str] = None,
cluster_id: Optional[str] = None,
) -> None:
"""Write a script to the PVC via a short-lived busybox workspace."""
if not token:
token = self.get_auth_token()
if not token:
raise RuntimeError("Failed to get auth token for script deployment")
if not project_id or not cluster_id:
project_id, cluster_id = self.get_project_and_cluster_id(token)

encoded = base64.b64encode(gzip.compress(script_content.encode(), compresslevel=9)).decode()
cmd = (
f"mkdir -p $(dirname {dest_path}) && "
f"echo {encoded} | base64 -d | gunzip > {dest_path} && "
f"chmod +x {dest_path}"
)
self._run_workspace_and_wait(token, project_id, cluster_id, "script-deploy", cmd)

def get_launcher_prefix(self) -> Optional[list[str]]:
launcher = self.get_launcher()
if launcher.nsys_profile:
Expand Down Expand Up @@ -574,6 +695,7 @@ def package(self, packager: Packager, job_name: str):
ctx.run(
f"tar -xvzf {local_pkg} -C {local_code_extraction_path} --ignore-zeros", hide=True
)
os.remove(local_pkg)

def macro_values(self) -> Optional[ExecutorMacros]:
return None
Expand Down
Loading
Loading