From 6d4058698c71e5b237e611ebddff22d2379c7a61 Mon Sep 17 00:00:00 2001 From: Hackable Diffusion Authors Date: Fri, 27 Mar 2026 05:14:34 -0700 Subject: [PATCH] Implement SelfConditioningDiffusionNetwork. PiperOrigin-RevId: 890391263 --- hackable_diffusion/lib/diffusion_network.py | 114 ++++++++++ .../lib/diffusion_network_test.py | 202 ++++++++++++++++++ 2 files changed, 316 insertions(+) diff --git a/hackable_diffusion/lib/diffusion_network.py b/hackable_diffusion/lib/diffusion_network.py index d2002dc..f0c2fae 100644 --- a/hackable_diffusion/lib/diffusion_network.py +++ b/hackable_diffusion/lib/diffusion_network.py @@ -190,6 +190,120 @@ def __call__( return {self.prediction_type: backbone_outputs} +################################################################################ +# MARK: Self-Conditioning Diffusion Network +################################################################################ + + +class SelfConditioningDiffusionNetwork(DiffusionNetwork): + """DiffusionNetwork with self-conditioning on x₀ predictions. + + Implements self-conditioning on x₀ predictions from the discrete diffusion + literature (Chen et al. "Analog Bits"; Strudel et al.; D3PM). + + During training, with probability ``self_cond_prob`` (default 0.5): + 1. Run the network once with zero x̂₀ input to get initial logits. + 2. ``stop_gradient`` on the initial logits. + 3. Concatenate the logits to the noisy input xₜ along the last axis. + 4. Run the network again and return the output. + + During inference (``is_training=False``), self-conditioning is always applied. + + The ``backbone_network`` is expected to accept the wider input (xₜ + concatenated with x̂₀ logits on the last axis). + + Attributes: + num_output_classes: Number of output classes (categories) for the + prediction. Used to create zero-filled logits of the correct shape for + the first forward pass. + self_cond_prob: Probability of applying self-conditioning during training. + During inference, self-conditioning is always applied. + rng_collection: The PRNG collection name to use for drawing the + self-conditioning mask. Defaults to 'dropout'. + """ + + num_output_classes: int = -1 + self_cond_prob: float = 0.5 + rng_collection: str = 'dropout' + + @nn.compact + @kt.typechecked + def __call__( + self, + time: TimeArray, + xt: DataArray, + conditioning: Conditioning | None, + is_training: bool, + ) -> TargetInfo: + + if self.num_output_classes <= 0: + raise ValueError( + '`num_output_classes` must be a positive integer, ' + f'got {self.num_output_classes}.' + ) + + # Rescale time and encode conditioning once for both passes. + time_rescaled = ( + self.time_rescaler(time) if self.time_rescaler is not None else time + ) + + conditioning_embeddings = self.conditioning_encoder.copy( + name='ConditioningEncoder' + )( + time=time_rescaled, + conditioning=conditioning, + is_training=is_training, + ) + + # Create zero logits with the same spatial shape as xt. + zero_logits = jnp.zeros( + xt.shape[:-1] + (self.num_output_classes,), dtype=xt.dtype + ) + + # First pass: run with zero logits to get initial x̂₀. + xt_with_zeros = jnp.concatenate([xt, zero_logits], axis=-1) + xt_with_zeros_rescaled = ( + self.input_rescaler(time, xt_with_zeros) + if self.input_rescaler is not None + else xt_with_zeros + ) + + backbone_module = self.backbone_network.copy(name='Backbone') + first_output = backbone_module( + x=xt_with_zeros_rescaled, + conditioning_embeddings=conditioning_embeddings, + is_training=is_training, + ) + + # Extract logits and detach gradients. + x0_hat_logits = jax.lax.stop_gradient(first_output) + + if is_training: + # With probability self_cond_prob, run self-conditioning. + do_self_cond = ( + jax.random.uniform(self.make_rng(self.rng_collection)) + < self.self_cond_prob + ) + # Conditionally use the self-conditioned logits or zeros. + x0_hat_logits = jnp.where(do_self_cond, x0_hat_logits, zero_logits) + + # Second pass: run with x̂₀ logits concatenated. + xt_with_x0_hat = jnp.concatenate([xt, x0_hat_logits], axis=-1) + xt_with_x0_hat_rescaled = ( + self.input_rescaler(time, xt_with_x0_hat) + if self.input_rescaler is not None + else xt_with_x0_hat + ) + + backbone_outputs = backbone_module( + x=xt_with_x0_hat_rescaled, + conditioning_embeddings=conditioning_embeddings, + is_training=is_training, + ) + + return {self.prediction_type: backbone_outputs} + + ################################################################################ # MARK: Multi-modal Diffusion Network ################################################################################ diff --git a/hackable_diffusion/lib/diffusion_network_test.py b/hackable_diffusion/lib/diffusion_network_test.py index fb43f3d..27fd575 100644 --- a/hackable_diffusion/lib/diffusion_network_test.py +++ b/hackable_diffusion/lib/diffusion_network_test.py @@ -14,6 +14,8 @@ """Tests for diffusion_network and its components.""" +from collections.abc import Mapping + import chex from flax import linen as nn from hackable_diffusion.lib import diffusion_network @@ -375,5 +377,205 @@ def test_multimodal_diffusion_network(self, input_type: str): chex.assert_trees_all_equal_structs(modified_t, output) +################################################################################ +# MARK: SelfConditioningDiffusionNetwork Tests +################################################################################ + + +class SelfConditioningBackbone(arch_typing.ConditionalBackbone): + """Backbone for self-conditioning tests. + + Accepts input of shape (B, ..., input_channels + num_classes) and returns + output of shape (B, ..., num_classes). The backbone simply applies a dense + layer so the output depends on the input content. + """ + + num_classes: int = 4 + + @nn.compact + def __call__( + self, + x: arch_typing.DataTree, + conditioning_embeddings: Mapping[ + arch_typing.ConditioningMechanism, Float['batch ...'] + ], + is_training: bool, + ) -> arch_typing.DataTree: + return nn.Dense(features=self.num_classes)(x) + + +class SelfConditioningDiffusionNetworkTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.key = jax.random.PRNGKey(0) + self.batch_size = 2 + self.input_channels = 1 + self.num_output_classes = 4 + self.spatial_shape = (8, 8) + self.xt_shape = ( + self.batch_size, *self.spatial_shape, self.input_channels + ) + self.t = jnp.ones((self.batch_size,)) + self.xt = jnp.ones(self.xt_shape) + self.conditioning = { + 'label1': jnp.arange(self.batch_size), + } + + self.time_encoder = conditioning_encoder.SinusoidalTimeEmbedder( + activation='silu', + embedding_dim=16, + num_features=32, + ) + self.cond_encoder = conditioning_encoder.ConditioningEncoder( + time_embedder=self.time_encoder, + conditioning_embedders={ + 'label': conditioning_encoder.LabelEmbedder( + conditioning_key='label1', + num_classes=10, + num_features=16, + ), + }, + embedding_merging_method=arch_typing.EmbeddingMergeMethod.CONCAT, + conditioning_rules={ + 'time': arch_typing.ConditioningMechanism.ADAPTIVE_NORM, + 'label': arch_typing.ConditioningMechanism.ADAPTIVE_NORM, + }, + ) + self.backbone = SelfConditioningBackbone( + num_classes=self.num_output_classes + ) + + def _make_network( + self, self_cond_prob: float = 0.5 + ) -> diffusion_network.SelfConditioningDiffusionNetwork: + return diffusion_network.SelfConditioningDiffusionNetwork( + backbone_network=self.backbone, + conditioning_encoder=self.cond_encoder, + prediction_type='logits', + data_dtype=jnp.float32, + num_output_classes=self.num_output_classes, + self_cond_prob=self_cond_prob, + ) + + def test_output_shape(self): + network = self._make_network() + variables = network.init( + {'params': self.key, 'dropout': self.key}, + self.t, self.xt, self.conditioning, True, + ) + + output = network.apply( + variables, + self.t, self.xt, self.conditioning, True, + rngs={'dropout': self.key}, + ) + + self.assertIsInstance(output, dict) + self.assertIn('logits', output) + + expected_shape = ( + self.batch_size, *self.spatial_shape, self.num_output_classes + ) + + self.assertEqual(output['logits'].shape, expected_shape) + + def test_self_cond_prob_zero_skips_self_cond(self): + network = self._make_network(self_cond_prob=0.0) + variables = network.init( + {'params': self.key, 'dropout': self.key}, + self.t, self.xt, self.conditioning, True, + ) + + output_no_sc = network.apply( + variables, + self.t, self.xt, self.conditioning, True, + rngs={'dropout': self.key}, + ) + # With prob=1.0, self-conditioning is always applied (different output). + network_always = self._make_network(self_cond_prob=1.0) + output_always = network_always.apply( + variables, + self.t, + self.xt, + self.conditioning, + True, + rngs={'dropout': self.key}, + ) + self.assertFalse( + jnp.allclose(output_no_sc['logits'], output_always['logits']), + msg='Outputs should differ since self-conditioning changes the input.', + ) + + def test_self_cond_prob_one_always_self_conditions(self): + network = self._make_network(self_cond_prob=1.0) + variables = network.init( + {'params': self.key, 'dropout': self.key}, + self.t, self.xt, self.conditioning, True, + ) + # Run twice with different RNG — should give the same result since + # self_cond_prob=1.0 means the random draw has no effect. + output_a = network.apply( + variables, + self.t, self.xt, self.conditioning, True, + rngs={'dropout': jax.random.PRNGKey(42)}, + ) + output_b = network.apply( + variables, + self.t, self.xt, self.conditioning, True, + rngs={'dropout': jax.random.PRNGKey(99)}, + ) + + chex.assert_trees_all_close(output_a, output_b) + + def test_inference_always_self_conditions(self): + # Even with self_cond_prob=0.0, inference should self-condition. + network = self._make_network(self_cond_prob=0.0) + variables = network.init( + {'params': self.key, 'dropout': self.key}, + self.t, self.xt, self.conditioning, True, + ) + # Inference output (is_training=False). + output_infer = network.apply( + variables, + self.t, self.xt, self.conditioning, False, + ) + # Training with self_cond_prob=1.0 should match inference. + network_always = self._make_network(self_cond_prob=1.0) + + output_train_sc = network_always.apply( + variables, + self.t, self.xt, self.conditioning, True, + rngs={'dropout': self.key}, + ) + + chex.assert_trees_all_close(output_infer, output_train_sc) + + def test_default_self_cond_prob(self): + network = diffusion_network.SelfConditioningDiffusionNetwork( + backbone_network=self.backbone, + conditioning_encoder=self.cond_encoder, + prediction_type='logits', + num_output_classes=self.num_output_classes, + ) + + self.assertEqual(network.self_cond_prob, 0.5) + + def test_invalid_num_output_classes_raises(self): + network = diffusion_network.SelfConditioningDiffusionNetwork( + backbone_network=self.backbone, + conditioning_encoder=self.cond_encoder, + prediction_type='logits', + ) + with self.assertRaisesRegex(ValueError, 'num_output_classes'): + network.init( + {'params': self.key, 'dropout': self.key}, + self.t, + self.xt, + self.conditioning, + True, + ) + + if __name__ == '__main__': absltest.main()