Skip to content

refactor Flux transformer to use scanned blocks, dynamic checkpointing, and decoupled projections#417

Open
prishajain1 wants to merge 1 commit into
mainfrom
prisha/flux_training
Open

refactor Flux transformer to use scanned blocks, dynamic checkpointing, and decoupled projections#417
prishajain1 wants to merge 1 commit into
mainfrom
prisha/flux_training

Conversation

@prishajain1

Copy link
Copy Markdown
Collaborator

Overview

This PR refactors the Flux model architecture in MaxDiffusion to support scanned blocks (nn.scan) for double and single blocks, implements configurable gradient checkpointing (rematerialization) policies, and updates the weights loader to support loading pretrained checkpoints under the scanned format.

Key Changes

  • Decoupled Fused Projections: Decoupled the projection layers (implementing the MlpAndOutputBlock wrapper) to eliminate redundant recomputation of attention and projection outputs.
  • QKV Slicing Refactoring: Refactored the QKV projection slicing logic to use jnp.split across Flux transformer blocks for cleaner layout constraints.
  • Scanned Block Architecture: Migrated Flux Double and Single Transformer Blocks to use nn.scan to optimize compiler tracing and step execution speed on TPUs.
  • Dynamic Gradient Checkpointing: Added FLUX_OPTIMIZED to GradientCheckpointType to allow configuring block-specific rematerialization policies dynamically via configuration files instead of being hardcoded.
  • Stacked Weights Loading: Updated the weights loader (util.py) to slice, group, and stack PyTorch checkpoint weights along axis 0 to match the expected format of nn.scan layers.

@prishajain1 prishajain1 requested a review from entrpn as a code owner June 12, 2026 06:20
@github-actions

Copy link
Copy Markdown

@prishajain1 prishajain1 marked this pull request as draft June 12, 2026 06:20
@prishajain1 prishajain1 force-pushed the prisha/flux_training branch from 6dfd5ea to 4696256 Compare June 12, 2026 06:22
@prishajain1 prishajain1 force-pushed the prisha/flux_training branch from 4696256 to 11ddfef Compare June 12, 2026 06:29
@prishajain1 prishajain1 marked this pull request as ready for review June 12, 2026 06:31
@github-actions

Copy link
Copy Markdown

🤖 Hi @prishajain1, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @prishajain1, but I was unable to process your request. Please see the logs for more details.

Comment on lines +220 to +221
names_which_can_be_saved=self.config.names_which_can_be_saved,
names_which_can_be_offloaded=self.config.names_which_can_be_offloaded,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make these variable names a little more explicit? Maybe something like saved_transformer_layer_names or savable_transformer_layer_names and offloaded_transformer_layer_names or offloadable_transformer_layer_names? I will leave it upto you

Comment on lines +290 to 292
names_which_can_be_saved=self.config.names_which_can_be_saved,
names_which_can_be_offloaded=self.config.names_which_can_be_offloaded,
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

Comment on lines +236 to +237
names_which_can_be_saved: []
names_which_can_be_offloaded: []

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename the file to transformer_flux.py?

Comment on lines -154 to -166
r"""
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.

Reference: https://arxiv.org/abs/2403.03206

Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
processing of `context` conditions.
"""

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep the comment block. Update if required

vec_shape = (
batch_size,
768, # Sequence length of clip, how to get this programmatically?
768,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

Comment on lines -54 to -61
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).

Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep the comment block

Comment on lines -104 to -111
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).

Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep the comment block


x = (x - mean) * inv_std * (1.0 + scale_msa) + shift_msa
else:
raise ValueError(f"Unsupported `norm_type` ({self.norm_type}) provided. Supported ones are: 'layer_norm'.")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a need to change this text comment?

Comment on lines +318 to +319
names_which_can_be_saved=config.names_which_can_be_saved,
names_which_can_be_offloaded=config.names_which_can_be_offloaded,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change variable names

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants