Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 45 additions & 7 deletions hackable_diffusion/lib/architecture/dit_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from hackable_diffusion.lib.architecture import attention
from hackable_diffusion.lib.architecture import mlp_blocks
from hackable_diffusion.lib.architecture import normalization
from hackable_diffusion.lib.hd_typing import typechecked # pylint: disable=g-multiple-import,g-importing-member
import jax.numpy as jnp
import kauldron.ktyping as kt

################################################################################
# MARK: Type Aliases
Expand Down Expand Up @@ -50,7 +50,7 @@ class PositionalEmbedding(nn.Module):
init_stddev: float = 0.02

@nn.compact
@kt.typechecked
@typechecked
def __call__(self, x: Num["batch *data_shape"]) -> Float["batch *data_shape"]:
pos_embed = self.param(
"PositionalEmbeddingTensor",
Expand Down Expand Up @@ -134,25 +134,49 @@ def setup(self):
use_scale=False,
).conditional_norm_factory()

@kt.typechecked
# Cross-Attention Module
self.cross_attn = attention.MultiHeadAttention(
num_heads=self.num_heads,
head_dim=self.head_dim,
# Note: We typically do NOT use RoPE for cross-attention because
# the spatial relationship between text and SMILES is not 1-to-1 linear.
use_rope=False,
zero_init_output=False,
dtype=self.dtype,
normalize_qk=True,
)
self.gate_cross = nn.Dense(
self.hidden_size,
kernel_init=nn.initializers.zeros_init(),
bias_init=nn.initializers.zeros_init(),
name="Dense_Gate_Cross",
)

@typechecked
@nn.compact
def __call__(
self,
x: Float["*batch seq_dim emb_dim"],
cond: Float["*#batch cond_dim"],
cond: Float["*batch cond_dim"],
*,
is_training: bool,
mask: Bool["batch seq_dim"] | None = None,
cross_cond: Float["*batch seq_cross cross_cond_dim"] | None = None,
cross_mask: Bool["batch seq_cross"] | None = None,
) -> Float["*batch seq_dim emb_dim"]:
"""Calls the DiT block.

Args:
x: The input tensor.
cond: The conditioning tensor.
cond: The conditioning tensor for adaptive normalization.
is_training: Whether the block is in training mode.
mask: The self-attention padding mask. If the mask is provided, it is
assumed that the input sequence contains padding tokens that should be
masked out when computing the self-attention.
cross_cond: The cross attention conditioning tensor.
cross_mask: The cross-attention padding mask. If the mask is provided, it
is assumed that the input sequence contains padding tokens that should
be masked out when computing the cross-attention.

Returns:
The output tensor.
Expand All @@ -170,6 +194,20 @@ def __call__(
# Add a sequence dimension [...,None,:] to broadcast to [*batch,seq,dim].
x = x + gate_msa[..., None, :] * attn_out

# 2. Cross-Attention Branch (NEW)
if cross_cond is not None:
x_cross_modulated = self.conditional_norm(x, c=nn.silu(cond))
# Pass cross_cond as `c`, and cross_mask as `mask` (which applies to the keys/cross_cond)
cross_out = self.cross_attn(
x_cross_modulated, c=cross_cond, mask=cross_mask
)
if self.dropout_rate > 0.0:
cross_out = nn.Dropout(rate=self.dropout_rate)(
cross_out, deterministic=not is_training
)
gate_cross = self.gate_cross(nn.silu(cond))
x = x + gate_cross[..., None, :] * cross_out

# MLP Branch
x_mlp_modulated = self.conditional_norm(x, c=nn.silu(cond))
mlp_out = self.mlp(x_mlp_modulated, is_training=is_training)
Expand Down Expand Up @@ -204,7 +242,7 @@ class Patchify(nn.Module):
embedding_dim: int

@nn.compact
@kt.typechecked
@typechecked
def __call__(
self, x: Float["*batch height width channels"]
) -> Float["*batch seq_dim emb_dim"]:
Expand Down Expand Up @@ -258,7 +296,7 @@ def setup(self):
).conditional_norm_factory()

@nn.compact
@kt.typechecked
@typechecked
def __call__(
self, x: Float["*batch seq_dim emb_dim"], cond: Float["*#batch cond_dim"]
) -> Float["*batch height width channels"]:
Expand Down
94 changes: 94 additions & 0 deletions hackable_diffusion/lib/architecture/dit_blocks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,100 @@ def test_zero_init_is_identity(self):
output = module.apply(variables, x, cond, is_training=False)
self.assertTrue(jnp.allclose(output, x, atol=1e-5))

def test_output_shape_with_cross_cond(self):
input_shape = (self.batch, self.n, self.d)
cond_shape = (self.batch, self.c)
cross_cond_shape = (self.batch, 8, self.d)
x = jnp.ones(input_shape)
cond = jnp.ones(cond_shape)
cross_cond = jnp.ones(cross_cond_shape)
module = dit_blocks.DiTBlockAdaLNZero(hidden_size=self.d, num_heads=4)
variables = module.init(
self.key, x, cond, is_training=False, cross_cond=cross_cond
)
output = module.apply(
variables, x, cond, is_training=False, cross_cond=cross_cond
)
self.assertEqual(output.shape, input_shape)

def test_variable_shapes_with_cross_cond(self):
input_shape = (self.batch, self.n, self.d)
cond_shape = (self.batch, self.c)
cross_cond_shape = (self.batch, 8, self.d)
x = jnp.ones(input_shape)
cond = jnp.ones(cond_shape)
cross_cond = jnp.ones(cross_cond_shape)
mlp_hidden = int(self.d * 4.0)
module = dit_blocks.DiTBlockAdaLNZero(hidden_size=self.d, num_heads=4)
variables = module.init(
self.key, x, cond, is_training=False, cross_cond=cross_cond
)
variables_shapes = test_utils.get_pytree_shapes(variables)

expected_variables_shapes = {
'params': {
'Dense_Gate_MSA': {
'kernel': (self.c, self.d),
'bias': (self.d,),
},
'Dense_Gate_MLP': {
'kernel': (self.c, self.d),
'bias': (self.d,),
},
'Dense_Gate_Cross': {
'kernel': (self.c, self.d),
'bias': (self.d,),
},
'ConditionalNorm': {
'Dense_0': {
'kernel': (self.c, self.d * 2),
'bias': (self.d * 2,),
},
},
'MLP': {
'Dense_Hidden_0': {
'kernel': (self.d, mlp_hidden),
'bias': (mlp_hidden,),
},
'Dense_Output': {
'kernel': (mlp_hidden, self.d),
'bias': (self.d,),
},
},
'attn': {
'Dense_Q': {'kernel': (self.d, self.d), 'bias': (self.d,)},
'Dense_K': {'kernel': (self.d, self.d), 'bias': (self.d,)},
'Dense_V': {'kernel': (self.d, self.d), 'bias': (self.d,)},
'Dense_Output': {'kernel': (self.d, self.d), 'bias': (self.d,)},
'norm_qk_scale': (1, 1, 1, 1),
},
'cross_attn': {
'Dense_Q': {'kernel': (self.d, self.d), 'bias': (self.d,)},
'Dense_K': {'kernel': (self.d, self.d), 'bias': (self.d,)},
'Dense_V': {'kernel': (self.d, self.d), 'bias': (self.d,)},
'Dense_Output': {'kernel': (self.d, self.d), 'bias': (self.d,)},
'norm_qk_scale': (1, 1, 1, 1),
},
}
}
self.assertDictEqual(expected_variables_shapes, variables_shapes)

def test_zero_init_is_identity_with_cross_cond(self):
input_shape = (self.batch, self.n, self.d)
cond_shape = (self.batch, self.c)
cross_cond_shape = (self.batch, 8, self.d)
x = jax.random.normal(self.key, input_shape)
cond = jnp.zeros(cond_shape)
cross_cond = jax.random.normal(self.key, cross_cond_shape)
module = dit_blocks.DiTBlockAdaLNZero(hidden_size=self.d, num_heads=4)
variables = module.init(
self.key, x, cond, is_training=False, cross_cond=cross_cond
)
output = module.apply(
variables, x, cond, is_training=False, cross_cond=cross_cond
)
self.assertTrue(jnp.allclose(output, x, atol=1e-5))


class PositionalEmbeddingTest(parameterized.TestCase):

Expand Down