From 70b345700c8c3eb4bdf89a9360be4d83dfb86dff Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Wed, 20 May 2026 20:41:04 +0530 Subject: [PATCH] feat: add g++ compiler to base image and implement hardware-aware FeedForward sharding --- .../base_requirements/requirements.txt | 1 + .../generated_requirements/requirements.txt | 1 + maxdiffusion_dependencies.Dockerfile | 2 +- src/maxdiffusion/models/attention_flax.py | 22 ++++++++++++++++--- 4 files changed, 22 insertions(+), 4 deletions(-) diff --git a/dependencies/requirements/base_requirements/requirements.txt b/dependencies/requirements/base_requirements/requirements.txt index 8b217c8e5..f3c0c40f9 100644 --- a/dependencies/requirements/base_requirements/requirements.txt +++ b/dependencies/requirements/base_requirements/requirements.txt @@ -2,6 +2,7 @@ absl-py accelerate aqtp +av chex datasets einops diff --git a/dependencies/requirements/generated_requirements/requirements.txt b/dependencies/requirements/generated_requirements/requirements.txt index b21196755..4f65992a5 100644 --- a/dependencies/requirements/generated_requirements/requirements.txt +++ b/dependencies/requirements/generated_requirements/requirements.txt @@ -15,6 +15,7 @@ astroid>=4.0.4 astunparse>=1.6.3 attrs>=25.4.0 auditwheel>=6.6.0 +av>=17.0.1 black>=25.12.0 build>=1.4.0 certifi>=2026.1.4 diff --git a/maxdiffusion_dependencies.Dockerfile b/maxdiffusion_dependencies.Dockerfile index 6168f1285..9a9598271 100644 --- a/maxdiffusion_dependencies.Dockerfile +++ b/maxdiffusion_dependencies.Dockerfile @@ -17,7 +17,7 @@ ENV DEBIAN_FRONTEND=noninteractive RUN python -m pip install --upgrade pip uv --no-warn-script-location # Install system dependencies -RUN apt-get update && apt-get install -y apt-utils git curl gnupg procps iproute2 ethtool && rm -rf /var/lib/apt/lists/* +RUN apt-get update && apt-get install -y apt-utils git curl gnupg procps iproute2 ethtool g++ && rm -rf /var/lib/apt/lists/* # Add the Google Cloud SDK package repository RUN curl -fsSL https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg && \ diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 6ff4b4fed..d740fbc4f 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -1134,6 +1134,22 @@ def __init__( ): inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim + + tpu_type = get_tpu_type() + is_ironwood = tpu_type == TpuType.TPU_7X + + # Hardware-aware sharding specs: Ironwood (v7x) keeps the embedding dimension (embed) + # replicated (None) to minimize cross-device communication, while other hardware (default) + # shards it to prevent OOM issues. + if is_ironwood: + net0_kernel_spec = (None, "mlp") + net2_kernel_spec = ("mlp", None) + net2_bias_spec = (None,) + else: + net0_kernel_spec = ("embed", "mlp") + net2_kernel_spec = ("mlp", "embed") + net2_bias_spec = ("embed",) + self.net_0 = nnx.Linear( dim, inner_dim, @@ -1142,7 +1158,7 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "mlp")), + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), net0_kernel_spec), bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)), ) self.act = get_activation(activation_fn) @@ -1154,8 +1170,8 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("mlp", None)), - bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None,)), + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), net2_kernel_spec), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, net2_bias_spec), ) def __call__(self, hidden_states: Array) -> Array: