From 63c72323ba6763e09c8163e754b9729b49f9a486 Mon Sep 17 00:00:00 2001 From: Alexandre Galashov Date: Tue, 24 Mar 2026 09:49:43 -0700 Subject: [PATCH] Add cross attention to DiT blocks PiperOrigin-RevId: 888717187 --- .../lib/architecture/dit_blocks.py | 52 ++++++++-- .../lib/architecture/dit_blocks_test.py | 94 +++++++++++++++++++ 2 files changed, 139 insertions(+), 7 deletions(-) diff --git a/hackable_diffusion/lib/architecture/dit_blocks.py b/hackable_diffusion/lib/architecture/dit_blocks.py index 2d02f49..9f8cce1 100644 --- a/hackable_diffusion/lib/architecture/dit_blocks.py +++ b/hackable_diffusion/lib/architecture/dit_blocks.py @@ -21,8 +21,8 @@ from hackable_diffusion.lib.architecture import attention from hackable_diffusion.lib.architecture import mlp_blocks from hackable_diffusion.lib.architecture import normalization +from hackable_diffusion.lib.hd_typing import typechecked # pylint: disable=g-multiple-import,g-importing-member import jax.numpy as jnp -import kauldron.ktyping as kt ################################################################################ # MARK: Type Aliases @@ -50,7 +50,7 @@ class PositionalEmbedding(nn.Module): init_stddev: float = 0.02 @nn.compact - @kt.typechecked + @typechecked def __call__(self, x: Num["batch *data_shape"]) -> Float["batch *data_shape"]: pos_embed = self.param( "PositionalEmbeddingTensor", @@ -134,25 +134,49 @@ def setup(self): use_scale=False, ).conditional_norm_factory() - @kt.typechecked + # Cross-Attention Module + self.cross_attn = attention.MultiHeadAttention( + num_heads=self.num_heads, + head_dim=self.head_dim, + # Note: We typically do NOT use RoPE for cross-attention because + # the spatial relationship between text and SMILES is not 1-to-1 linear. + use_rope=False, + zero_init_output=False, + dtype=self.dtype, + normalize_qk=True, + ) + self.gate_cross = nn.Dense( + self.hidden_size, + kernel_init=nn.initializers.zeros_init(), + bias_init=nn.initializers.zeros_init(), + name="Dense_Gate_Cross", + ) + + @typechecked @nn.compact def __call__( self, x: Float["*batch seq_dim emb_dim"], - cond: Float["*#batch cond_dim"], + cond: Float["*batch cond_dim"], *, is_training: bool, mask: Bool["batch seq_dim"] | None = None, + cross_cond: Float["*batch seq_cross cross_cond_dim"] | None = None, + cross_mask: Bool["batch seq_cross"] | None = None, ) -> Float["*batch seq_dim emb_dim"]: """Calls the DiT block. Args: x: The input tensor. - cond: The conditioning tensor. + cond: The conditioning tensor for adaptive normalization. is_training: Whether the block is in training mode. mask: The self-attention padding mask. If the mask is provided, it is assumed that the input sequence contains padding tokens that should be masked out when computing the self-attention. + cross_cond: The cross attention conditioning tensor. + cross_mask: The cross-attention padding mask. If the mask is provided, it + is assumed that the input sequence contains padding tokens that should + be masked out when computing the cross-attention. Returns: The output tensor. @@ -170,6 +194,20 @@ def __call__( # Add a sequence dimension [...,None,:] to broadcast to [*batch,seq,dim]. x = x + gate_msa[..., None, :] * attn_out + # 2. Cross-Attention Branch (NEW) + if cross_cond is not None: + x_cross_modulated = self.conditional_norm(x, c=nn.silu(cond)) + # Pass cross_cond as `c`, and cross_mask as `mask` (which applies to the keys/cross_cond) + cross_out = self.cross_attn( + x_cross_modulated, c=cross_cond, mask=cross_mask + ) + if self.dropout_rate > 0.0: + cross_out = nn.Dropout(rate=self.dropout_rate)( + cross_out, deterministic=not is_training + ) + gate_cross = self.gate_cross(nn.silu(cond)) + x = x + gate_cross[..., None, :] * cross_out + # MLP Branch x_mlp_modulated = self.conditional_norm(x, c=nn.silu(cond)) mlp_out = self.mlp(x_mlp_modulated, is_training=is_training) @@ -204,7 +242,7 @@ class Patchify(nn.Module): embedding_dim: int @nn.compact - @kt.typechecked + @typechecked def __call__( self, x: Float["*batch height width channels"] ) -> Float["*batch seq_dim emb_dim"]: @@ -258,7 +296,7 @@ def setup(self): ).conditional_norm_factory() @nn.compact - @kt.typechecked + @typechecked def __call__( self, x: Float["*batch seq_dim emb_dim"], cond: Float["*#batch cond_dim"] ) -> Float["*batch height width channels"]: diff --git a/hackable_diffusion/lib/architecture/dit_blocks_test.py b/hackable_diffusion/lib/architecture/dit_blocks_test.py index 7e67e59..d7a58d2 100644 --- a/hackable_diffusion/lib/architecture/dit_blocks_test.py +++ b/hackable_diffusion/lib/architecture/dit_blocks_test.py @@ -118,6 +118,100 @@ def test_zero_init_is_identity(self): output = module.apply(variables, x, cond, is_training=False) self.assertTrue(jnp.allclose(output, x, atol=1e-5)) + def test_output_shape_with_cross_cond(self): + input_shape = (self.batch, self.n, self.d) + cond_shape = (self.batch, self.c) + cross_cond_shape = (self.batch, 8, self.d) + x = jnp.ones(input_shape) + cond = jnp.ones(cond_shape) + cross_cond = jnp.ones(cross_cond_shape) + module = dit_blocks.DiTBlockAdaLNZero(hidden_size=self.d, num_heads=4) + variables = module.init( + self.key, x, cond, is_training=False, cross_cond=cross_cond + ) + output = module.apply( + variables, x, cond, is_training=False, cross_cond=cross_cond + ) + self.assertEqual(output.shape, input_shape) + + def test_variable_shapes_with_cross_cond(self): + input_shape = (self.batch, self.n, self.d) + cond_shape = (self.batch, self.c) + cross_cond_shape = (self.batch, 8, self.d) + x = jnp.ones(input_shape) + cond = jnp.ones(cond_shape) + cross_cond = jnp.ones(cross_cond_shape) + mlp_hidden = int(self.d * 4.0) + module = dit_blocks.DiTBlockAdaLNZero(hidden_size=self.d, num_heads=4) + variables = module.init( + self.key, x, cond, is_training=False, cross_cond=cross_cond + ) + variables_shapes = test_utils.get_pytree_shapes(variables) + + expected_variables_shapes = { + 'params': { + 'Dense_Gate_MSA': { + 'kernel': (self.c, self.d), + 'bias': (self.d,), + }, + 'Dense_Gate_MLP': { + 'kernel': (self.c, self.d), + 'bias': (self.d,), + }, + 'Dense_Gate_Cross': { + 'kernel': (self.c, self.d), + 'bias': (self.d,), + }, + 'ConditionalNorm': { + 'Dense_0': { + 'kernel': (self.c, self.d * 2), + 'bias': (self.d * 2,), + }, + }, + 'MLP': { + 'Dense_Hidden_0': { + 'kernel': (self.d, mlp_hidden), + 'bias': (mlp_hidden,), + }, + 'Dense_Output': { + 'kernel': (mlp_hidden, self.d), + 'bias': (self.d,), + }, + }, + 'attn': { + 'Dense_Q': {'kernel': (self.d, self.d), 'bias': (self.d,)}, + 'Dense_K': {'kernel': (self.d, self.d), 'bias': (self.d,)}, + 'Dense_V': {'kernel': (self.d, self.d), 'bias': (self.d,)}, + 'Dense_Output': {'kernel': (self.d, self.d), 'bias': (self.d,)}, + 'norm_qk_scale': (1, 1, 1, 1), + }, + 'cross_attn': { + 'Dense_Q': {'kernel': (self.d, self.d), 'bias': (self.d,)}, + 'Dense_K': {'kernel': (self.d, self.d), 'bias': (self.d,)}, + 'Dense_V': {'kernel': (self.d, self.d), 'bias': (self.d,)}, + 'Dense_Output': {'kernel': (self.d, self.d), 'bias': (self.d,)}, + 'norm_qk_scale': (1, 1, 1, 1), + }, + } + } + self.assertDictEqual(expected_variables_shapes, variables_shapes) + + def test_zero_init_is_identity_with_cross_cond(self): + input_shape = (self.batch, self.n, self.d) + cond_shape = (self.batch, self.c) + cross_cond_shape = (self.batch, 8, self.d) + x = jax.random.normal(self.key, input_shape) + cond = jnp.zeros(cond_shape) + cross_cond = jax.random.normal(self.key, cross_cond_shape) + module = dit_blocks.DiTBlockAdaLNZero(hidden_size=self.d, num_heads=4) + variables = module.init( + self.key, x, cond, is_training=False, cross_cond=cross_cond + ) + output = module.apply( + variables, x, cond, is_training=False, cross_cond=cross_cond + ) + self.assertTrue(jnp.allclose(output, x, atol=1e-5)) + class PositionalEmbeddingTest(parameterized.TestCase):