From 11ddfef984f40e55b14f50673a7d71b63265a33e Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Fri, 12 Jun 2026 11:59:05 +0530 Subject: [PATCH] Flux training: Implement scanned blocks, dynamic gradient checkpointing, and weight loading improvements --- .../checkpointing/flux_checkpointer.py | 10 + src/maxdiffusion/configs/base_flux_dev.yml | 47 ++- src/maxdiffusion/generate_flux.py | 3 + src/maxdiffusion/models/attention_flax.py | 64 ++- .../transformers/transformer_flux_flax.py | 375 +++++++++++------- src/maxdiffusion/models/flux/util.py | 87 +++- .../models/gradient_checkpoint.py | 41 +- src/maxdiffusion/models/normalization_flax.py | 91 ++--- src/maxdiffusion/pyconfig.py | 37 +- src/maxdiffusion/trainers/flux_trainer.py | 79 +++- 10 files changed, 569 insertions(+), 265 deletions(-) diff --git a/src/maxdiffusion/checkpointing/flux_checkpointer.py b/src/maxdiffusion/checkpointing/flux_checkpointer.py index 78ad000b6..70b54d08d 100644 --- a/src/maxdiffusion/checkpointing/flux_checkpointer.py +++ b/src/maxdiffusion/checkpointing/flux_checkpointer.py @@ -214,6 +214,11 @@ def load_diffusers_checkpoint(self): dtype=self.config.activations_dtype, weights_dtype=self.config.weights_dtype, precision=max_utils.get_precision(self.config), + use_base2_exp=self.config.use_base2_exp, + use_experimental_scheduler=self.config.use_experimental_scheduler, + remat_policy=self.config.remat_policy, + names_which_can_be_saved=self.config.names_which_can_be_saved, + names_which_can_be_offloaded=self.config.names_which_can_be_offloaded, ) transformer_eval_params = transformer.init_weights( rngs=self.rng, max_sequence_length=self.config.max_sequence_length, eval_only=True @@ -279,6 +284,11 @@ def load_checkpoint(self, step=None, scheduler_class=None): weights_dtype=self.config.weights_dtype, precision=max_utils.get_precision(self.config), from_pt=self.config.from_pt, + use_base2_exp=self.config.use_base2_exp, + use_experimental_scheduler=self.config.use_experimental_scheduler, + remat_policy=self.config.remat_policy, + names_which_can_be_saved=self.config.names_which_can_be_saved, + names_which_can_be_offloaded=self.config.names_which_can_be_offloaded, ) pipeline = FluxPipeline( diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index ae81047a7..e71844594 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -63,6 +63,8 @@ jit_initializers: True from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te +use_base2_exp: False +use_experimental_scheduler: False # If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. # Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. # However, when padding tokens are significant, this will lead to worse quality and should be set to True. @@ -73,18 +75,18 @@ mask_padding_tokens: True # in cross attention q. attention_sharding_uniform: True -flash_block_sizes: {} +#flash_block_sizes: {} # Use the following flash_block_sizes on v6e (Trillium) due to larger vmem. -# flash_block_sizes: { -# "block_q" : 1536, -# "block_kv_compute" : 1536, -# "block_kv" : 1536, -# "block_q_dkv" : 1536, -# "block_kv_dkv" : 1536, -# "block_kv_dkv_compute" : 1536, -# "block_q_dq" : 1536, -# "block_kv_dq" : 1536 -# } +flash_block_sizes: { + "block_q" : 1536, + "block_kv_compute" : 1536, + "block_kv" : 1536, + "block_q_dkv" : 1536, + "block_kv_dkv" : 1536, + "block_kv_dkv_compute" : 1536, + "block_q_dq" : 1536, + "block_kv_dq" : 1536 +} # GroupNorm groups norm_num_groups: 32 @@ -147,9 +149,11 @@ mesh_axes: ['data', 'fsdp', 'context', 'tensor'] # conv_in : conv.shape[2] weight # conv_out : conv.shape[-1] weight logical_axis_rules: [ - ['batch', 'data'], + ['batch', ['data','fsdp']], ['activation_batch', ['data','fsdp']], - ['activation_heads', 'tensor'], + ['activation_heads', 'fsdp'], + ['activation_length', 'context'], + ['activation_kv_length', 'context'], ['activation_kv', 'tensor'], ['mlp','tensor'], ['embed','fsdp'], @@ -188,7 +192,7 @@ dataset_type: 'tfrecord' # Options: 'tfrecord', 'hf', 'tf', 'grain', 'synthetic # 2. Optionally set synthetic_num_samples (null=infinite, or a number like 10000) # 3. Optionally override dimensions # -# synthetic_num_samples: null # null for infinite, or set a number +synthetic_num_samples: 1000 # null for infinite, or set a number # # Optional dimension overrides: # resolution: 512 @@ -218,6 +222,21 @@ transform_images_num_proc: 4 reuse_example_batch: False enable_data_shuffling: True +# Defines the type of gradient checkpoint to enable. +# NONE - means no gradient checkpoint +# FULL - means full gradient checkpoint, whenever possible (minimum memory usage) +# MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, +# except for ones that involve batch dimension - that means that all attention and projection +# layers will have gradient checkpoint, but not the backward with respect to the parameters. +# OFFLOAD_MATMUL_WITHOUT_BATCH - same as MATMUL_WITHOUT_BATCH but offload instead of recomputing. +# CUSTOM - set names to offload and save. +remat_policy: "FLUX_OPTIMIZED" +# For CUSTOM policy set below, current annotations are for: attn_output, query_proj, key_proj, value_proj +# xq_out, xk_out, ffn_activation +names_which_can_be_saved: [] +names_which_can_be_offloaded: [] +flash_min_seq_length: 0 + # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 # enables one replica to read the ckpt then broadcast to the rest diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index 267e70277..8d2153c65 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -314,6 +314,9 @@ def run(config): dtype=config.activations_dtype, weights_dtype=config.weights_dtype, precision=get_precision(config), + remat_policy=config.remat_policy, + names_which_can_be_saved=config.names_which_can_be_saved, + names_which_can_be_offloaded=config.names_which_can_be_offloaded, ) num_channels_latents = transformer.in_channels // 4 diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index c8d94bef4..84a4a2030 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -1824,6 +1824,8 @@ class FlaxFluxAttention(nn.Module): out_axis_names: AxisNames = (BATCH, LENGTH, EMBED) precision: jax.lax.Precision = None qkv_bias: bool = False + use_base2_exp: bool = False + use_experimental_scheduler: bool = False def setup(self): if self.attention_kernel in {"flash", "cudnn_flash_te"} and self.mesh is None: @@ -1843,6 +1845,8 @@ def setup(self): flash_block_sizes=self.flash_block_sizes, dtype=self.dtype, float32_qk_product=False, + use_base2_exp=self.use_base2_exp, + use_experimental_scheduler=self.use_experimental_scheduler, ) kernel_axes = ("embed", "heads") @@ -1923,41 +1927,59 @@ def __call__( attention_mask=None, image_rotary_emb=None, ): - qkv_proj = self.qkv(hidden_states) B, L = hidden_states.shape[:2] - H, D, K = self.heads, qkv_proj.shape[-1] // (self.heads * 3), 3 - qkv_proj = qkv_proj.reshape(B, L, K, H, D).transpose(2, 0, 3, 1, 4) - query_proj, key_proj, value_proj = qkv_proj + # Deduce dimensions cleanly from class attributes + H, D = self.heads, self.dim_head - query_proj = self.query_norm(query_proj) + qkv_proj = self.qkv(hidden_states) + qkv_proj = checkpoint_name(qkv_proj, "img_qkv_proj") + + qkv_proj = qkv_proj.reshape(B, L, 3, H, D) + query_proj, key_proj, value_proj = jnp.split(qkv_proj, 3, axis=2) + query_proj = query_proj.squeeze(2) + key_proj = key_proj.squeeze(2) + value_proj = value_proj.squeeze(2) + query_proj = self.query_norm(query_proj) key_proj = self.key_norm(key_proj) if encoder_hidden_states is not None: + B_enc, L_txt = encoder_hidden_states.shape[:2] encoder_qkv_proj = self.encoder_qkv(encoder_hidden_states) - B, L = encoder_hidden_states.shape[:2] - H, D, K = self.heads, encoder_qkv_proj.shape[-1] // (self.heads * 3), 3 - encoder_qkv_proj = encoder_qkv_proj.reshape(B, L, K, H, D).transpose(2, 0, 3, 1, 4) - encoder_query_proj, encoder_key_proj, encoder_value_proj = encoder_qkv_proj + encoder_qkv_proj = checkpoint_name(encoder_qkv_proj, "txt_qkv_proj") + encoder_qkv_proj = encoder_qkv_proj.reshape(B_enc, L_txt, 3, H, D) + enc_query_proj, enc_key_proj, enc_value_proj = jnp.split(encoder_qkv_proj, 3, axis=2) + enc_query_proj = enc_query_proj.squeeze(2) + enc_key_proj = enc_key_proj.squeeze(2) + enc_value_proj = enc_value_proj.squeeze(2) - encoder_query_proj = self.encoder_query_norm(encoder_query_proj) + encoder_query_proj = self.encoder_query_norm(enc_query_proj) + encoder_key_proj = self.encoder_key_norm(enc_key_proj) - encoder_key_proj = self.encoder_key_norm(encoder_key_proj) + query_proj = jnp.concatenate((encoder_query_proj, query_proj), axis=1) + key_proj = jnp.concatenate((encoder_key_proj, key_proj), axis=1) + value_proj = jnp.concatenate((enc_value_proj, value_proj), axis=1) - query_proj = jnp.concatenate((encoder_query_proj, query_proj), axis=2) - key_proj = jnp.concatenate((encoder_key_proj, key_proj), axis=2) - value_proj = jnp.concatenate((encoder_value_proj, value_proj), axis=2) - - query_proj = nn.with_logical_constraint(query_proj, self.query_axis_names) - key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names) - value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names) + # query_proj = nn.with_logical_constraint(query_proj, self.query_axis_names) + # key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names) + # value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names) image_rotary_emb = rearrange(image_rotary_emb, "n d (i j) -> n d i j", i=2, j=2) + + query_proj = query_proj.swapaxes(1, 2) + key_proj = key_proj.swapaxes(1, 2) query_proj, key_proj = apply_rope(query_proj, key_proj, image_rotary_emb) + query_proj = query_proj.swapaxes(1, 2) + key_proj = key_proj.swapaxes(1, 2) + + query_proj = query_proj.reshape(B, -1, H * D) + key_proj = key_proj.reshape(B, -1, H * D) + value_proj = value_proj.reshape(B, -1, H * D) - query_proj = query_proj.transpose(0, 2, 1, 3).reshape(query_proj.shape[0], query_proj.shape[2], -1) - key_proj = key_proj.transpose(0, 2, 1, 3).reshape(key_proj.shape[0], key_proj.shape[2], -1) - value_proj = value_proj.transpose(0, 2, 1, 3).reshape(value_proj.shape[0], value_proj.shape[2], -1) + if encoder_hidden_states is not None: + query_proj = nn.with_logical_constraint(query_proj, self.query_axis_names) + key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names) + value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names) attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj, attention_mask=attention_mask) context_attn_output = None diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index 814e21eab..60de0f275 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -29,6 +29,8 @@ from .... import common_types from ....common_types import BlockSizes from ....utils import BaseOutput +from ...gradient_checkpoint import GradientCheckpointType, SKIP_GRADIENT_CHECKPOINT_KEY +from jax.ad_checkpoint import checkpoint_name AxisNames = common_types.AxisNames BATCH = common_types.BATCH @@ -50,6 +52,44 @@ class Transformer2DModelOutput(BaseOutput): sample: jnp.ndarray +class MlpAndOutputBlock(nn.Module): + dim: int + mlp_ratio: float = 4.0 + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + precision: jax.lax.Precision = None + + def setup(self): + self.mlp_hidden_dim = int(self.dim * self.mlp_ratio) + self.lin_mlp = nn.Dense( + self.mlp_hidden_dim, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ) + self.mlp_act = nn.gelu + self.linear2 = nn.Dense( + self.dim, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ) + + def __call__(self, x, attn_output, gate, residual): + mlp = self.lin_mlp(x) + attn_mlp = jnp.concatenate([attn_output, self.mlp_act(mlp)], axis=2) + attn_mlp = nn.with_logical_constraint(attn_mlp, ("activation_batch", None, "mlp")) + hidden_states = self.linear2(attn_mlp) + hidden_states = checkpoint_name(hidden_states, "lin2_hidden_states") + hidden_states = gate * hidden_states + hidden_states = residual + hidden_states + return hidden_states + + class FluxSingleTransformerBlock(nn.Module): r""" A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. @@ -75,6 +115,8 @@ class FluxSingleTransformerBlock(nn.Module): dtype: jnp.dtype = jnp.float32 weights_dtype: jnp.dtype = jnp.float32 precision: jax.lax.Precision = None + use_base2_exp: bool = False + use_experimental_scheduler: bool = False def setup(self): self.mlp_hidden_dim = int(self.dim * self.mlp_ratio) @@ -83,8 +125,8 @@ def setup(self): self.dim, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision ) - self.linear1 = nn.Dense( - self.dim * 3 + self.mlp_hidden_dim, + self.lin_qkv = nn.Dense( + self.dim * 3, kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), dtype=self.dtype, @@ -92,15 +134,14 @@ def setup(self): precision=self.precision, ) - self.mlp_act = nn.gelu - self.linear2 = nn.Dense( - self.dim, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), + self.mlp_and_out = nn.remat(MlpAndOutputBlock, prevent_cse=True)( + dim=self.dim, + mlp_ratio=self.mlp_ratio, dtype=self.dtype, - param_dtype=self.weights_dtype, + weights_dtype=self.weights_dtype, precision=self.precision, ) + self.attn = FlaxFluxAttention( query_dim=self.dim, heads=self.num_attention_heads, @@ -110,26 +151,35 @@ def setup(self): attention_kernel=self.attention_kernel, mesh=self.mesh, flash_block_sizes=self.flash_block_sizes, + use_base2_exp=self.use_base2_exp, + use_experimental_scheduler=self.use_experimental_scheduler, ) def __call__(self, hidden_states, temb, image_rotary_emb=None): residual = hidden_states + + # FIX: Constrain inputs using valid config parameters (None skips sequence length axis parsing) + hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", None, "mlp")) + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) - qkv, mlp = jnp.split(self.linear1(norm_hidden_states), [3 * self.dim], axis=-1) - mlp = nn.with_logical_constraint(mlp, ("activation_batch", "activation_length", "activation_embed")) - qkv = nn.with_logical_constraint(qkv, ("activation_batch", "activation_length", "activation_embed")) + + qkv = self.lin_qkv(norm_hidden_states) + qkv = checkpoint_name(qkv, "lin1_norm_hidden_states") + qkv = nn.with_logical_constraint(qkv, ("activation_batch", None, "mlp")) B, L = hidden_states.shape[:2] H, D, K = self.num_attention_heads, qkv.shape[-1] // (self.num_attention_heads * 3), 3 - qkv_proj = qkv.reshape(B, L, K, H, D).transpose(2, 0, 3, 1, 4) - q, k, v = qkv_proj + + qkv_proj = qkv.reshape(B, L, K, H, D) + q, k, v = jnp.split(qkv_proj, 3, axis=2) + q = q.squeeze(2).swapaxes(1, 2) + k = k.squeeze(2).swapaxes(1, 2) + v = v.squeeze(2).swapaxes(1, 2) q = self.attn.query_norm(q) k = self.attn.key_norm(k) if image_rotary_emb is not None: - # since this function returns image_rotary_emb and passes it between layers, - # we do not want to modify it image_rotary_emb_reordered = rearrange(image_rotary_emb, "n d (i j) -> n d i j", i=2, j=2) q, k = apply_rope(q, k, image_rotary_emb_reordered) @@ -138,12 +188,10 @@ def __call__(self, hidden_states, temb, image_rotary_emb=None): v = v.transpose(0, 2, 1, 3).reshape(v.shape[0], v.shape[2], -1) attn_output = self.attn.attention_op.apply_attention(q, k, v) + attn_output = checkpoint_name(attn_output, "attn_output") + + hidden_states = self.mlp_and_out(norm_hidden_states, attn_output, gate, residual) - attn_mlp = jnp.concatenate([attn_output, self.mlp_act(mlp)], axis=2) - attn_mlp = nn.with_logical_constraint(attn_mlp, ("activation_batch", "activation_length", "activation_embed")) - hidden_states = self.linear2(attn_mlp) - hidden_states = gate * hidden_states - hidden_states = residual + hidden_states if hidden_states.dtype == jnp.float16: hidden_states = jnp.clip(hidden_states, -65504, 65504) @@ -151,24 +199,11 @@ def __call__(self, hidden_states, temb, image_rotary_emb=None): class FluxTransformerBlock(nn.Module): - r""" - A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. - - Reference: https://arxiv.org/abs/2403.03206 - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the - processing of `context` conditions. - """ - dim: int num_attention_heads: int attention_head_dim: int qk_norm: str = "rms_norm" - eps: int = 1e-6 + eps: float = 1e-6 flash_min_seq_length: int = 4096 flash_block_sizes: BlockSizes = None mesh: jax.sharding.Mesh = None @@ -178,8 +213,11 @@ class FluxTransformerBlock(nn.Module): mlp_ratio: float = 4.0 qkv_bias: bool = False attention_kernel: str = "dot_product" + use_base2_exp: bool = False + use_experimental_scheduler: bool = False def setup(self): + # These contain the parameter projections ("lin"), optimize them using your updated AdaLayerNorm class self.img_norm1 = AdaLayerNormZero(self.dim, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision) self.txt_norm1 = AdaLayerNormZero(self.dim, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision) @@ -193,15 +231,13 @@ def setup(self): attention_kernel=self.attention_kernel, mesh=self.mesh, flash_block_sizes=self.flash_block_sizes, + use_base2_exp=self.use_base2_exp, + use_experimental_scheduler=self.use_experimental_scheduler, ) - self.img_norm2 = nn.LayerNorm( - use_bias=False, - use_scale=False, - epsilon=self.eps, - dtype=self.dtype, - param_dtype=self.weights_dtype, - ) + # REMOVED: self.img_norm2 and self.txt_norm2 completely to stop HBM memory spilling. + # The mathematical reductions are handled natively below. + self.img_mlp = nn.Sequential([ nn.Dense( int(self.dim * self.mlp_ratio), @@ -224,13 +260,6 @@ def setup(self): ), ]) - self.txt_norm2 = nn.LayerNorm( - use_bias=False, - use_scale=False, - epsilon=self.eps, - dtype=self.dtype, - param_dtype=self.weights_dtype, - ) self.txt_mlp = nn.Sequential([ nn.Dense( int(self.dim * self.mlp_ratio), @@ -253,74 +282,99 @@ def setup(self): ), ]) - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = 0 - def __call__(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb=None): - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.img_norm1(hidden_states, emb=temb) + # Enforce active partitioning based on your FSDP setup config + hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", None, "mlp")) + encoder_hidden_states = nn.with_logical_constraint(encoder_hidden_states, ("activation_batch", None, "mlp")) + # 1. First Adaptive Normalization Pass + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.img_norm1(hidden_states, emb=temb) norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.txt_norm1( encoder_hidden_states, emb=temb ) - # Attention. + # 2. Attention Mechanics attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, ) + # --- IMAGE STREAM OPTIMIZATION (img_norm2) --- attn_output = gate_msa * attn_output hidden_states = hidden_states + attn_output - norm_hidden_states = self.img_norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + # Fully fused LayerNorm + scale_mlp + shift_mlp compilation block + img_mean = jnp.mean(hidden_states, axis=-1, keepdims=True) + img_var = jnp.mean(jnp.square(hidden_states - img_mean), axis=-1, keepdims=True) + img_inv_std = jax.lax.rsqrt(img_var + self.eps) + + norm_hidden_states = (hidden_states - img_mean) * img_inv_std * (1 + scale_mlp) + shift_mlp + norm_hidden_states = nn.with_logical_constraint(norm_hidden_states, ("activation_batch", None, "mlp")) ff_output = self.img_mlp(norm_hidden_states) - ff_output = gate_mlp * ff_output + hidden_states = hidden_states + gate_mlp * ff_output - hidden_states = hidden_states + ff_output - # Process attention outputs for the `encoder_hidden_states`. + # --- TEXT STREAM OPTIMIZATION (txt_norm2) --- context_attn_output = c_gate_msa * context_attn_output encoder_hidden_states = encoder_hidden_states + context_attn_output - norm_encoder_hidden_states = self.txt_norm2(encoder_hidden_states) - norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp + # Fully fused LayerNorm + c_scale_mlp + c_shift_mlp compilation block + txt_mean = jnp.mean(encoder_hidden_states, axis=-1, keepdims=True) + txt_var = jnp.mean(jnp.square(encoder_hidden_states - txt_mean), axis=-1, keepdims=True) + txt_inv_std = jax.lax.rsqrt(txt_var + self.eps) + + norm_encoder_hidden_states = (encoder_hidden_states - txt_mean) * txt_inv_std * (1 + c_scale_mlp) + c_shift_mlp + norm_encoder_hidden_states = nn.with_logical_constraint(norm_encoder_hidden_states, ("activation_batch", None, "mlp")) context_ff_output = self.txt_mlp(norm_encoder_hidden_states) encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output - if encoder_hidden_states.dtype == jnp.float16: + + # Safe numerical clipping limits for half precision math execution + if encoder_hidden_states.dtype == jnp.float16 or encoder_hidden_states.dtype == jnp.bfloat16: encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + hidden_states = hidden_states.clip(-65504, 65504) + return hidden_states, encoder_hidden_states -@flax_register_to_config -class FluxTransformer2DModel(nn.Module, FlaxModelMixin, ConfigMixin): - r""" - The Transformer model introduced in Flux. +class ScannedDoubleBlockWrapper(nn.Module): + block_kwargs: dict - Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + @nn.compact + def __call__(self, carry, _): + hidden_states, encoder_hidden_states, temb, image_rotary_emb = carry - This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods - implemented for all models (such as downloading or saving). + # Instantiate the pure block (no remat here) + block = FluxTransformerBlock(**self.block_kwargs) - This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) - subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its - general usage and behavior. + h_out, e_out = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + return (h_out, e_out, temb, image_rotary_emb), None - Parameters: - patch_size (`int`): Patch size to turn the input data into small patches. - in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. - num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. - num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. - attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. - num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. - joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. - guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. - """ +class ScannedSingleBlockWrapper(nn.Module): + block_kwargs: dict + + @nn.compact + def __call__(self, carry, _): + hidden_states, temb, image_rotary_emb = carry + + # Instantiate the pure block + block = FluxSingleTransformerBlock(**self.block_kwargs) + h_out = block(hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb) + return (h_out, temb, image_rotary_emb), None + +@flax_register_to_config +class FluxTransformer2DModel(nn.Module, FlaxModelMixin, ConfigMixin): + r""" + The Transformer model introduced in Flux. + """ patch_size: int = 1 in_channels: int = 64 num_layers: int = 19 @@ -341,7 +395,12 @@ class FluxTransformer2DModel(nn.Module, FlaxModelMixin, ConfigMixin): qkv_bias: bool = True theta: int = 1000 attention_kernel: str = "dot_product" - eps = 1e-6 + eps: float = 1e-6 + remat_policy: str = "None" + names_which_can_be_saved: tuple = () + names_which_can_be_offloaded: tuple = () + use_base2_exp: bool = False + use_experimental_scheduler: bool = False def setup(self): self.out_channels = self.in_channels @@ -377,43 +436,85 @@ def setup(self): precision=self.precision, ) - double_blocks = [] - for _ in range(self.num_layers): - double_block = FluxTransformerBlock( - dim=self.inner_dim, - num_attention_heads=self.num_attention_heads, - attention_head_dim=self.attention_head_dim, - attention_kernel=self.attention_kernel, - flash_min_seq_length=self.flash_min_seq_length, - flash_block_sizes=self.flash_block_sizes, - mesh=self.mesh, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision, - mlp_ratio=self.mlp_ratio, - qkv_bias=self.qkv_bias, - ) - double_blocks.append(double_block) - self.double_blocks = double_blocks - - single_blocks = [] - for _ in range(self.num_single_layers): - single_block = FluxSingleTransformerBlock( - dim=self.inner_dim, - num_attention_heads=self.num_attention_heads, - attention_head_dim=self.attention_head_dim, - attention_kernel=self.attention_kernel, - flash_min_seq_length=self.flash_min_seq_length, - flash_block_sizes=self.flash_block_sizes, - mesh=self.mesh, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision, - mlp_ratio=self.mlp_ratio, - ) - single_blocks.append(single_block) - - self.single_blocks = single_blocks + self.gradient_checkpoint = GradientCheckpointType.from_str(self.remat_policy) + + # 2. Apply the policy to the Module classes + # RematDoubleBlock = self.gradient_checkpoint.apply_linen(FluxTransformerBlock) + # RematSingleBlock = self.gradient_checkpoint.apply_linen(FluxSingleTransformerBlock) + + # 1. Prepare the kwargs for the double blocks + double_kwargs = { + "dim": self.inner_dim, + "num_attention_heads": self.num_attention_heads, + "attention_head_dim": self.attention_head_dim, + "attention_kernel": self.attention_kernel, + "flash_min_seq_length": self.flash_min_seq_length, + "flash_block_sizes": self.flash_block_sizes, + "mesh": self.mesh, + "dtype": self.dtype, + "weights_dtype": self.weights_dtype, + "precision": self.precision, + "mlp_ratio": self.mlp_ratio, + "qkv_bias": self.qkv_bias, + "use_base2_exp": self.use_base2_exp, + "use_experimental_scheduler": self.use_experimental_scheduler, + } + + double_policy = self.gradient_checkpoint.to_jax_policy( + names_which_can_be_saved=self.names_which_can_be_saved, + names_which_can_be_offloaded=self.names_which_can_be_offloaded, + block_type="double", + ) + + if double_policy == SKIP_GRADIENT_CHECKPOINT_KEY: + RemattedDoubleWrapper = ScannedDoubleBlockWrapper + else: + RemattedDoubleWrapper = nn.remat(ScannedDoubleBlockWrapper, prevent_cse=True, policy=double_policy) + + self.scanned_double_blocks = nn.scan( + RemattedDoubleWrapper, + variable_axes={"params": 0}, + split_rngs={"params": True, "dropout": True}, + length=self.num_layers, + metadata_params={"partition_name": None}, + )(block_kwargs=double_kwargs) + + # 3. Define pure kwargs for single blocks + single_kwargs = { + "dim": self.inner_dim, + "num_attention_heads": self.num_attention_heads, + "attention_head_dim": self.attention_head_dim, + "attention_kernel": self.attention_kernel, + "flash_min_seq_length": self.flash_min_seq_length, + "flash_block_sizes": self.flash_block_sizes, + "mesh": self.mesh, + "dtype": self.dtype, + "weights_dtype": self.weights_dtype, + "precision": self.precision, + "mlp_ratio": self.mlp_ratio, + "use_base2_exp": self.use_base2_exp, + "use_experimental_scheduler": self.use_experimental_scheduler, + } + + # 4. Force strict checkpointing on the Single Wrapper + single_policy = self.gradient_checkpoint.to_jax_policy( + names_which_can_be_saved=self.names_which_can_be_saved, + names_which_can_be_offloaded=self.names_which_can_be_offloaded, + block_type="single", + ) + + if single_policy == SKIP_GRADIENT_CHECKPOINT_KEY: + RemattedSingleWrapper = ScannedSingleBlockWrapper + else: + RemattedSingleWrapper = nn.remat(ScannedSingleBlockWrapper, prevent_cse=True, policy=single_policy) + + self.scanned_single_blocks = nn.scan( + RemattedSingleWrapper, + variable_axes={"params": 0}, + split_rngs={"params": True, "dropout": True}, + length=self.num_single_layers, + metadata_params={"partition_name": None}, + )(block_kwargs=single_kwargs) self.norm_out = AdaLayerNormContinuous( self.inner_dim, @@ -478,7 +579,6 @@ def __call__( ): hidden_states = self.img_in(hidden_states) timestep = self.timestep_embedding(timestep, 256) - timestep = nn.with_logical_constraint(timestep, ("activation_batch", None)) if self.guidance_embeds: @@ -504,17 +604,18 @@ def __call__( image_rotary_emb = self.pe_embedder(ids) image_rotary_emb = nn.with_logical_constraint(image_rotary_emb, (None, None)) - for double_block in self.double_blocks: - hidden_states, encoder_hidden_states = double_block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - ) + carry = (hidden_states, encoder_hidden_states, temb, image_rotary_emb) + carry, _ = self.scanned_double_blocks(carry, None) + hidden_states, encoder_hidden_states, _, _ = carry + hidden_states = jnp.concatenate([encoder_hidden_states, hidden_states], axis=1) hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed")) - for single_block in self.single_blocks: - hidden_states = single_block(hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb) + + # Execute the 38 Single Blocks + carry = (hidden_states, temb, image_rotary_emb) + carry, _ = self.scanned_single_blocks(carry, None) + hidden_states, _, _ = carry + hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] hidden_states = self.norm_out(hidden_states, temb) @@ -532,24 +633,23 @@ def init_weights(self, rngs, max_sequence_length, eval_only=True): batch_size = 1 * num_devices batch_image_shape = ( batch_size, - 16, # 16 to match jflux.get_noise + 16, 2 * resolution // scale_factor, 2 * resolution // scale_factor, ) - # bs, encoder_input, seq_length text_shape = ( batch_size, max_sequence_length, - 4096, # Sequence length of text encoder, how to get this programmatically? + 4096, ) text_ids_shape = ( batch_size, max_sequence_length, - 3, # Hardcoded to match jflux.prepare + 3, ) vec_shape = ( batch_size, - 768, # Sequence length of clip, how to get this programmatically? + 768, ) img = jnp.zeros(batch_image_shape, dtype=self.dtype) bs, _, h, w = img.shape @@ -561,11 +661,8 @@ def init_weights(self, rngs, max_sequence_length, eval_only=True): txt = jnp.zeros(text_shape, dtype=self.dtype) txt_ids = jnp.zeros(text_ids_shape, dtype=self.dtype) - t_vec = jnp.full(bs, 0, dtype=self.dtype) - vec = jnp.zeros(vec_shape, dtype=self.dtype) - guidance_vec = jnp.full(bs, 4.0, dtype=self.dtype) if eval_only: diff --git a/src/maxdiffusion/models/flux/util.py b/src/maxdiffusion/models/flux/util.py index a4f665c6c..4a44bb172 100644 --- a/src/maxdiffusion/models/flux/util.py +++ b/src/maxdiffusion/models/flux/util.py @@ -154,9 +154,17 @@ def load_flow_model(name: str, eval_shapes: dict, device: str, hf_download: bool tensors[k] = torch2jax(f.get_tensor(k)) flax_state_dict = {} cpu = jax.local_devices(backend="cpu")[0] + + double_blocks_tensors = {} + single_blocks_tensors = {} + for pt_key, tensor in tensors.items(): - renamed_pt_key = rename_key(pt_key) - if "double_blocks" in renamed_pt_key: + if pt_key.startswith("double_blocks."): + parts = pt_key.split(".") + layer_idx = int(parts[1]) + pt_key_without_idx = "double_blocks." + ".".join(parts[2:]) + renamed_pt_key = rename_key(pt_key_without_idx) + renamed_pt_key = renamed_pt_key.replace("double_blocks", "scanned_double_blocks.FluxTransformerBlock_0") renamed_pt_key = renamed_pt_key.replace("img_mlp_", "img_mlp.layers_") renamed_pt_key = renamed_pt_key.replace("txt_mlp_", "txt_mlp.layers_") renamed_pt_key = renamed_pt_key.replace("img_mod", "img_norm1") @@ -168,14 +176,65 @@ def load_flow_model(name: str, eval_shapes: dict, device: str, hf_download: bool renamed_pt_key = renamed_pt_key.replace("txt_attn.proj", "attn.e_proj") renamed_pt_key = renamed_pt_key.replace("txt_attn.norm.key_norm", "attn.encoder_key_norm") renamed_pt_key = renamed_pt_key.replace("txt_attn.norm.query_norm", "attn.encoder_query_norm") - elif "guidance_in" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("guidance_in", "time_text_embed.FlaxTimestepEmbedding_1") - renamed_pt_key = renamed_pt_key.replace("in_layer", "linear_1") - renamed_pt_key = renamed_pt_key.replace("out_layer", "linear_2") - elif "single_blocks" in renamed_pt_key: + + pt_tuple_key = tuple(renamed_pt_key.split(".")) + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, eval_shapes, scan_layers=True) + if flax_key not in double_blocks_tensors: + double_blocks_tensors[flax_key] = {} + double_blocks_tensors[flax_key][layer_idx] = flax_tensor + continue + + elif pt_key.startswith("single_blocks."): + parts = pt_key.split(".") + layer_idx = int(parts[1]) + pt_key_without_idx = "single_blocks." + ".".join(parts[2:]) + renamed_pt_key = rename_key(pt_key_without_idx) + renamed_pt_key = renamed_pt_key.replace("single_blocks", "scanned_single_blocks.FluxSingleTransformerBlock_0") renamed_pt_key = renamed_pt_key.replace("modulation", "norm") renamed_pt_key = renamed_pt_key.replace("norm.key_norm", "attn.key_norm") renamed_pt_key = renamed_pt_key.replace("norm.query_norm", "attn.query_norm") + + if "linear1" in renamed_pt_key: + if tensor.ndim == 2: + qkv_tensor = tensor[:9216, :] + mlp_tensor = tensor[9216:, :] + else: + qkv_tensor = tensor[:9216] + mlp_tensor = tensor[9216:] + qkv_pt_key = renamed_pt_key.replace("linear1", "lin_qkv") + mlp_pt_key = renamed_pt_key.replace("linear1", "mlp_and_out.lin_mlp") + + flax_key_qkv, flax_tensor_qkv = rename_key_and_reshape_tensor( + tuple(qkv_pt_key.split(".")), qkv_tensor, eval_shapes, scan_layers=True + ) + flax_key_mlp, flax_tensor_mlp = rename_key_and_reshape_tensor( + tuple(mlp_pt_key.split(".")), mlp_tensor, eval_shapes, scan_layers=True + ) + + if flax_key_qkv not in single_blocks_tensors: + single_blocks_tensors[flax_key_qkv] = {} + single_blocks_tensors[flax_key_qkv][layer_idx] = flax_tensor_qkv + + if flax_key_mlp not in single_blocks_tensors: + single_blocks_tensors[flax_key_mlp] = {} + single_blocks_tensors[flax_key_mlp][layer_idx] = flax_tensor_mlp + continue + + elif "linear2" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("linear2", "mlp_and_out.linear2") + + pt_tuple_key = tuple(renamed_pt_key.split(".")) + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, eval_shapes, scan_layers=True) + if flax_key not in single_blocks_tensors: + single_blocks_tensors[flax_key] = {} + single_blocks_tensors[flax_key][layer_idx] = flax_tensor + continue + + renamed_pt_key = rename_key(pt_key) + if "guidance_in" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("guidance_in", "time_text_embed.FlaxTimestepEmbedding_1") + renamed_pt_key = renamed_pt_key.replace("in_layer", "linear_1") + renamed_pt_key = renamed_pt_key.replace("out_layer", "linear_2") elif "vector_in" in renamed_pt_key or "time_in" in renamed_pt_key: renamed_pt_key = renamed_pt_key.replace("vector_in", "time_text_embed.PixArtAlphaTextProjection_0") renamed_pt_key = renamed_pt_key.replace("time_in", "time_text_embed.FlaxTimestepEmbedding_0") @@ -184,9 +243,23 @@ def load_flow_model(name: str, eval_shapes: dict, device: str, hf_download: bool elif "final_layer" in renamed_pt_key: renamed_pt_key = renamed_pt_key.replace("final_layer.linear", "proj_out") renamed_pt_key = renamed_pt_key.replace("final_layer.adaLN_modulation_1", "norm_out.Dense_0") + pt_tuple_key = tuple(renamed_pt_key.split(".")) flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, eval_shapes) flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) + + # Stack double blocks + for flax_key, layers in double_blocks_tensors.items(): + sorted_indices = sorted(layers.keys()) + stacked_tensor = jnp.stack([layers[i] for i in sorted_indices], axis=0) + flax_state_dict[flax_key] = jax.device_put(stacked_tensor, device=cpu) + + # Stack single blocks + for flax_key, layers in single_blocks_tensors.items(): + sorted_indices = sorted(layers.keys()) + stacked_tensor = jnp.stack([layers[i] for i in sorted_indices], axis=0) + flax_state_dict[flax_key] = jax.device_put(stacked_tensor, device=cpu) + validate_flax_state_dict(eval_shapes, flax_state_dict) flax_state_dict = unflatten_dict(flax_state_dict) del tensors diff --git a/src/maxdiffusion/models/gradient_checkpoint.py b/src/maxdiffusion/models/gradient_checkpoint.py index 18e5c7e65..366f93a65 100644 --- a/src/maxdiffusion/models/gradient_checkpoint.py +++ b/src/maxdiffusion/models/gradient_checkpoint.py @@ -21,6 +21,8 @@ from jax import checkpoint_policies as cp from flax import nnx +import flax.linen as linen_nn + SKIP_GRADIENT_CHECKPOINT_KEY = "skip" @@ -42,6 +44,7 @@ class GradientCheckpointType(Enum): OFFLOAD_MATMUL_WITHOUT_BATCH = auto() CUSTOM = auto() HIDDEN_STATE_WITH_OFFLOAD = auto() + FLUX_OPTIMIZED = auto() @classmethod def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType": @@ -58,7 +61,12 @@ def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType": s = "none" return GradientCheckpointType[s.upper()] - def to_jax_policy(self, names_which_can_be_saved: list = [], names_which_can_be_offloaded: list = []): + def to_jax_policy( + self, + names_which_can_be_saved: list = [], + names_which_can_be_offloaded: list = [], + block_type: Optional[str] = None, + ): """ Converts the gradient checkpoint type to a jax policy """ @@ -86,6 +94,13 @@ def to_jax_policy(self, names_which_can_be_saved: list = [], names_which_can_be_ ) case GradientCheckpointType.MATMUL_WITHOUT_BATCH: return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + case GradientCheckpointType.FLUX_OPTIMIZED: + if block_type == "double": + return cp.save_any_names_but_these("img_qkv_proj", "txt_qkv_proj") + elif block_type == "single": + return cp.save_only_these_names("attn_output", "lin1_norm_hidden_states") + else: + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims def apply( self, @@ -109,3 +124,27 @@ def apply( if policy == SKIP_GRADIENT_CHECKPOINT_KEY: return module return nnx.remat(module, prevent_cse=prevent_cse, policy=policy, static_argnums=static_argnums) # pylint: disable=invalid-name + + def apply_linen( + self, + module_class: type[linen_nn.Module], + names_which_can_be_saved: list = [], + names_which_can_be_offloaded: list = [], + static_argnums=(), + prevent_cse: bool = False, + ) -> type[linen_nn.Module]: + """ + Applies a gradient checkpoint policy to a Linen module class. + If no policy is needed, it will return the module class as is. + + Args: + module_class: the Linen Module class to apply the policy to + + Returns: + The rematerialized Module class (or the original if no policy). + """ + policy = self.to_jax_policy(names_which_can_be_saved, names_which_can_be_offloaded) + if policy == SKIP_GRADIENT_CHECKPOINT_KEY: + return module_class + + return linen_nn.remat(module_class, prevent_cse=prevent_cse, policy=policy, static_argnums=static_argnums) diff --git a/src/maxdiffusion/models/normalization_flax.py b/src/maxdiffusion/models/normalization_flax.py index 24f423f14..6a63cf10a 100644 --- a/src/maxdiffusion/models/normalization_flax.py +++ b/src/maxdiffusion/models/normalization_flax.py @@ -51,14 +51,6 @@ def __call__(self, x, conditioning_embedding): class AdaLayerNormZero(nn.Module): - r""" - Norm layer adaptive layer norm zero (adaLN-Zero). - - Parameters: - embedding_dim (`int`): The size of each embedding vector. - num_embeddings (`int`): The size of the embeddings dictionary. - """ - embedding_dim: int norm_type: str = "layer_norm" bias: bool = True @@ -68,6 +60,10 @@ class AdaLayerNormZero(nn.Module): @nn.compact def __call__(self, x, emb): + emb = nn.silu(emb) + + # Pretrained Flux checks: The dual block variant projects to 6 * dim + # to unpack: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp emb = nn.Dense( 6 * self.embedding_dim, use_bias=self.bias, @@ -77,38 +73,30 @@ def __call__(self, x, emb): param_dtype=self.weights_dtype, precision=self.precision, name="lin", - )(nn.silu(emb)) - (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = jnp.split(emb[:, None, :], 6, axis=-1) - shift_msa = nn.with_logical_constraint(shift_msa, ("activation_batch", "activation_embed")) - scale_msa = nn.with_logical_constraint(scale_msa, ("activation_batch", "activation_embed")) - gate_msa = nn.with_logical_constraint(gate_msa, ("activation_batch", "activation_embed")) - shift_mlp = nn.with_logical_constraint(shift_mlp, ("activation_batch", "activation_embed")) - scale_mlp = nn.with_logical_constraint(scale_mlp, ("activation_batch", "activation_embed")) - gate_mlp = nn.with_logical_constraint(gate_mlp, ("activation_batch", "activation_embed")) + )(emb) + + emb = emb[:, None, :] + + # Explicit MaxDiffusion 3D axis alignment mapping to your 'mlp' layout rule + emb = nn.with_logical_constraint(emb, ("activation_batch", None, "mlp")) + + # Slicing the 6 chunks safely within your fsdp:8, tensor:1 configuration + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = jnp.split(emb, 6, axis=-1) if self.norm_type == "layer_norm": - x = nn.LayerNorm( - epsilon=1e-6, - use_bias=False, - use_scale=False, - dtype=self.dtype, - param_dtype=self.weights_dtype, - )(x) + # Fused mathematical reduction loop + mean = jnp.mean(x, axis=-1, keepdims=True) + variance = jnp.mean(jnp.square(x - mean), axis=-1, keepdims=True) + inv_std = jax.lax.rsqrt(variance + 1e-6) + + x = (x - mean) * inv_std * (1.0 + scale_msa) + shift_msa else: - raise ValueError(f"Unsupported `norm_type` ({self.norm_type}) provided. Supported ones are: 'layer_norm'.") - x = x * (1 + scale_msa) + shift_msa + raise ValueError(f"Unsupported `norm_type` ({self.norm_type}) provided.") + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp class AdaLayerNormZeroSingle(nn.Module): - r""" - Norm layer adaptive layer norm zero (adaLN-Zero). - - Parameters: - embedding_dim (`int`): The size of each embedding vector. - num_embeddings (`int`): The size of the embeddings dictionary. - """ - embedding_dim: int norm_type: str = "layer_norm" bias: bool = True @@ -119,6 +107,8 @@ class AdaLayerNormZeroSingle(nn.Module): @nn.compact def __call__(self, x, emb): emb = nn.silu(emb) + + # Matches your config layout precisely emb = nn.Dense( 3 * self.embedding_dim, use_bias=self.bias, @@ -129,24 +119,27 @@ def __call__(self, x, emb): precision=self.precision, name="lin", )(emb) - shift_msa, scale_msa, gate_msa = jnp.split(emb[:, None, :], 3, axis=-1) - shift_msa = nn.with_logical_constraint(shift_msa, ("activation_batch", "activation_embed")) - scale_msa = nn.with_logical_constraint(scale_msa, ("activation_batch", "activation_embed")) - gate_msa = nn.with_logical_constraint(gate_msa, ("activation_batch", "activation_embed")) + + # 1. Expand layout safely to a 3D Tensor + emb = emb[:, None, :] + + # 2. FIX: Apply verified MaxDiffusion logical rules to match the 3D footprint + # We map the channels to 'mlp' because that matches the output layout dimension of the dense layer + emb = nn.with_logical_constraint(emb, ("activation_batch", None, "mlp")) + + # 3. Slicing now happens safely within known sharding rules + shift_msa, scale_msa, gate_msa = jnp.split(emb, 3, axis=-1) + if self.norm_type == "layer_norm": - x = ( - nn.LayerNorm( - epsilon=1e-6, - use_bias=False, - use_scale=False, - dtype=self.dtype, - param_dtype=self.weights_dtype, - )(x) - * (1 + scale_msa) - + shift_msa - ) + # Fused optimization math keeping exact pretrained weight compatibility + mean = jnp.mean(x, axis=-1, keepdims=True) + variance = jnp.mean(jnp.square(x - mean), axis=-1, keepdims=True) + inv_std = jax.lax.rsqrt(variance + 1e-6) + + x = (x - mean) * inv_std * (1.0 + scale_msa) + shift_msa else: - raise ValueError(f"Unsupported `norm_type` ({self.norm_type}) provided. Supported ones are: 'layer_norm'.") + raise ValueError(f"Unsupported `norm_type` ({self.norm_type}) provided.") + return x, gate_msa diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 6ef6ccbfc..132087723 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -196,34 +196,15 @@ def calculate_global_batch_sizes(per_device_batch_size): @staticmethod def user_init(raw_keys): """Transformations between the config data and configs used at runtime""" - # Set default for use_batched_text_encoder if not present - if "use_batched_text_encoder" not in raw_keys: - raw_keys["use_batched_text_encoder"] = False - - # Validate that use_batched_text_encoder is a boolean - if not isinstance(raw_keys["use_batched_text_encoder"], bool): - raise TypeError( - f"Expected config.use_batched_text_encoder to be a boolean, but got {type(raw_keys['use_batched_text_encoder'])}" - ) - # Set defaults for dtypes if they weren't explicitly provided - if "vae_dtype" not in raw_keys: - raw_keys["vae_dtype"] = "float32" - if "vae_weights_dtype" not in raw_keys: - raw_keys["vae_weights_dtype"] = "float32" - if "scheduler_dtype" not in raw_keys: - raw_keys["scheduler_dtype"] = "float32" - - # Cast all dtype configs to jax.numpy.dtype - for dtype_key in [ - "weights_dtype", - "activations_dtype", - "scheduler_dtype", - "vae_dtype", - "vae_weights_dtype", - "text_encoder_dtype", - ]: - if dtype_key in raw_keys: - raw_keys[dtype_key] = jax.numpy.dtype(raw_keys[dtype_key]) + if "remat_policy" not in raw_keys: + raw_keys["remat_policy"] = "None" + if "names_which_can_be_saved" not in raw_keys: + raw_keys["names_which_can_be_saved"] = [] + if "names_which_can_be_offloaded" not in raw_keys: + raw_keys["names_which_can_be_offloaded"] = [] + + raw_keys["weights_dtype"] = jax.numpy.dtype(raw_keys["weights_dtype"]) + raw_keys["activations_dtype"] = jax.numpy.dtype(raw_keys["activations_dtype"]) if raw_keys["run_name"] == "": raw_keys["run_name"] = os.environ.get("JOBSET_NAME") # using XPK default run_name = raw_keys["run_name"] diff --git a/src/maxdiffusion/trainers/flux_trainer.py b/src/maxdiffusion/trainers/flux_trainer.py index c042d9540..1c20aec1a 100644 --- a/src/maxdiffusion/trainers/flux_trainer.py +++ b/src/maxdiffusion/trainers/flux_trainer.py @@ -43,10 +43,13 @@ write_metrics, ) + from maxdiffusion.maxdiffusion_utils import calculate_flux_tflops from ..schedulers import (FlaxEulerDiscreteScheduler) +import tensorflow as tf + class FluxTrainer(FluxCheckpointer): _profiler: max_utils.Profiler | None = None @@ -108,6 +111,12 @@ def start_training(self): # don't need this anymore, clear some memory. del pipeline.t5_encoder + del pipeline.clip_encoder + del pipeline.clip_tokenizer + del pipeline.t5_tokenizer + del pipeline.vae + del state_shardings[VAE_STATE_SHARDINGS_KEY] + jax.clear_caches() # evaluate shapes @@ -121,10 +130,10 @@ def start_training(self): checkpoint_item_name=FLUX_STATE_KEY, is_training=True, ) + flux_state = jax.device_put(flux_state, flux_state_mesh_shardings) train_states[FLUX_STATE_KEY] = flux_state state_shardings[FLUX_STATE_SHARDINGS_KEY] = flux_state_mesh_shardings - # self.post_training_steps(pipeline, params, train_states, msg="before_training") # Create scheduler noise_scheduler, noise_scheduler_state = self.create_scheduler(pipeline, params) @@ -140,7 +149,7 @@ def start_training(self): p_train_step = self.compile_train_step(pipeline, params, train_states, state_shardings, data_shardings) # Start training train_states = self.training_loop( - p_train_step, pipeline, params, train_states, data_iterator, flux_learning_rate_scheduler + p_train_step, pipeline, params, train_states, data_iterator, data_shardings, flux_learning_rate_scheduler ) # 6. save final checkpoint # Hook @@ -247,6 +256,28 @@ def load_dataset(self, pipeline, params, train_states): total_train_batch_size = self.total_train_batch_size mesh = self.mesh + feature_description = { + "pixel_values": tf.io.FixedLenFeature([], tf.string), + "input_ids": tf.io.FixedLenFeature([], tf.string), + "text_embeds": tf.io.FixedLenFeature([], tf.string), + "prompt_embeds": tf.io.FixedLenFeature([], tf.string), + "img_ids": tf.io.FixedLenFeature([], tf.string), + } + + def prepare_sample_train(features): + pixel_values = tf.io.parse_tensor(features["pixel_values"], out_type=tf.bfloat16) + input_ids = tf.io.parse_tensor(features["input_ids"], out_type=tf.bfloat16) + text_embeds = tf.io.parse_tensor(features["text_embeds"], out_type=tf.float32) + prompt_embeds = tf.io.parse_tensor(features["prompt_embeds"], out_type=tf.float32) + img_ids = tf.io.parse_tensor(features["img_ids"], out_type=tf.float32) + return { + "pixel_values": pixel_values, + "input_ids": input_ids, + "text_embeds": text_embeds, + "prompt_embeds": prompt_embeds, + "img_ids": img_ids, + } + # If using synthetic data if config.dataset_type == "synthetic": return make_data_iterator( @@ -255,6 +286,8 @@ def load_dataset(self, pipeline, params, train_states): jax.process_count(), mesh, total_train_batch_size, + feature_description=feature_description, + prepare_sample_fn=prepare_sample_train, pipeline=pipeline, # Pass pipeline to extract dimensions is_training=True, ) @@ -283,6 +316,28 @@ def load_dataset(self, pipeline, params, train_states): prepare_latent_imgage_ids=prepare_latent_image_ids_p, ) + feature_description = { + "pixel_values": tf.io.FixedLenFeature([], tf.string), + "input_ids": tf.io.FixedLenFeature([], tf.string), + "text_embeds": tf.io.FixedLenFeature([], tf.string), + "prompt_embeds": tf.io.FixedLenFeature([], tf.string), + "img_ids": tf.io.FixedLenFeature([], tf.string), + } + + def prepare_sample_train(features): + pixel_values = tf.io.parse_tensor(features["pixel_values"], out_type=tf.bfloat16) + input_ids = tf.io.parse_tensor(features["input_ids"], out_type=tf.bfloat16) + text_embeds = tf.io.parse_tensor(features["text_embeds"], out_type=tf.float32) + prompt_embeds = tf.io.parse_tensor(features["prompt_embeds"], out_type=tf.float32) + img_ids = tf.io.parse_tensor(features["img_ids"], out_type=tf.float32) + return { + "pixel_values": pixel_values, + "input_ids": input_ids, + "text_embeds": text_embeds, + "prompt_embeds": prompt_embeds, + "img_ids": img_ids, + } + data_iterator = make_data_iterator( config, jax.process_index(), @@ -291,6 +346,8 @@ def load_dataset(self, pipeline, params, train_states): total_train_batch_size, tokenize_fn=tokenize_fn, image_transforms_fn=image_transforms_fn, + feature_description=feature_description, + prepare_sample_fn=prepare_sample_train, ) return data_iterator @@ -315,17 +372,24 @@ def compile_train_step(self, pipeline, params, train_states, state_shardings, da None, ), out_shardings=(state_shardings["flux_state_shardings"], None, None), - donate_argnums=(0,), + donate_argnums=(0, 1), ) max_logging.log("Precompiling...") s = time.time() dummy_batch = self.get_shaped_batch(self.config, pipeline) - p_train_step = p_train_step.lower(train_states[FLUX_STATE_KEY], dummy_batch, train_rngs) + abstract_flux_state = jax.tree_util.tree_map( + lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), train_states[FLUX_STATE_KEY] + ) + + abstract_rngs = jax.tree_util.tree_map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), train_rngs) + p_train_step = p_train_step.lower(abstract_flux_state, dummy_batch, abstract_rngs) p_train_step = p_train_step.compile() max_logging.log(f"Compile time: {(time.time() - s )}") return p_train_step - def training_loop(self, p_train_step, pipeline, params, train_states, data_iterator, unet_learning_rate_scheduler): + def training_loop( + self, p_train_step, pipeline, params, train_states, data_iterator, data_shardings, unet_learning_rate_scheduler + ): writer = max_utils.initialize_summary_writer(self.config) flux_state = train_states[FLUX_STATE_KEY] num_model_parameters = max_utils.calculate_num_params_from_pytree(flux_state.params) @@ -360,7 +424,10 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera self._profiler.start() example_batch = load_next_batch(data_iterator, example_batch, self.config) - example_batch = {key: jnp.asarray(value, dtype=self.config.activations_dtype) for key, value in example_batch.items()} + example_batch = { + key: jax.device_put(jnp.asarray(value, dtype=self.config.activations_dtype), data_shardings[key]) + for key, value in example_batch.items() + } if self.config.profiler == "nsys": with self.mesh: