Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions hackable_diffusion/lib/architecture/arch_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class ConditioningMechanism(enum.StrEnum):

ADAPTIVE_NORM = "adaptive_norm"
CROSS_ATTENTION = "cross_attention"
CROSS_ATTENTION_MASK = "cross_attention_mask"
CONCATENATE = "concatenate"
SUM = "sum"
CUSTOM = "custom"
Expand Down
19 changes: 18 additions & 1 deletion hackable_diffusion/lib/architecture/dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,18 @@ def __call__(
if adaptive_norm_emb is None:
raise ValueError("adaptive_norm_emb must be provided.")

# Extract conditioning embeddings to use with cross attention.
cross_attention_emb = conditioning_embeddings.get(
ConditioningMechanism.CROSS_ATTENTION
)
if cross_attention_emb is not None and cross_attention_emb.ndim == 2:
cross_attention_emb = cross_attention_emb[:, jnp.newaxis, :]

# Extract cross-attention mask.
cross_attention_mask = conditioning_embeddings.get(
ConditioningMechanism.CROSS_ATTENTION_MASK
)

# TODO(agalashov): This assumes that x is already tokenized, which is not
# true for images.
if self.use_padding_mask:
Expand All @@ -126,7 +138,12 @@ def __call__(
cond = adaptive_norm_emb
for i in range(1, self.num_blocks + 1):
tokens_emb = self.block.copy(name=f"Block_{i}")(
tokens_emb, cond, is_training=is_training, mask=padding_mask
tokens_emb,
cond,
is_training=is_training,
mask=padding_mask,
cross_cond=cross_attention_emb,
cross_mask=cross_attention_mask,
)

tokens_emb = self.conditional_norm(tokens_emb, c=nn.silu(cond))
Expand Down
52 changes: 45 additions & 7 deletions hackable_diffusion/lib/architecture/dit_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from hackable_diffusion.lib.architecture import attention
from hackable_diffusion.lib.architecture import mlp_blocks
from hackable_diffusion.lib.architecture import normalization
from hackable_diffusion.lib.hd_typing import typechecked # pylint: disable=g-multiple-import,g-importing-member
import jax.numpy as jnp
import kauldron.ktyping as kt

################################################################################
# MARK: Type Aliases
Expand Down Expand Up @@ -50,7 +50,7 @@ class PositionalEmbedding(nn.Module):
init_stddev: float = 0.02

@nn.compact
@kt.typechecked
@typechecked
def __call__(self, x: Num["batch *data_shape"]) -> Float["batch *data_shape"]:
pos_embed = self.param(
"PositionalEmbeddingTensor",
Expand Down Expand Up @@ -134,25 +134,49 @@ def setup(self):
use_scale=False,
).conditional_norm_factory()

@kt.typechecked
# Cross-Attention Module
self.cross_attn = attention.MultiHeadAttention(
num_heads=self.num_heads,
head_dim=self.head_dim,
# Note: We typically do NOT use RoPE for cross-attention because
# the spatial relationship between text and SMILES is not 1-to-1 linear.
use_rope=False,
zero_init_output=False,
dtype=self.dtype,
normalize_qk=True,
)
self.gate_cross = nn.Dense(
self.hidden_size,
kernel_init=nn.initializers.zeros_init(),
bias_init=nn.initializers.zeros_init(),
name="Dense_Gate_Cross",
)

@typechecked
@nn.compact
def __call__(
self,
x: Float["*batch seq_dim emb_dim"],
cond: Float["*#batch cond_dim"],
cond: Float["*batch cond_dim"],
*,
is_training: bool,
mask: Bool["batch seq_dim"] | None = None,
cross_cond: Float["*batch seq_cross cross_cond_dim"] | None = None,
cross_mask: Bool["batch seq_cross"] | None = None,
) -> Float["*batch seq_dim emb_dim"]:
"""Calls the DiT block.

Args:
x: The input tensor.
cond: The conditioning tensor.
cond: The conditioning tensor for adaptive normalization.
is_training: Whether the block is in training mode.
mask: The self-attention padding mask. If the mask is provided, it is
assumed that the input sequence contains padding tokens that should be
masked out when computing the self-attention.
cross_cond: The cross attention conditioning tensor.
cross_mask: The cross-attention padding mask. If the mask is provided, it
is assumed that the input sequence contains padding tokens that should
be masked out when computing the cross-attention.

Returns:
The output tensor.
Expand All @@ -170,6 +194,20 @@ def __call__(
# Add a sequence dimension [...,None,:] to broadcast to [*batch,seq,dim].
x = x + gate_msa[..., None, :] * attn_out

# 2. Cross-Attention Branch (NEW)
if cross_cond is not None:
x_cross_modulated = self.conditional_norm(x, c=nn.silu(cond))
# Pass cross_cond as `c`, and cross_mask as `mask` (which applies to the keys/cross_cond)
cross_out = self.cross_attn(
x_cross_modulated, c=cross_cond, mask=cross_mask
)
if self.dropout_rate > 0.0:
cross_out = nn.Dropout(rate=self.dropout_rate)(
cross_out, deterministic=not is_training
)
gate_cross = self.gate_cross(nn.silu(cond))
x = x + gate_cross[..., None, :] * cross_out

# MLP Branch
x_mlp_modulated = self.conditional_norm(x, c=nn.silu(cond))
mlp_out = self.mlp(x_mlp_modulated, is_training=is_training)
Expand Down Expand Up @@ -204,7 +242,7 @@ class Patchify(nn.Module):
embedding_dim: int

@nn.compact
@kt.typechecked
@typechecked
def __call__(
self, x: Float["*batch height width channels"]
) -> Float["*batch seq_dim emb_dim"]:
Expand Down Expand Up @@ -258,7 +296,7 @@ def setup(self):
).conditional_norm_factory()

@nn.compact
@kt.typechecked
@typechecked
def __call__(
self, x: Float["*batch seq_dim emb_dim"], cond: Float["*#batch cond_dim"]
) -> Float["*batch height width channels"]:
Expand Down
94 changes: 94 additions & 0 deletions hackable_diffusion/lib/architecture/dit_blocks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,100 @@ def test_zero_init_is_identity(self):
output = module.apply(variables, x, cond, is_training=False)
self.assertTrue(jnp.allclose(output, x, atol=1e-5))

def test_output_shape_with_cross_cond(self):
input_shape = (self.batch, self.n, self.d)
cond_shape = (self.batch, self.c)
cross_cond_shape = (self.batch, 8, self.d)
x = jnp.ones(input_shape)
cond = jnp.ones(cond_shape)
cross_cond = jnp.ones(cross_cond_shape)
module = dit_blocks.DiTBlockAdaLNZero(hidden_size=self.d, num_heads=4)
variables = module.init(
self.key, x, cond, is_training=False, cross_cond=cross_cond
)
output = module.apply(
variables, x, cond, is_training=False, cross_cond=cross_cond
)
self.assertEqual(output.shape, input_shape)

def test_variable_shapes_with_cross_cond(self):
input_shape = (self.batch, self.n, self.d)
cond_shape = (self.batch, self.c)
cross_cond_shape = (self.batch, 8, self.d)
x = jnp.ones(input_shape)
cond = jnp.ones(cond_shape)
cross_cond = jnp.ones(cross_cond_shape)
mlp_hidden = int(self.d * 4.0)
module = dit_blocks.DiTBlockAdaLNZero(hidden_size=self.d, num_heads=4)
variables = module.init(
self.key, x, cond, is_training=False, cross_cond=cross_cond
)
variables_shapes = test_utils.get_pytree_shapes(variables)

expected_variables_shapes = {
'params': {
'Dense_Gate_MSA': {
'kernel': (self.c, self.d),
'bias': (self.d,),
},
'Dense_Gate_MLP': {
'kernel': (self.c, self.d),
'bias': (self.d,),
},
'Dense_Gate_Cross': {
'kernel': (self.c, self.d),
'bias': (self.d,),
},
'ConditionalNorm': {
'Dense_0': {
'kernel': (self.c, self.d * 2),
'bias': (self.d * 2,),
},
},
'MLP': {
'Dense_Hidden_0': {
'kernel': (self.d, mlp_hidden),
'bias': (mlp_hidden,),
},
'Dense_Output': {
'kernel': (mlp_hidden, self.d),
'bias': (self.d,),
},
},
'attn': {
'Dense_Q': {'kernel': (self.d, self.d), 'bias': (self.d,)},
'Dense_K': {'kernel': (self.d, self.d), 'bias': (self.d,)},
'Dense_V': {'kernel': (self.d, self.d), 'bias': (self.d,)},
'Dense_Output': {'kernel': (self.d, self.d), 'bias': (self.d,)},
'norm_qk_scale': (1, 1, 1, 1),
},
'cross_attn': {
'Dense_Q': {'kernel': (self.d, self.d), 'bias': (self.d,)},
'Dense_K': {'kernel': (self.d, self.d), 'bias': (self.d,)},
'Dense_V': {'kernel': (self.d, self.d), 'bias': (self.d,)},
'Dense_Output': {'kernel': (self.d, self.d), 'bias': (self.d,)},
'norm_qk_scale': (1, 1, 1, 1),
},
}
}
self.assertDictEqual(expected_variables_shapes, variables_shapes)

def test_zero_init_is_identity_with_cross_cond(self):
input_shape = (self.batch, self.n, self.d)
cond_shape = (self.batch, self.c)
cross_cond_shape = (self.batch, 8, self.d)
x = jax.random.normal(self.key, input_shape)
cond = jnp.zeros(cond_shape)
cross_cond = jax.random.normal(self.key, cross_cond_shape)
module = dit_blocks.DiTBlockAdaLNZero(hidden_size=self.d, num_heads=4)
variables = module.init(
self.key, x, cond, is_training=False, cross_cond=cross_cond
)
output = module.apply(
variables, x, cond, is_training=False, cross_cond=cross_cond
)
self.assertTrue(jnp.allclose(output, x, atol=1e-5))


class PositionalEmbeddingTest(parameterized.TestCase):

Expand Down
Loading