From 61b6dd6d4fd4910d8869a7a5d74225351aa102b7 Mon Sep 17 00:00:00 2001 From: Alexandre Galashov Date: Tue, 24 Mar 2026 15:34:52 -0700 Subject: [PATCH] Add cross attention to DiT PiperOrigin-RevId: 888888559 --- .../lib/architecture/arch_typing.py | 1 + hackable_diffusion/lib/architecture/dit.py | 19 +- .../lib/architecture/dit_blocks.py | 52 ++++- .../lib/architecture/dit_blocks_test.py | 94 ++++++++ .../lib/architecture/dit_test.py | 202 ++++++++++++++++++ 5 files changed, 360 insertions(+), 8 deletions(-) diff --git a/hackable_diffusion/lib/architecture/arch_typing.py b/hackable_diffusion/lib/architecture/arch_typing.py index 4cd639c..54f8c6f 100644 --- a/hackable_diffusion/lib/architecture/arch_typing.py +++ b/hackable_diffusion/lib/architecture/arch_typing.py @@ -57,6 +57,7 @@ class ConditioningMechanism(enum.StrEnum): ADAPTIVE_NORM = "adaptive_norm" CROSS_ATTENTION = "cross_attention" + CROSS_ATTENTION_MASK = "cross_attention_mask" CONCATENATE = "concatenate" SUM = "sum" CUSTOM = "custom" diff --git a/hackable_diffusion/lib/architecture/dit.py b/hackable_diffusion/lib/architecture/dit.py index 74511f0..948bb5e 100644 --- a/hackable_diffusion/lib/architecture/dit.py +++ b/hackable_diffusion/lib/architecture/dit.py @@ -109,6 +109,18 @@ def __call__( if adaptive_norm_emb is None: raise ValueError("adaptive_norm_emb must be provided.") + # Extract conditioning embeddings to use with cross attention. + cross_attention_emb = conditioning_embeddings.get( + ConditioningMechanism.CROSS_ATTENTION + ) + if cross_attention_emb is not None and cross_attention_emb.ndim == 2: + cross_attention_emb = cross_attention_emb[:, jnp.newaxis, :] + + # Extract cross-attention mask. + cross_attention_mask = conditioning_embeddings.get( + ConditioningMechanism.CROSS_ATTENTION_MASK + ) + # TODO(agalashov): This assumes that x is already tokenized, which is not # true for images. if self.use_padding_mask: @@ -126,7 +138,12 @@ def __call__( cond = adaptive_norm_emb for i in range(1, self.num_blocks + 1): tokens_emb = self.block.copy(name=f"Block_{i}")( - tokens_emb, cond, is_training=is_training, mask=padding_mask + tokens_emb, + cond, + is_training=is_training, + mask=padding_mask, + cross_cond=cross_attention_emb, + cross_mask=cross_attention_mask, ) tokens_emb = self.conditional_norm(tokens_emb, c=nn.silu(cond)) diff --git a/hackable_diffusion/lib/architecture/dit_blocks.py b/hackable_diffusion/lib/architecture/dit_blocks.py index 2d02f49..9f8cce1 100644 --- a/hackable_diffusion/lib/architecture/dit_blocks.py +++ b/hackable_diffusion/lib/architecture/dit_blocks.py @@ -21,8 +21,8 @@ from hackable_diffusion.lib.architecture import attention from hackable_diffusion.lib.architecture import mlp_blocks from hackable_diffusion.lib.architecture import normalization +from hackable_diffusion.lib.hd_typing import typechecked # pylint: disable=g-multiple-import,g-importing-member import jax.numpy as jnp -import kauldron.ktyping as kt ################################################################################ # MARK: Type Aliases @@ -50,7 +50,7 @@ class PositionalEmbedding(nn.Module): init_stddev: float = 0.02 @nn.compact - @kt.typechecked + @typechecked def __call__(self, x: Num["batch *data_shape"]) -> Float["batch *data_shape"]: pos_embed = self.param( "PositionalEmbeddingTensor", @@ -134,25 +134,49 @@ def setup(self): use_scale=False, ).conditional_norm_factory() - @kt.typechecked + # Cross-Attention Module + self.cross_attn = attention.MultiHeadAttention( + num_heads=self.num_heads, + head_dim=self.head_dim, + # Note: We typically do NOT use RoPE for cross-attention because + # the spatial relationship between text and SMILES is not 1-to-1 linear. + use_rope=False, + zero_init_output=False, + dtype=self.dtype, + normalize_qk=True, + ) + self.gate_cross = nn.Dense( + self.hidden_size, + kernel_init=nn.initializers.zeros_init(), + bias_init=nn.initializers.zeros_init(), + name="Dense_Gate_Cross", + ) + + @typechecked @nn.compact def __call__( self, x: Float["*batch seq_dim emb_dim"], - cond: Float["*#batch cond_dim"], + cond: Float["*batch cond_dim"], *, is_training: bool, mask: Bool["batch seq_dim"] | None = None, + cross_cond: Float["*batch seq_cross cross_cond_dim"] | None = None, + cross_mask: Bool["batch seq_cross"] | None = None, ) -> Float["*batch seq_dim emb_dim"]: """Calls the DiT block. Args: x: The input tensor. - cond: The conditioning tensor. + cond: The conditioning tensor for adaptive normalization. is_training: Whether the block is in training mode. mask: The self-attention padding mask. If the mask is provided, it is assumed that the input sequence contains padding tokens that should be masked out when computing the self-attention. + cross_cond: The cross attention conditioning tensor. + cross_mask: The cross-attention padding mask. If the mask is provided, it + is assumed that the input sequence contains padding tokens that should + be masked out when computing the cross-attention. Returns: The output tensor. @@ -170,6 +194,20 @@ def __call__( # Add a sequence dimension [...,None,:] to broadcast to [*batch,seq,dim]. x = x + gate_msa[..., None, :] * attn_out + # 2. Cross-Attention Branch (NEW) + if cross_cond is not None: + x_cross_modulated = self.conditional_norm(x, c=nn.silu(cond)) + # Pass cross_cond as `c`, and cross_mask as `mask` (which applies to the keys/cross_cond) + cross_out = self.cross_attn( + x_cross_modulated, c=cross_cond, mask=cross_mask + ) + if self.dropout_rate > 0.0: + cross_out = nn.Dropout(rate=self.dropout_rate)( + cross_out, deterministic=not is_training + ) + gate_cross = self.gate_cross(nn.silu(cond)) + x = x + gate_cross[..., None, :] * cross_out + # MLP Branch x_mlp_modulated = self.conditional_norm(x, c=nn.silu(cond)) mlp_out = self.mlp(x_mlp_modulated, is_training=is_training) @@ -204,7 +242,7 @@ class Patchify(nn.Module): embedding_dim: int @nn.compact - @kt.typechecked + @typechecked def __call__( self, x: Float["*batch height width channels"] ) -> Float["*batch seq_dim emb_dim"]: @@ -258,7 +296,7 @@ def setup(self): ).conditional_norm_factory() @nn.compact - @kt.typechecked + @typechecked def __call__( self, x: Float["*batch seq_dim emb_dim"], cond: Float["*#batch cond_dim"] ) -> Float["*batch height width channels"]: diff --git a/hackable_diffusion/lib/architecture/dit_blocks_test.py b/hackable_diffusion/lib/architecture/dit_blocks_test.py index 7e67e59..d7a58d2 100644 --- a/hackable_diffusion/lib/architecture/dit_blocks_test.py +++ b/hackable_diffusion/lib/architecture/dit_blocks_test.py @@ -118,6 +118,100 @@ def test_zero_init_is_identity(self): output = module.apply(variables, x, cond, is_training=False) self.assertTrue(jnp.allclose(output, x, atol=1e-5)) + def test_output_shape_with_cross_cond(self): + input_shape = (self.batch, self.n, self.d) + cond_shape = (self.batch, self.c) + cross_cond_shape = (self.batch, 8, self.d) + x = jnp.ones(input_shape) + cond = jnp.ones(cond_shape) + cross_cond = jnp.ones(cross_cond_shape) + module = dit_blocks.DiTBlockAdaLNZero(hidden_size=self.d, num_heads=4) + variables = module.init( + self.key, x, cond, is_training=False, cross_cond=cross_cond + ) + output = module.apply( + variables, x, cond, is_training=False, cross_cond=cross_cond + ) + self.assertEqual(output.shape, input_shape) + + def test_variable_shapes_with_cross_cond(self): + input_shape = (self.batch, self.n, self.d) + cond_shape = (self.batch, self.c) + cross_cond_shape = (self.batch, 8, self.d) + x = jnp.ones(input_shape) + cond = jnp.ones(cond_shape) + cross_cond = jnp.ones(cross_cond_shape) + mlp_hidden = int(self.d * 4.0) + module = dit_blocks.DiTBlockAdaLNZero(hidden_size=self.d, num_heads=4) + variables = module.init( + self.key, x, cond, is_training=False, cross_cond=cross_cond + ) + variables_shapes = test_utils.get_pytree_shapes(variables) + + expected_variables_shapes = { + 'params': { + 'Dense_Gate_MSA': { + 'kernel': (self.c, self.d), + 'bias': (self.d,), + }, + 'Dense_Gate_MLP': { + 'kernel': (self.c, self.d), + 'bias': (self.d,), + }, + 'Dense_Gate_Cross': { + 'kernel': (self.c, self.d), + 'bias': (self.d,), + }, + 'ConditionalNorm': { + 'Dense_0': { + 'kernel': (self.c, self.d * 2), + 'bias': (self.d * 2,), + }, + }, + 'MLP': { + 'Dense_Hidden_0': { + 'kernel': (self.d, mlp_hidden), + 'bias': (mlp_hidden,), + }, + 'Dense_Output': { + 'kernel': (mlp_hidden, self.d), + 'bias': (self.d,), + }, + }, + 'attn': { + 'Dense_Q': {'kernel': (self.d, self.d), 'bias': (self.d,)}, + 'Dense_K': {'kernel': (self.d, self.d), 'bias': (self.d,)}, + 'Dense_V': {'kernel': (self.d, self.d), 'bias': (self.d,)}, + 'Dense_Output': {'kernel': (self.d, self.d), 'bias': (self.d,)}, + 'norm_qk_scale': (1, 1, 1, 1), + }, + 'cross_attn': { + 'Dense_Q': {'kernel': (self.d, self.d), 'bias': (self.d,)}, + 'Dense_K': {'kernel': (self.d, self.d), 'bias': (self.d,)}, + 'Dense_V': {'kernel': (self.d, self.d), 'bias': (self.d,)}, + 'Dense_Output': {'kernel': (self.d, self.d), 'bias': (self.d,)}, + 'norm_qk_scale': (1, 1, 1, 1), + }, + } + } + self.assertDictEqual(expected_variables_shapes, variables_shapes) + + def test_zero_init_is_identity_with_cross_cond(self): + input_shape = (self.batch, self.n, self.d) + cond_shape = (self.batch, self.c) + cross_cond_shape = (self.batch, 8, self.d) + x = jax.random.normal(self.key, input_shape) + cond = jnp.zeros(cond_shape) + cross_cond = jax.random.normal(self.key, cross_cond_shape) + module = dit_blocks.DiTBlockAdaLNZero(hidden_size=self.d, num_heads=4) + variables = module.init( + self.key, x, cond, is_training=False, cross_cond=cross_cond + ) + output = module.apply( + variables, x, cond, is_training=False, cross_cond=cross_cond + ) + self.assertTrue(jnp.allclose(output, x, atol=1e-5)) + class PositionalEmbeddingTest(parameterized.TestCase): diff --git a/hackable_diffusion/lib/architecture/dit_test.py b/hackable_diffusion/lib/architecture/dit_test.py index 2f3eb70..76058fd 100644 --- a/hackable_diffusion/lib/architecture/dit_test.py +++ b/hackable_diffusion/lib/architecture/dit_test.py @@ -256,6 +256,208 @@ def test_missing_adaptive_norm_raises(self): is_training=self.is_training, ) + def test_output_shape_with_cross_attention(self): + """Verifies output shape when cross-attention conditioning is provided.""" + cross_seq_len = 10 + cross_dim = 24 + input_shape = (self.batch_size, self.sequence_length, self.embedding_dim) + x = jnp.ones(input_shape) + conditioning_embeddings = { + ConditioningMechanism.ADAPTIVE_NORM: jnp.ones( + (self.batch_size, self.cond_dim) + ), + ConditioningMechanism.CROSS_ATTENTION: jnp.ones( + (self.batch_size, cross_seq_len, cross_dim) + ), + } + model = dit.DiT( + num_blocks=2, + block=dit_blocks.DiTBlockAdaLNZero( + hidden_size=self.embedding_dim, num_heads=4 + ), + ) + variables = model.init( + self.key, + x=x, + conditioning_embeddings=conditioning_embeddings, + is_training=self.is_training, + ) + output = model.apply( + variables, + x=x, + conditioning_embeddings=conditioning_embeddings, + is_training=self.is_training, + ) + self.assertEqual(output.shape, input_shape) + + def test_output_shape_with_cross_attention_and_mask(self): + """Verifies output shape when both cross-attention and mask are provided.""" + cross_seq_len = 10 + cross_dim = 24 + input_shape = (self.batch_size, self.sequence_length, self.embedding_dim) + x = jnp.ones(input_shape) + conditioning_embeddings = { + ConditioningMechanism.ADAPTIVE_NORM: jnp.ones( + (self.batch_size, self.cond_dim) + ), + ConditioningMechanism.CROSS_ATTENTION: jnp.ones( + (self.batch_size, cross_seq_len, cross_dim) + ), + ConditioningMechanism.CROSS_ATTENTION_MASK: jnp.ones( + (self.batch_size, cross_seq_len), dtype=jnp.bool_ + ), + } + model = dit.DiT( + num_blocks=2, + block=dit_blocks.DiTBlockAdaLNZero( + hidden_size=self.embedding_dim, num_heads=4 + ), + ) + variables = model.init( + self.key, + x=x, + conditioning_embeddings=conditioning_embeddings, + is_training=self.is_training, + ) + output = model.apply( + variables, + x=x, + conditioning_embeddings=conditioning_embeddings, + is_training=self.is_training, + ) + self.assertEqual(output.shape, input_shape) + + def test_variable_shapes_with_cross_attention(self): + """Verifies that cross-attention params are correctly initialized.""" + cross_seq_len = 10 + cross_dim = 24 + input_shape = (self.batch_size, self.sequence_length, self.embedding_dim) + x = jnp.ones(input_shape) + conditioning_embeddings = { + ConditioningMechanism.ADAPTIVE_NORM: jnp.ones( + (self.batch_size, self.cond_dim) + ), + ConditioningMechanism.CROSS_ATTENTION: jnp.ones( + (self.batch_size, cross_seq_len, cross_dim) + ), + } + model = dit.DiT( + num_blocks=1, + block=dit_blocks.DiTBlockAdaLNZero( + hidden_size=self.embedding_dim, num_heads=4 + ), + ) + variables = model.init( + self.key, + x=x, + conditioning_embeddings=conditioning_embeddings, + is_training=self.is_training, + ) + variables_shapes = test_utils.get_pytree_shapes(variables) + block_params = variables_shapes['params']['Block_1'] + + # Cross-attention gate should be present. + self.assertIn('Dense_Gate_Cross', block_params) + self.assertEqual( + block_params['Dense_Gate_Cross']['kernel'], + (self.cond_dim, self.embedding_dim), + ) + + # Cross-attention module should be present with correct shapes. + self.assertIn('cross_attn', block_params) + cross_attn = block_params['cross_attn'] + # Q projects from embedding_dim + self.assertEqual( + cross_attn['Dense_Q']['kernel'], + (self.embedding_dim, self.embedding_dim), + ) + # K and V project from cross_dim + self.assertEqual( + cross_attn['Dense_K']['kernel'], + (cross_dim, self.embedding_dim), + ) + self.assertEqual( + cross_attn['Dense_V']['kernel'], + (cross_dim, self.embedding_dim), + ) + + def test_cross_attention_mask_zeros_out_tokens(self): + """Verifies that masking all cross-attention tokens changes the output.""" + cross_seq_len = 10 + cross_dim = 24 + input_shape = (self.batch_size, self.sequence_length, self.embedding_dim) + x = jax.random.normal(self.key, input_shape) + adaptive_norm = jax.random.normal( + jax.random.PRNGKey(2), (self.batch_size, self.cond_dim) + ) + cross_emb = jax.random.normal( + jax.random.PRNGKey(1), (self.batch_size, cross_seq_len, cross_dim) + ) + + model = dit.DiT( + num_blocks=1, + block=dit_blocks.DiTBlockAdaLNZero( + hidden_size=self.embedding_dim, num_heads=4 + ), + ) + + # Init with cross-attention to get correct params. + conditioning_with_cross = { + ConditioningMechanism.ADAPTIVE_NORM: adaptive_norm, + ConditioningMechanism.CROSS_ATTENTION: cross_emb, + ConditioningMechanism.CROSS_ATTENTION_MASK: jnp.ones( + (self.batch_size, cross_seq_len), dtype=jnp.bool_ + ), + } + variables = model.init( + self.key, + x=x, + conditioning_embeddings=conditioning_with_cross, + is_training=False, + ) + + # The cross-attention gate (Dense_Gate_Cross) is zero-initialized, so + # replace it with ones so the gate is active and masking has an effect. + params = variables['params'] + gate_cross = params['Block_1']['Dense_Gate_Cross'] + gate_cross = jax.tree.map(jnp.ones_like, gate_cross) + params['Block_1']['Dense_Gate_Cross'] = gate_cross + variables = {'params': params} + + # Run with all tokens masked out (False = masked). + conditioning_all_masked = { + ConditioningMechanism.ADAPTIVE_NORM: adaptive_norm, + ConditioningMechanism.CROSS_ATTENTION: cross_emb, + ConditioningMechanism.CROSS_ATTENTION_MASK: jnp.zeros( + (self.batch_size, cross_seq_len), dtype=jnp.bool_ + ), + } + # Run with all tokens unmasked. + conditioning_all_unmasked = { + ConditioningMechanism.ADAPTIVE_NORM: adaptive_norm, + ConditioningMechanism.CROSS_ATTENTION: cross_emb, + ConditioningMechanism.CROSS_ATTENTION_MASK: jnp.ones( + (self.batch_size, cross_seq_len), dtype=jnp.bool_ + ), + } + + output_all_masked = model.apply( + variables, + x=x, + conditioning_embeddings=conditioning_all_masked, + is_training=False, + ) + output_all_unmasked = model.apply( + variables, + x=x, + conditioning_embeddings=conditioning_all_unmasked, + is_training=False, + ) + + # The two outputs should differ since masking changes cross-attention. + self.assertFalse(jnp.allclose(output_all_masked, output_all_unmasked)) + if __name__ == '__main__': absltest.main() +