refactor Flux transformer to use scanned blocks, dynamic checkpointing, and decoupled projections#417
refactor Flux transformer to use scanned blocks, dynamic checkpointing, and decoupled projections#417prishajain1 wants to merge 1 commit into
Conversation
6dfd5ea to
4696256
Compare
…ng, and weight loading improvements
4696256 to
11ddfef
Compare
|
🤖 Hi @prishajain1, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @prishajain1, but I was unable to process your request. Please see the logs for more details. |
| names_which_can_be_saved=self.config.names_which_can_be_saved, | ||
| names_which_can_be_offloaded=self.config.names_which_can_be_offloaded, |
There was a problem hiding this comment.
Can we make these variable names a little more explicit? Maybe something like saved_transformer_layer_names or savable_transformer_layer_names and offloaded_transformer_layer_names or offloadable_transformer_layer_names? I will leave it upto you
| names_which_can_be_saved=self.config.names_which_can_be_saved, | ||
| names_which_can_be_offloaded=self.config.names_which_can_be_offloaded, | ||
| ) |
| names_which_can_be_saved: [] | ||
| names_which_can_be_offloaded: [] |
There was a problem hiding this comment.
Can we rename the file to transformer_flux.py?
| r""" | ||
| A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. | ||
|
|
||
| Reference: https://arxiv.org/abs/2403.03206 | ||
|
|
||
| Parameters: | ||
| dim (`int`): The number of channels in the input and output. | ||
| num_attention_heads (`int`): The number of heads to use for multi-head attention. | ||
| attention_head_dim (`int`): The number of channels in each head. | ||
| context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the | ||
| processing of `context` conditions. | ||
| """ | ||
|
|
There was a problem hiding this comment.
Let's keep the comment block. Update if required
| vec_shape = ( | ||
| batch_size, | ||
| 768, # Sequence length of clip, how to get this programmatically? | ||
| 768, |
| r""" | ||
| Norm layer adaptive layer norm zero (adaLN-Zero). | ||
|
|
||
| Parameters: | ||
| embedding_dim (`int`): The size of each embedding vector. | ||
| num_embeddings (`int`): The size of the embeddings dictionary. | ||
| """ | ||
|
|
There was a problem hiding this comment.
Let's keep the comment block
| r""" | ||
| Norm layer adaptive layer norm zero (adaLN-Zero). | ||
|
|
||
| Parameters: | ||
| embedding_dim (`int`): The size of each embedding vector. | ||
| num_embeddings (`int`): The size of the embeddings dictionary. | ||
| """ | ||
|
|
There was a problem hiding this comment.
Let's keep the comment block
|
|
||
| x = (x - mean) * inv_std * (1.0 + scale_msa) + shift_msa | ||
| else: | ||
| raise ValueError(f"Unsupported `norm_type` ({self.norm_type}) provided. Supported ones are: 'layer_norm'.") |
There was a problem hiding this comment.
Is there a need to change this text comment?
| names_which_can_be_saved=config.names_which_can_be_saved, | ||
| names_which_can_be_offloaded=config.names_which_can_be_offloaded, |
There was a problem hiding this comment.
Please change variable names
Overview
This PR refactors the Flux model architecture in MaxDiffusion to support scanned blocks (nn.scan) for double and single blocks, implements configurable gradient checkpointing (rematerialization) policies, and updates the weights loader to support loading pretrained checkpoints under the scanned format.
Key Changes
MlpAndOutputBlockwrapper) to eliminate redundant recomputation of attention and projection outputs.jnp.splitacross Flux transformer blocks for cleaner layout constraints.nn.scanto optimize compiler tracing and step execution speed on TPUs.FLUX_OPTIMIZEDtoGradientCheckpointTypeto allow configuring block-specific rematerialization policies dynamically via configuration files instead of being hardcoded.util.py) to slice, group, and stack PyTorch checkpoint weights along axis 0 to match the expected format ofnn.scanlayers.