Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions hackable_diffusion/lib/diffusion_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,120 @@ def __call__(
return {self.prediction_type: backbone_outputs}


################################################################################
# MARK: Self-Conditioning Diffusion Network
################################################################################


class SelfConditioningDiffusionNetwork(DiffusionNetwork):
"""DiffusionNetwork with self-conditioning on x₀ predictions.

Implements self-conditioning on x₀ predictions from the discrete diffusion
literature (Chen et al. "Analog Bits"; Strudel et al.; D3PM).

During training, with probability ``self_cond_prob`` (default 0.5):
1. Run the network once with zero x̂₀ input to get initial logits.
2. ``stop_gradient`` on the initial logits.
3. Concatenate the logits to the noisy input xₜ along the last axis.
4. Run the network again and return the output.

During inference (``is_training=False``), self-conditioning is always applied.

The ``backbone_network`` is expected to accept the wider input (xₜ
concatenated with x̂₀ logits on the last axis).

Attributes:
num_output_classes: Number of output classes (categories) for the
prediction. Used to create zero-filled logits of the correct shape for
the first forward pass.
self_cond_prob: Probability of applying self-conditioning during training.
During inference, self-conditioning is always applied.
rng_collection: The PRNG collection name to use for drawing the
self-conditioning mask. Defaults to 'dropout'.
"""

num_output_classes: int = -1
self_cond_prob: float = 0.5
rng_collection: str = 'dropout'

@nn.compact
@kt.typechecked
def __call__(
self,
time: TimeArray,
xt: DataArray,
conditioning: Conditioning | None,
is_training: bool,
) -> TargetInfo:

if self.num_output_classes <= 0:
raise ValueError(
'`num_output_classes` must be a positive integer, '
f'got {self.num_output_classes}.'
)

# Rescale time and encode conditioning once for both passes.
time_rescaled = (
self.time_rescaler(time) if self.time_rescaler is not None else time
)

conditioning_embeddings = self.conditioning_encoder.copy(
name='ConditioningEncoder'
)(
time=time_rescaled,
conditioning=conditioning,
is_training=is_training,
)

# Create zero logits with the same spatial shape as xt.
zero_logits = jnp.zeros(
xt.shape[:-1] + (self.num_output_classes,), dtype=xt.dtype
)

# First pass: run with zero logits to get initial x̂₀.
xt_with_zeros = jnp.concatenate([xt, zero_logits], axis=-1)
xt_with_zeros_rescaled = (
self.input_rescaler(time, xt_with_zeros)
if self.input_rescaler is not None
else xt_with_zeros
)

backbone_module = self.backbone_network.copy(name='Backbone')
first_output = backbone_module(
x=xt_with_zeros_rescaled,
conditioning_embeddings=conditioning_embeddings,
is_training=is_training,
)

# Extract logits and detach gradients.
x0_hat_logits = jax.lax.stop_gradient(first_output)

if is_training:
# With probability self_cond_prob, run self-conditioning.
do_self_cond = (
jax.random.uniform(self.make_rng(self.rng_collection))
< self.self_cond_prob
)
# Conditionally use the self-conditioned logits or zeros.
x0_hat_logits = jnp.where(do_self_cond, x0_hat_logits, zero_logits)

# Second pass: run with x̂₀ logits concatenated.
xt_with_x0_hat = jnp.concatenate([xt, x0_hat_logits], axis=-1)
xt_with_x0_hat_rescaled = (
self.input_rescaler(time, xt_with_x0_hat)
if self.input_rescaler is not None
else xt_with_x0_hat
)

backbone_outputs = backbone_module(
x=xt_with_x0_hat_rescaled,
conditioning_embeddings=conditioning_embeddings,
is_training=is_training,
)

return {self.prediction_type: backbone_outputs}


################################################################################
# MARK: Multi-modal Diffusion Network
################################################################################
Expand Down
202 changes: 202 additions & 0 deletions hackable_diffusion/lib/diffusion_network_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Tests for diffusion_network and its components."""

from collections.abc import Mapping

import chex
from flax import linen as nn
from hackable_diffusion.lib import diffusion_network
Expand Down Expand Up @@ -375,5 +377,205 @@ def test_multimodal_diffusion_network(self, input_type: str):
chex.assert_trees_all_equal_structs(modified_t, output)


################################################################################
# MARK: SelfConditioningDiffusionNetwork Tests
################################################################################


class SelfConditioningBackbone(arch_typing.ConditionalBackbone):
"""Backbone for self-conditioning tests.

Accepts input of shape (B, ..., input_channels + num_classes) and returns
output of shape (B, ..., num_classes). The backbone simply applies a dense
layer so the output depends on the input content.
"""

num_classes: int = 4

@nn.compact
def __call__(
self,
x: arch_typing.DataTree,
conditioning_embeddings: Mapping[
arch_typing.ConditioningMechanism, Float['batch ...']
],
is_training: bool,
) -> arch_typing.DataTree:
return nn.Dense(features=self.num_classes)(x)


class SelfConditioningDiffusionNetworkTest(parameterized.TestCase):

def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(0)
self.batch_size = 2
self.input_channels = 1
self.num_output_classes = 4
self.spatial_shape = (8, 8)
self.xt_shape = (
self.batch_size, *self.spatial_shape, self.input_channels
)
self.t = jnp.ones((self.batch_size,))
self.xt = jnp.ones(self.xt_shape)
self.conditioning = {
'label1': jnp.arange(self.batch_size),
}

self.time_encoder = conditioning_encoder.SinusoidalTimeEmbedder(
activation='silu',
embedding_dim=16,
num_features=32,
)
self.cond_encoder = conditioning_encoder.ConditioningEncoder(
time_embedder=self.time_encoder,
conditioning_embedders={
'label': conditioning_encoder.LabelEmbedder(
conditioning_key='label1',
num_classes=10,
num_features=16,
),
},
embedding_merging_method=arch_typing.EmbeddingMergeMethod.CONCAT,
conditioning_rules={
'time': arch_typing.ConditioningMechanism.ADAPTIVE_NORM,
'label': arch_typing.ConditioningMechanism.ADAPTIVE_NORM,
},
)
self.backbone = SelfConditioningBackbone(
num_classes=self.num_output_classes
)

def _make_network(
self, self_cond_prob: float = 0.5
) -> diffusion_network.SelfConditioningDiffusionNetwork:
return diffusion_network.SelfConditioningDiffusionNetwork(
backbone_network=self.backbone,
conditioning_encoder=self.cond_encoder,
prediction_type='logits',
data_dtype=jnp.float32,
num_output_classes=self.num_output_classes,
self_cond_prob=self_cond_prob,
)

def test_output_shape(self):
network = self._make_network()
variables = network.init(
{'params': self.key, 'dropout': self.key},
self.t, self.xt, self.conditioning, True,
)

output = network.apply(
variables,
self.t, self.xt, self.conditioning, True,
rngs={'dropout': self.key},
)

self.assertIsInstance(output, dict)
self.assertIn('logits', output)

expected_shape = (
self.batch_size, *self.spatial_shape, self.num_output_classes
)

self.assertEqual(output['logits'].shape, expected_shape)

def test_self_cond_prob_zero_skips_self_cond(self):
network = self._make_network(self_cond_prob=0.0)
variables = network.init(
{'params': self.key, 'dropout': self.key},
self.t, self.xt, self.conditioning, True,
)

output_no_sc = network.apply(
variables,
self.t, self.xt, self.conditioning, True,
rngs={'dropout': self.key},
)
# With prob=1.0, self-conditioning is always applied (different output).
network_always = self._make_network(self_cond_prob=1.0)
output_always = network_always.apply(
variables,
self.t,
self.xt,
self.conditioning,
True,
rngs={'dropout': self.key},
)
self.assertFalse(
jnp.allclose(output_no_sc['logits'], output_always['logits']),
msg='Outputs should differ since self-conditioning changes the input.',
)

def test_self_cond_prob_one_always_self_conditions(self):
network = self._make_network(self_cond_prob=1.0)
variables = network.init(
{'params': self.key, 'dropout': self.key},
self.t, self.xt, self.conditioning, True,
)
# Run twice with different RNG — should give the same result since
# self_cond_prob=1.0 means the random draw has no effect.
output_a = network.apply(
variables,
self.t, self.xt, self.conditioning, True,
rngs={'dropout': jax.random.PRNGKey(42)},
)
output_b = network.apply(
variables,
self.t, self.xt, self.conditioning, True,
rngs={'dropout': jax.random.PRNGKey(99)},
)

chex.assert_trees_all_close(output_a, output_b)

def test_inference_always_self_conditions(self):
# Even with self_cond_prob=0.0, inference should self-condition.
network = self._make_network(self_cond_prob=0.0)
variables = network.init(
{'params': self.key, 'dropout': self.key},
self.t, self.xt, self.conditioning, True,
)
# Inference output (is_training=False).
output_infer = network.apply(
variables,
self.t, self.xt, self.conditioning, False,
)
# Training with self_cond_prob=1.0 should match inference.
network_always = self._make_network(self_cond_prob=1.0)

output_train_sc = network_always.apply(
variables,
self.t, self.xt, self.conditioning, True,
rngs={'dropout': self.key},
)

chex.assert_trees_all_close(output_infer, output_train_sc)

def test_default_self_cond_prob(self):
network = diffusion_network.SelfConditioningDiffusionNetwork(
backbone_network=self.backbone,
conditioning_encoder=self.cond_encoder,
prediction_type='logits',
num_output_classes=self.num_output_classes,
)

self.assertEqual(network.self_cond_prob, 0.5)

def test_invalid_num_output_classes_raises(self):
network = diffusion_network.SelfConditioningDiffusionNetwork(
backbone_network=self.backbone,
conditioning_encoder=self.cond_encoder,
prediction_type='logits',
)
with self.assertRaisesRegex(ValueError, 'num_output_classes'):
network.init(
{'params': self.key, 'dropout': self.key},
self.t,
self.xt,
self.conditioning,
True,
)


if __name__ == '__main__':
absltest.main()
Loading