Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/maxdiffusion/checkpointing/flux_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,11 @@ def load_diffusers_checkpoint(self):
dtype=self.config.activations_dtype,
weights_dtype=self.config.weights_dtype,
precision=max_utils.get_precision(self.config),
use_base2_exp=self.config.use_base2_exp,
use_experimental_scheduler=self.config.use_experimental_scheduler,
remat_policy=self.config.remat_policy,
names_which_can_be_saved=self.config.names_which_can_be_saved,
names_which_can_be_offloaded=self.config.names_which_can_be_offloaded,
Comment on lines +220 to +221

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make these variable names a little more explicit? Maybe something like saved_transformer_layer_names or savable_transformer_layer_names and offloaded_transformer_layer_names or offloadable_transformer_layer_names? I will leave it upto you

)
transformer_eval_params = transformer.init_weights(
rngs=self.rng, max_sequence_length=self.config.max_sequence_length, eval_only=True
Expand Down Expand Up @@ -279,6 +284,11 @@ def load_checkpoint(self, step=None, scheduler_class=None):
weights_dtype=self.config.weights_dtype,
precision=max_utils.get_precision(self.config),
from_pt=self.config.from_pt,
use_base2_exp=self.config.use_base2_exp,
use_experimental_scheduler=self.config.use_experimental_scheduler,
remat_policy=self.config.remat_policy,
names_which_can_be_saved=self.config.names_which_can_be_saved,
names_which_can_be_offloaded=self.config.names_which_can_be_offloaded,
)
Comment on lines +290 to 292

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above


pipeline = FluxPipeline(
Expand Down
47 changes: 33 additions & 14 deletions src/maxdiffusion/configs/base_flux_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ jit_initializers: True
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
use_base2_exp: False
use_experimental_scheduler: False
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
Expand All @@ -73,18 +75,18 @@ mask_padding_tokens: True
# in cross attention q.
attention_sharding_uniform: True

flash_block_sizes: {}
#flash_block_sizes: {}
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.
# flash_block_sizes: {
# "block_q" : 1536,
# "block_kv_compute" : 1536,
# "block_kv" : 1536,
# "block_q_dkv" : 1536,
# "block_kv_dkv" : 1536,
# "block_kv_dkv_compute" : 1536,
# "block_q_dq" : 1536,
# "block_kv_dq" : 1536
# }
flash_block_sizes: {
"block_q" : 1536,
"block_kv_compute" : 1536,
"block_kv" : 1536,
"block_q_dkv" : 1536,
"block_kv_dkv" : 1536,
"block_kv_dkv_compute" : 1536,
"block_q_dq" : 1536,
"block_kv_dq" : 1536
}
# GroupNorm groups
norm_num_groups: 32

Expand Down Expand Up @@ -147,9 +149,11 @@ mesh_axes: ['data', 'fsdp', 'context', 'tensor']
# conv_in : conv.shape[2] weight
# conv_out : conv.shape[-1] weight
logical_axis_rules: [
['batch', 'data'],
['batch', ['data','fsdp']],
['activation_batch', ['data','fsdp']],
['activation_heads', 'tensor'],
['activation_heads', 'fsdp'],
['activation_length', 'context'],
['activation_kv_length', 'context'],
['activation_kv', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
Expand Down Expand Up @@ -188,7 +192,7 @@ dataset_type: 'tfrecord' # Options: 'tfrecord', 'hf', 'tf', 'grain', 'synthetic
# 2. Optionally set synthetic_num_samples (null=infinite, or a number like 10000)
# 3. Optionally override dimensions
#
# synthetic_num_samples: null # null for infinite, or set a number
synthetic_num_samples: 1000 # null for infinite, or set a number
#
# Optional dimension overrides:
# resolution: 512
Expand Down Expand Up @@ -218,6 +222,21 @@ transform_images_num_proc: 4
reuse_example_batch: False
enable_data_shuffling: True

# Defines the type of gradient checkpoint to enable.
# NONE - means no gradient checkpoint
# FULL - means full gradient checkpoint, whenever possible (minimum memory usage)
# MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation,
# except for ones that involve batch dimension - that means that all attention and projection
# layers will have gradient checkpoint, but not the backward with respect to the parameters.
# OFFLOAD_MATMUL_WITHOUT_BATCH - same as MATMUL_WITHOUT_BATCH but offload instead of recomputing.
# CUSTOM - set names to offload and save.
remat_policy: "FLUX_OPTIMIZED"
# For CUSTOM policy set below, current annotations are for: attn_output, query_proj, key_proj, value_proj
# xq_out, xk_out, ffn_activation
names_which_can_be_saved: []
names_which_can_be_offloaded: []
Comment on lines +236 to +237

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

flash_min_seq_length: 0

# checkpoint every number of samples, -1 means don't checkpoint.
checkpoint_every: -1
# enables one replica to read the ckpt then broadcast to the rest
Expand Down
3 changes: 3 additions & 0 deletions src/maxdiffusion/generate_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,9 @@ def run(config):
dtype=config.activations_dtype,
weights_dtype=config.weights_dtype,
precision=get_precision(config),
remat_policy=config.remat_policy,
names_which_can_be_saved=config.names_which_can_be_saved,
names_which_can_be_offloaded=config.names_which_can_be_offloaded,
Comment on lines +318 to +319

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change variable names

)

num_channels_latents = transformer.in_channels // 4
Expand Down
64 changes: 43 additions & 21 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1824,6 +1824,8 @@ class FlaxFluxAttention(nn.Module):
out_axis_names: AxisNames = (BATCH, LENGTH, EMBED)
precision: jax.lax.Precision = None
qkv_bias: bool = False
use_base2_exp: bool = False
use_experimental_scheduler: bool = False

def setup(self):
if self.attention_kernel in {"flash", "cudnn_flash_te"} and self.mesh is None:
Expand All @@ -1843,6 +1845,8 @@ def setup(self):
flash_block_sizes=self.flash_block_sizes,
dtype=self.dtype,
float32_qk_product=False,
use_base2_exp=self.use_base2_exp,
use_experimental_scheduler=self.use_experimental_scheduler,
)

kernel_axes = ("embed", "heads")
Expand Down Expand Up @@ -1923,41 +1927,59 @@ def __call__(
attention_mask=None,
image_rotary_emb=None,
):
qkv_proj = self.qkv(hidden_states)
B, L = hidden_states.shape[:2]
H, D, K = self.heads, qkv_proj.shape[-1] // (self.heads * 3), 3
qkv_proj = qkv_proj.reshape(B, L, K, H, D).transpose(2, 0, 3, 1, 4)
query_proj, key_proj, value_proj = qkv_proj
# Deduce dimensions cleanly from class attributes
H, D = self.heads, self.dim_head

query_proj = self.query_norm(query_proj)
qkv_proj = self.qkv(hidden_states)
qkv_proj = checkpoint_name(qkv_proj, "img_qkv_proj")

qkv_proj = qkv_proj.reshape(B, L, 3, H, D)
query_proj, key_proj, value_proj = jnp.split(qkv_proj, 3, axis=2)
query_proj = query_proj.squeeze(2)
key_proj = key_proj.squeeze(2)
value_proj = value_proj.squeeze(2)

query_proj = self.query_norm(query_proj)
key_proj = self.key_norm(key_proj)

if encoder_hidden_states is not None:
B_enc, L_txt = encoder_hidden_states.shape[:2]
encoder_qkv_proj = self.encoder_qkv(encoder_hidden_states)
B, L = encoder_hidden_states.shape[:2]
H, D, K = self.heads, encoder_qkv_proj.shape[-1] // (self.heads * 3), 3
encoder_qkv_proj = encoder_qkv_proj.reshape(B, L, K, H, D).transpose(2, 0, 3, 1, 4)
encoder_query_proj, encoder_key_proj, encoder_value_proj = encoder_qkv_proj
encoder_qkv_proj = checkpoint_name(encoder_qkv_proj, "txt_qkv_proj")
encoder_qkv_proj = encoder_qkv_proj.reshape(B_enc, L_txt, 3, H, D)
enc_query_proj, enc_key_proj, enc_value_proj = jnp.split(encoder_qkv_proj, 3, axis=2)
enc_query_proj = enc_query_proj.squeeze(2)
enc_key_proj = enc_key_proj.squeeze(2)
enc_value_proj = enc_value_proj.squeeze(2)

encoder_query_proj = self.encoder_query_norm(encoder_query_proj)
encoder_query_proj = self.encoder_query_norm(enc_query_proj)
encoder_key_proj = self.encoder_key_norm(enc_key_proj)

encoder_key_proj = self.encoder_key_norm(encoder_key_proj)
query_proj = jnp.concatenate((encoder_query_proj, query_proj), axis=1)
key_proj = jnp.concatenate((encoder_key_proj, key_proj), axis=1)
value_proj = jnp.concatenate((enc_value_proj, value_proj), axis=1)

query_proj = jnp.concatenate((encoder_query_proj, query_proj), axis=2)
key_proj = jnp.concatenate((encoder_key_proj, key_proj), axis=2)
value_proj = jnp.concatenate((encoder_value_proj, value_proj), axis=2)

query_proj = nn.with_logical_constraint(query_proj, self.query_axis_names)
key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names)
value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names)
# query_proj = nn.with_logical_constraint(query_proj, self.query_axis_names)
# key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names)
# value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names)

image_rotary_emb = rearrange(image_rotary_emb, "n d (i j) -> n d i j", i=2, j=2)

query_proj = query_proj.swapaxes(1, 2)
key_proj = key_proj.swapaxes(1, 2)
query_proj, key_proj = apply_rope(query_proj, key_proj, image_rotary_emb)
query_proj = query_proj.swapaxes(1, 2)
key_proj = key_proj.swapaxes(1, 2)

query_proj = query_proj.reshape(B, -1, H * D)
key_proj = key_proj.reshape(B, -1, H * D)
value_proj = value_proj.reshape(B, -1, H * D)

query_proj = query_proj.transpose(0, 2, 1, 3).reshape(query_proj.shape[0], query_proj.shape[2], -1)
key_proj = key_proj.transpose(0, 2, 1, 3).reshape(key_proj.shape[0], key_proj.shape[2], -1)
value_proj = value_proj.transpose(0, 2, 1, 3).reshape(value_proj.shape[0], value_proj.shape[2], -1)
if encoder_hidden_states is not None:
query_proj = nn.with_logical_constraint(query_proj, self.query_axis_names)
key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names)
value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names)

attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj, attention_mask=attention_mask)
context_attn_output = None
Expand Down
Loading
Loading