diff --git a/pathwaysutils/experimental/shared_pathways_service/deploy_pathways_service.py b/pathwaysutils/experimental/shared_pathways_service/deploy_pathways_service.py index a4b78d3..865c6be 100644 --- a/pathwaysutils/experimental/shared_pathways_service/deploy_pathways_service.py +++ b/pathwaysutils/experimental/shared_pathways_service/deploy_pathways_service.py @@ -26,6 +26,11 @@ _SERVER_IMAGE = flags.DEFINE_string( "server_image", None, "Full path to the server Docker image" ) +_SIDECAR_IMAGE = flags.DEFINE_string( + "sidecar_image", + "us-docker.pkg.dev/cloud-tpu-v2-images/pathways-colocated-python/sidecar:20260423-python_3.12-jax_0.10.0", + "Full path to the sidecar Docker image", +) _TPU_TYPE = flags.DEFINE_enum( "tpu_type", "v6e", ["v5e", "v5p", "v6e", "tpu7x"], "TPU type" ) @@ -52,6 +57,7 @@ False, "If true, only print the generated YAML without deploying.", ) +_SIDECAR_DIR = "/tmp/sidecar_dir" @dataclasses.dataclass(frozen=True) @@ -191,6 +197,7 @@ def run_deployment( jobset_name, gcs_bucket, server_image, + sidecar_image, template_file, dry_run, deploy_func: Callable[[dict[str, Any]], None] = deploy_jobset, @@ -202,6 +209,8 @@ def run_deployment( context = { "JOBSET_NAME": jobset_name, "SERVER_IMAGE": server_image, + "SIDECAR_IMAGE": sidecar_image, + "SIDECAR_DIR": _SIDECAR_DIR, "GCS_SCRATCH_LOCATION": gcs_bucket, "NUM_SLICES": num_slices, "INSTANCE_TYPE": f"{tpu_config.instance_prefix}:{topology}", @@ -246,6 +255,7 @@ def main(argv: Sequence[str]) -> None: jobset_name=_JOBSET_NAME.value, gcs_bucket=_GCS_BUCKET.value, server_image=server_image, + sidecar_image=_SIDECAR_IMAGE.value, template_file=_TEMPLATE_FILE.value, dry_run=_DRY_RUN.value, ) diff --git a/pathwaysutils/experimental/shared_pathways_service/dockerfiles/sample_sidecar.dockerfile b/pathwaysutils/experimental/shared_pathways_service/dockerfiles/sample_sidecar.dockerfile new file mode 100644 index 0000000..42a224f --- /dev/null +++ b/pathwaysutils/experimental/shared_pathways_service/dockerfiles/sample_sidecar.dockerfile @@ -0,0 +1,26 @@ +ARG JAX_VERSION=0.10.0 +# Use the JAX image with the custom-built sidecar as the base. +FROM us-docker.pkg.dev/cloud-tpu-v2-images/pathways-colocated-python/sidecar:20260423-python_3.12-jax_${JAX_VERSION} + +ARG JAX_VERSION + +# Set the working directory +WORKDIR /app + +# 1. Upgrade pip and build tools +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --upgrade pip setuptools wheel + +# 2. Copy ONLY requirements first to leverage Docker layer caching. +COPY maxtext/src/dependencies/requirements/base_requirements/requirements.txt ./requirements.txt + +# ADD THE CACHE MOUNT HERE +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install -r requirements.txt && \ + uv pip install --upgrade jax==${JAX_VERSION} jaxlib==${JAX_VERSION} + +# 3. Copy ONLY the actual MaxText source code +COPY maxtext/src /app/maxtext/src + +# Ensure MaxText src and Orbax are in PYTHONPATH +ENV PYTHONPATH=/app/maxtext/src:/app/orbax/checkpoint:$PYTHONPATH diff --git a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py index 80fc04c..5e12272 100644 --- a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py +++ b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py @@ -48,14 +48,17 @@ class ProxyOptions: use_insecure_credentials: Whether to use insecure gRPC credentials for the proxy server. xla_flags: A list of XLA flags to pass to the proxy server. + sidecar: Whether to use the worker sidecar or not. """ use_insecure_credentials: bool = False xla_flags: list[str] = dataclasses.field(default_factory=list) + sidecar: bool = False @classmethod def from_list(cls, options: Iterable[str] | None) -> "ProxyOptions": """Creates a ProxyOptions object from a list of 'key:value' strings.""" use_insecure = False + use_sidecar = False xla_flags = [] for option in options or []: if ":" in option: @@ -63,6 +66,8 @@ def from_list(cls, options: Iterable[str] | None) -> "ProxyOptions": key_strip = key.strip().lower() if key_strip == "use_insecure_credentials": use_insecure = value.strip().lower() == "true" + elif key_strip == "sidecar": + use_sidecar = value.strip().lower() == "true" elif key_strip == "xla_flags": val_strip = value.strip() if ( @@ -78,7 +83,11 @@ def from_list(cls, options: Iterable[str] | None) -> "ProxyOptions": if xla_flags: validators.validate_xla_flags(xla_flags) - return cls(use_insecure_credentials=use_insecure, xla_flags=xla_flags) + return cls( + use_insecure_credentials=use_insecure, + xla_flags=xla_flags, + sidecar=use_sidecar, + ) def _deploy_pathways_proxy_server( @@ -134,6 +143,9 @@ def _deploy_pathways_proxy_server( ) proxy_args_str = "\n" + proxy_args_str + if proxy_options.sidecar: + proxy_args_str += "\n - --sidecar_name=external" + template = string.Template(yaml_template) substituted_yaml = template.substitute( PROXY_JOB_NAME=proxy_job_name, diff --git a/pathwaysutils/experimental/shared_pathways_service/yamls/pw-service.yaml b/pathwaysutils/experimental/shared_pathways_service/yamls/pw-service.yaml index 19769db..d01c258 100644 --- a/pathwaysutils/experimental/shared_pathways_service/yamls/pw-service.yaml +++ b/pathwaysutils/experimental/shared_pathways_service/yamls/pw-service.yaml @@ -87,6 +87,7 @@ spec: - --server_port=29005 - --resource_manager_address=$$(PATHWAYS_HEAD):29001 - --gcs_scratch_location=${GCS_SCRATCH_LOCATION} + - --cloud_pathways_sidecar_shm_directory=${SIDECAR_DIR} env: - name: TPU_MIN_LOG_LEVEL value: "0" @@ -133,8 +134,51 @@ spec: limits: google.com/tpu: "${CHIPS_PER_VM}" volumeMounts: - - mountPath: /tmp - name: shared-tmp + - name: shared-tmp + mountPath: /tmp + - name: cache + mountPath: /tmp/checkpoints + - name: sidecar-shared-memory + mountPath: ${SIDECAR_DIR} + initContainers: + - name: colocated-python-sidecar + image: ${SIDECAR_IMAGE} + imagePullPolicy: Always + env: + - name: GRPC_SERVER_ADDRESS + value: '''0.0.0.0:50051''' + - name: CLOUD_PATHWAYS_SIDECAR_SHM_DIRECTORY + value: ${SIDECAR_DIR} + - name: PYTHONUNBUFFERED + value: '1' + # --- High Verbosity Logging Variables --- + - name: LOGLEVEL + value: 'DEBUG' + - name: GLOG_minloglevel + value: '0' # 0 = INFO level base + - name: GLOG_v + value: '5' # Extreme verbosity for all C++ modules + - name: TF_CPP_MIN_LOG_LEVEL + value: '0' + - name: TF_CPP_MIN_VLOG_LEVEL + value: '5' # TF/XLA verbose logging + - name: TPU_MIN_LOG_LEVEL + value: '0' + - name: GLOG_vmodule + value: 'jax_array_handlers=5,type_handlers=5,tensorstore_utils=5' + # ---------------------------------------- + ports: + - containerPort: 50051 + protocol: TCP + resources: {} + restartPolicy: Always + volumeMounts: + - name: shared-tmp + mountPath: /tmp + - name: cache + mountPath: /tmp/checkpoints + - name: sidecar-shared-memory + mountPath: ${SIDECAR_DIR} dnsPolicy: ClusterFirstWithHostNet hostNetwork: true nodeSelector: @@ -146,6 +190,12 @@ spec: hostPath: path: /tmp type: DirectoryOrCreate + - name: cache + csi: + driver: multitier-checkpoint.csi.storage.gke.io + - name: sidecar-shared-memory + emptyDir: + medium: Memory startupPolicy: startupPolicyOrder: InOrder successPolicy: