diff --git a/pathwaysutils/experimental/gke/jobset.py b/pathwaysutils/experimental/gke/jobset.py index 17345f3..a29441f 100644 --- a/pathwaysutils/experimental/gke/jobset.py +++ b/pathwaysutils/experimental/gke/jobset.py @@ -9,14 +9,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Pathways JobSet generator and builder (Skeleton).""" +"""Pathways JobSet generator and builder (Head Job Config).""" + +import json +import logging from typing import Any, Mapping from kubernetes import client +# GKE sidecar containers restartPolicy compatibility placeholder. + +_logger = logging.getLogger(__name__) + # Core constants. PATHWAYS_HEAD_JOB_NAME = "pathways-head" PATHWAYS_WORKER_JOB_NAME = "pathways-worker" +DEFAULT_PATHWAYS_RM_AND_WORKER_IMAGE = ( + "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server" +) +DEFAULT_PATHWAYS_PROXY_IMAGE = ( + "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server" +) + +PATHWAYS_PROXY_PORT = 29000 +PATHWAYS_RM_PORT = 29001 + MACHINE_TYPE_TO_TPU_VERSION_MAP = { "tpu7x-standard-4t": "tpu7x", "tpu7x": "tpu7x", @@ -48,18 +65,34 @@ } +def _deserialize_dict( + api_client: client.ApiClient, data_dict: Mapping[str, Any], klass: Any +) -> Any: + class FakeResponse: + + def __init__(self, data): + self.data = data + + return api_client.deserialize(FakeResponse(json.dumps(data_dict)), klass) + + class PathwaysJobSet: - """Generates JobSet configuration for Pathways (Skeleton).""" + """Generates JobSet configuration for Pathways (with Head Job Config).""" def __init__( self, name: str, namespace: str, + pathways_dir: str, tpu_type: str, + topology: str, num_slices: int, user_pod_template: Mapping[str, Any] | None = None, + main_container_name: str = "main", max_restarts: int = 0, + pathways_version: str = "latest", jobset_api_version: str = "v1alpha2", + elastic_slices: int = 0, labels: Mapping[str, str] | None = None, annotations: Mapping[str, str] | None = None, ): @@ -68,11 +101,16 @@ def __init__( Args: name: Name of the JobSet. namespace: Namespace of the JobSet. + pathways_dir: GCS path for Pathways scratch space. tpu_type: TPU type (e.g., "v5e"). + topology: TPU topology (e.g., "2x2"). num_slices: Number of slices. user_pod_template: Optional user pod template for the head job. + main_container_name: Name of the main container in user_pod_template. max_restarts: Maximum number of restarts for the JobSet. + pathways_version: Version tag for Pathways images. jobset_api_version: API version of JobSet. + elastic_slices: Number of elastic slices. labels: Optional labels for the JobSet. annotations: Optional annotations for the JobSet. """ @@ -88,8 +126,19 @@ def __init__( if not tpu_version: raise ValueError(f"Unsupported TPU type: {tpu_type}") - # Build minimal head template (placeholder) - self._head_job_template = self._build_minimal_job_template("head") + instance_type = f"{tpu_version}:{topology}" + image_tag = pathways_version + + # Build head template. + self._head_job_template = self._build_head_job_template( + pathways_dir=pathways_dir, + num_slices=num_slices, + instance_type=instance_type, + image_tag=image_tag, + user_pod_template=user_pod_template, + main_container_name=main_container_name, + elastic_slices=elastic_slices, + ) # Build minimal worker template (placeholder) self._worker_job_template = self._build_minimal_job_template("worker") @@ -115,6 +164,207 @@ def _build_minimal_job_template(self, role: str) -> client.V1JobTemplateSpec: ) return client.V1JobTemplateSpec(spec=job_spec) + def _build_head_job_template( + self, + pathways_dir: str, + num_slices: int, + instance_type: str, + image_tag: str, + user_pod_template: Mapping[str, Any] | None, + main_container_name: str, + elastic_slices: int, + ) -> client.V1JobTemplateSpec: + """Builds the head job template for the JobSet. + + Args: + pathways_dir: GCS path for Pathways scratch space. + num_slices: Number of slices. + instance_type: TPU instance type (e.g., "tpuv5:2x2"). + image_tag: Version tag for Pathways images. + user_pod_template: Optional user pod template for the head job. + main_container_name: Name of the main container in user_pod_template. + elastic_slices: Number of elastic slices. + + Returns: + The head job template. + """ + rm_image = f"{DEFAULT_PATHWAYS_RM_AND_WORKER_IMAGE}:{image_tag}" + proxy_image = f"{DEFAULT_PATHWAYS_PROXY_IMAGE}:{image_tag}" + + rm_args = [ + f"--server_port={PATHWAYS_RM_PORT}", + f"--gcs_scratch_location={pathways_dir}", + "--node_type=resource_manager", + f"--instance_count={num_slices}", + f"--instance_type={instance_type}", + ] + rm_env = [ + client.V1EnvVar( + name="REPLICATED_JOB_NAME", + value_from=client.V1EnvVarSource( + field_ref=client.V1ObjectFieldSelector( + field_path="metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']" + ) + ), + ), + client.V1EnvVar( + name="JOBSET_NAME", + value_from=client.V1EnvVarSource( + field_ref=client.V1ObjectFieldSelector( + field_path=( + "metadata.annotations['jobset.sigs.k8s.io/jobset-name']" + ) + ) + ), + ), + client.V1EnvVar( + name="HOST_ADDRESS", + value_from=client.V1EnvVarSource( + field_ref=client.V1ObjectFieldSelector( + field_path=( + "metadata.labels['jobset.sigs.k8s.io/coordinator']" + ) + ) + ), + ), + client.V1EnvVar(name="TPU_SKIP_MDS_QUERY", value="true"), + ] + rm_container = client.V1Container( + name="pathways-rm", + image=rm_image, + image_pull_policy="Always", + args=rm_args, + env=rm_env, + ports=[ + client.V1ContainerPort( + container_port=PATHWAYS_RM_PORT, protocol="TCP" + ), + client.V1ContainerPort(container_port=29002, protocol="TCP"), + ], + resources=client.V1ResourceRequirements( + limits={"cpu": "8", "memory": "32G"} + ), + ) + + proxy_args = [ + f"--server_port={PATHWAYS_PROXY_PORT}", + f"--resource_manager_address=$(PATHWAYS_HEAD):{PATHWAYS_RM_PORT}", + f"--gcs_scratch_location={pathways_dir}", + ] + if elastic_slices > 0: + proxy_args.append(f"--num_elastic_slices={elastic_slices}") + + proxy_env = [ + client.V1EnvVar( + name="PATHWAYS_HEAD", + value_from=client.V1EnvVarSource( + field_ref=client.V1ObjectFieldSelector( + field_path=( + "metadata.labels['jobset.sigs.k8s.io/coordinator']" + ) + ) + ), + ) + ] + proxy_container = client.V1Container( + name="pathways-proxy", + image=proxy_image, + image_pull_policy="Always", + args=proxy_args, + env=proxy_env, + ports=[ + client.V1ContainerPort( + container_port=PATHWAYS_PROXY_PORT, protocol="TCP" + ) + ], + resources=client.V1ResourceRequirements( + limits={"cpu": "16", "memory": "100G"} + ), + ) + + api_client = client.ApiClient() + + if user_pod_template: + user_template_obj = _deserialize_dict( + api_client, user_pod_template, client.V1PodTemplateSpec + ) + head_pod_spec = user_template_obj.spec + head_pod_spec.host_network = True + head_pod_spec.dns_policy = "ClusterFirstWithHostNet" + + rm_container.restart_policy = "Always" + proxy_container.restart_policy = "Always" + + init_containers = head_pod_spec.init_containers or [] + init_containers.extend([rm_container, proxy_container]) + head_pod_spec.init_containers = init_containers + + # Inject JAX env vars into main container. + jax_env = [ + client.V1EnvVar( + name="PATHWAYS_HEAD", + value_from=client.V1EnvVarSource( + field_ref=client.V1ObjectFieldSelector( + field_path=( + "metadata.labels['jobset.sigs.k8s.io/coordinator']" + ) + ) + ), + ), + client.V1EnvVar(name="JAX_PLATFORMS", value="proxy"), + client.V1EnvVar(name="XCLOUD_ENVIRONMENT", value="GCP"), + client.V1EnvVar( + name="JAX_BACKEND_TARGET", + value=f"grpc://$(PATHWAYS_HEAD):{PATHWAYS_PROXY_PORT}", + ), + ] + containers = head_pod_spec.containers or [] + for c in containers: + if c.name == main_container_name: + env = c.env or [] + env.extend(jax_env) + c.env = env + break + head_pod_spec.containers = containers + + annotations = user_pod_template.get("metadata", {}).get("annotations", {}) + labels = user_pod_template.get("metadata", {}).get("labels", {}) + else: + # Headless mode. + head_pod_spec = client.V1PodSpec( + host_network=True, + dns_policy="ClusterFirstWithHostNet", + containers=[rm_container, proxy_container], + ) + annotations = {} + labels = {} + + if not head_pod_spec.restart_policy: + head_pod_spec.restart_policy = "Never" + + # Default annotations + job_annotations = { + "alpha.jobset.sigs.k8s.io/exclusive-topology": "kubernetes.io/hostname" + } + job_annotations.update(annotations) + + head_job_template = client.V1JobTemplateSpec( + metadata=client.V1ObjectMeta(annotations=job_annotations), + spec=client.V1JobSpec( + backoff_limit=0, + completion_mode="Indexed", + completions=1, + parallelism=1, + template=client.V1PodTemplateSpec( + metadata=client.V1ObjectMeta( + annotations=job_annotations, labels=labels + ), + spec=head_pod_spec, + ), + ), + ) + return head_job_template + def _compile_config(self) -> dict[str, Any]: """Compiles the JobSet configuration into a dictionary.""" with client.ApiClient() as api_client: diff --git a/pathwaysutils/test/experimental/gke/jobset_test.py b/pathwaysutils/test/experimental/gke/jobset_test.py index 2e4e197..2a1a92d 100644 --- a/pathwaysutils/test/experimental/gke/jobset_test.py +++ b/pathwaysutils/test/experimental/gke/jobset_test.py @@ -1,5 +1,6 @@ from absl.testing import absltest from absl.testing import parameterized +from kubernetes import client from pathwaysutils.experimental.gke import jobset @@ -10,51 +11,126 @@ def test_invalid_tpu_type(self): jobset.PathwaysJobSet( name="test-jobset", namespace="default", + pathways_dir="gs://test-bucket", tpu_type="invalid-tpu", + topology="4x4", num_slices=1, ) - def test_basic_jobset_structure(self): + def test_headless_head_job(self): js = jobset.PathwaysJobSet( name="test-jobset", namespace="default", + pathways_dir="gs://test-bucket", tpu_type="v5e", + topology="4x8", num_slices=2, - labels={"app": "pathways"}, - annotations={"example.com/annotation": "value"}, + elastic_slices=2, ) config = js.to_dict() - self.assertEqual(config["apiVersion"], "jobset.sigs.k8s.io/v1alpha2") - self.assertEqual(config["kind"], "JobSet") - self.assertEqual(config["metadata"]["name"], "test-jobset") - self.assertEqual(config["metadata"]["namespace"], "default") - self.assertEqual(config["metadata"]["labels"]["app"], "pathways") + replicated_jobs = config["spec"]["replicatedJobs"] + self.assertLen(replicated_jobs, 2) + + head_job = next(j for j in replicated_jobs if j["name"] == "pathways-head") + self.assertEqual(head_job["replicas"], 1) + + pod_spec = head_job["template"]["spec"]["template"]["spec"] + self.assertTrue(pod_spec["hostNetwork"]) + self.assertEqual(pod_spec["dnsPolicy"], "ClusterFirstWithHostNet") + self.assertEqual(pod_spec["restartPolicy"], "Never") + + # In headless mode, RM and Proxy are in containers list + containers = pod_spec["containers"] + self.assertLen(containers, 2) + rm_container = next(c for c in containers if c["name"] == "pathways-rm") + proxy_container = next( + c for c in containers if c["name"] == "pathways-proxy" + ) + + self.assertEqual( + rm_container["image"], + "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest", + ) self.assertEqual( - config["metadata"]["annotations"]["example.com/annotation"], "value" + proxy_container["image"], + "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest", ) + self.assertIn("--num_elastic_slices=2", proxy_container["args"]) - self.assertEqual(config["spec"]["failurePolicy"]["maxRestarts"], 0) + def test_non_headless_head_job(self): + user_pod_template = { + "metadata": {"annotations": {"example.com/annotation": "value"}}, + "spec": { + "containers": [{ + "name": "jax-tpu", + "image": "ubuntu:latest", + "command": ["sleep", "infinity"], + }] + }, + } + js = jobset.PathwaysJobSet( + name="test-jobset", + namespace="default", + pathways_dir="gs://test-bucket", + tpu_type="v5e", + topology="4x8", + num_slices=2, + user_pod_template=user_pod_template, + main_container_name="jax-tpu", + ) + config = js.to_dict() replicated_jobs = config["spec"]["replicatedJobs"] - self.assertLen(replicated_jobs, 2) + head_job = next(j for j in replicated_jobs if j["name"] == "pathways-head") - head_job = replicated_jobs[0] - self.assertEqual(head_job["name"], "pathways-head") - self.assertEqual(head_job["replicas"], 1) + pod_spec = head_job["template"]["spec"]["template"]["spec"] + self.assertTrue(pod_spec["hostNetwork"]) + self.assertEqual(pod_spec["dnsPolicy"], "ClusterFirstWithHostNet") - # In K8s API models, V1JobTemplateSpec -> V1JobSpec -> V1PodTemplateSpec - # -> V1PodSpec. When serialized, they match this structure. - head_pod_spec = head_job["template"]["spec"]["template"]["spec"] - self.assertEqual(head_pod_spec["containers"][0]["name"], "placeholder-head") + # RM and Proxy should be in initContainers + init_containers = pod_spec["initContainers"] + self.assertLen(init_containers, 2) + rm_container = next( + c for c in init_containers if c["name"] == "pathways-rm" + ) + proxy_container = next( + c for c in init_containers if c["name"] == "pathways-proxy" + ) - worker_job = replicated_jobs[1] - self.assertEqual(worker_job["name"], "pathways-worker") - self.assertEqual(worker_job["replicas"], 2) - worker_pod_spec = worker_job["template"]["spec"]["template"]["spec"] + self.assertEqual(rm_container["restartPolicy"], "Always") + self.assertEqual(proxy_container["restartPolicy"], "Always") + + # Main container should have JAX env vars injected + main_container = next( + c for c in pod_spec["containers"] if c["name"] == "jax-tpu" + ) + env_names = [e["name"] for e in main_container["env"]] + self.assertIn("PATHWAYS_HEAD", env_names) + self.assertIn("JAX_PLATFORMS", env_names) + self.assertIn("XCLOUD_ENVIRONMENT", env_names) + self.assertIn("JAX_BACKEND_TARGET", env_names) + + # Verify annotations are propagated self.assertEqual( - worker_pod_spec["containers"][0]["name"], "placeholder-worker" + head_job["template"]["metadata"]["annotations"][ + "example.com/annotation" + ], + "value", ) + self.assertEqual( + head_job["template"]["spec"]["template"]["metadata"]["annotations"][ + "example.com/annotation" + ], + "value", + ) + + def test_monkeypatch_restart_policy(self): + # Construct V1Container with restart_policy to test monkeypatch. + c = client.V1Container( + name="test", restart_policy="Always" + ) # pytype: disable=wrong-keyword-args + self.assertEqual(getattr(c, "restart_policy"), "Always") if __name__ == "__main__":