Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,8 @@
title: DeepFloyd IF
- local: api/pipelines/dit
title: DiT
- local: api/pipelines/dreamlite
title: DreamLite
- local: api/pipelines/easyanimate
title: EasyAnimate
- local: api/pipelines/ernie_image
Expand Down
160 changes: 160 additions & 0 deletions docs/source/en/api/pipelines/dreamlite.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
<!--Copyright 2026 The ByteDance Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# DreamLite

DreamLite is a text-to-image and image-editing model from ByteDance. It pairs a custom 2D U-Net
(`DreamLiteUNetModel`) with the `Qwen3-VL` multimodal encoder as its prompt / image-instruction encoder,
and uses an `AutoencoderTiny` (TAESD-style) VAE for fast latent encode/decode.

Two pipelines are exposed:

| Pipeline | Modes | CFG | Use case |
|---|---|---|---|
| [`DreamLitePipeline`] | text-to-image **and** image-editing (auto-selected by whether `image` is `None`) | 3-branch dual CFG (`guidance_scale` on text branch, `image_guidance_scale` on image branch, à la InstructPix2Pix) | Highest quality |
| [`DreamLiteMobilePipeline`] | text-to-image **and** image-editing (auto-selected by whether `image` is `None`) | None — distilled, single UNet forward per step | On-device / low-latency |

Official checkpoints:

* Base model: [carlofkl/DreamLite-base](https://huggingface.co/carlofkl/DreamLite-base)
* Distilled mobile model: [carlofkl/DreamLite-mobile](https://huggingface.co/carlofkl/DreamLite-mobile)

> [!TIP]
> Both pipelines auto-detect text-to-image vs. image-editing mode from whether the `image` argument is
> provided. There is no separate `Img2Img` class.

> [!TIP]
> When loading an input image for editing, prefer `diffusers.utils.load_image(...)` over raw `PIL.Image.open(...)`.
> `load_image` enforces an RGB conversion and applies EXIF orientation, both of which the pipeline assumes.
> A plain `Image.open` of an RGBA / palette / EXIF-rotated source will silently produce a different latent
> conditioning and degrade output quality.

## Text-to-image (Base)

```python
import torch
from diffusers import DreamLitePipeline

pipe = DreamLitePipeline.from_pretrained("carlofkl/DreamLite-base", revision="diffusers", torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")

image = pipe(
prompt="a dog running on the grass",
negative_prompt="",
height=1024,
width=1024,
num_inference_steps=28,
guidance_scale=3.5,
generator=torch.Generator("cpu").manual_seed(42),
).images[0]
image.save("dreamlite_t2i.png")
```

## Image editing (Base)

Pass an `image` to enter edit mode. Both `guidance_scale` (text branch) and `image_guidance_scale`
(image branch) are active here.

```python
import torch
from diffusers import DreamLitePipeline
from diffusers.utils import load_image

pipe = DreamLitePipeline.from_pretrained("carlofkl/DreamLite-base", revision="diffusers", torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")

source = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cat.png")

image = pipe(
prompt="turn the cat into a corgi",
image=source,
height=1024,
width=1024,
num_inference_steps=28,
guidance_scale=3.5,
image_guidance_scale=1.5,
generator=torch.Generator("cpu").manual_seed(42),
).images[0]
image.save("dreamlite_edit.png")
```

## Text-to-image (Mobile)

The mobile pipeline is distilled and skips CFG entirely — a single UNet forward per step. It accepts the
same `prompt` / `height` / `width` / `num_inference_steps` arguments, but **ignores** `guidance_scale` and
`image_guidance_scale` if passed (a warning is logged).

```python
import torch
from diffusers import DreamLiteMobilePipeline

pipe = DreamLiteMobilePipeline.from_pretrained("carlofkl/DreamLite-mobile", revision="diffusers", torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")

image = pipe(
prompt="a dog running on the grass",
height=1024,
width=1024,
num_inference_steps=4,
generator=torch.Generator("cpu").manual_seed(42),
).images[0]
image.save("dreamlite_mobile_t2i.png")
```

## Image editing (Mobile)

```python
import torch
from diffusers import DreamLiteMobilePipeline
from diffusers.utils import load_image

pipe = DreamLiteMobilePipeline.from_pretrained("carlofkl/DreamLite-mobile", revision="diffusers", torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")

source = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cat.png")

image = pipe(
prompt="turn the cat into a corgi",
image=source,
height=1024,
width=1024,
num_inference_steps=4,
generator=torch.Generator("cpu").manual_seed(42),
).images[0]
image.save("dreamlite_mobile_edit.png")
```

## Notes and limitations

* Both pipelines force `batch_size = 1` internally; `num_images_per_prompt` controls how many samples
are drawn from the same prompt rather than parallel batching.
* The prompt encoder is `Qwen3-VL`, which is a multimodal model. Loading the full pipeline therefore
requires sufficient GPU memory for both the U-Net and the Qwen3-VL text encoder (~4 GB + ~0.7 GB
in bf16 for the base release).
* The VAE is `AutoencoderTiny` and exposes `encoder_block_out_channels`; `vae_scale_factor` is derived
from it at pipeline init time.

## DreamLitePipeline

[[autodoc]] DreamLitePipeline
- all
- __call__

## DreamLiteMobilePipeline

[[autodoc]] DreamLiteMobilePipeline
- all
- __call__

## DreamLitePipelineOutput

[[autodoc]] pipelines.dreamlite.pipeline_output.DreamLitePipelineOutput
10 changes: 10 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@
"CosmosControlNetModel",
"CosmosTransformer3DModel",
"DiTTransformer2DModel",
"DreamLiteTransformer2DModel",
"DreamLiteUNetModel",
"EasyAnimateTransformer3DModel",
"ErnieImageTransformer2DModel",
"Flux2Transformer2DModel",
Expand Down Expand Up @@ -546,6 +548,9 @@
"CosmosTextToWorldPipeline",
"CosmosVideoToWorldPipeline",
"CycleDiffusionPipeline",
"DreamLiteMobilePipeline",
"DreamLitePipeline",
"DreamLitePipelineOutput",
"EasyAnimateControlPipeline",
"EasyAnimateInpaintPipeline",
"EasyAnimatePipeline",
Expand Down Expand Up @@ -1071,6 +1076,8 @@
CosmosControlNetModel,
CosmosTransformer3DModel,
DiTTransformer2DModel,
DreamLiteTransformer2DModel,
DreamLiteUNetModel,
EasyAnimateTransformer3DModel,
ErnieImageTransformer2DModel,
Flux2Transformer2DModel,
Expand Down Expand Up @@ -1354,6 +1361,9 @@
CosmosTextToWorldPipeline,
CosmosVideoToWorldPipeline,
CycleDiffusionPipeline,
DreamLiteMobilePipeline,
DreamLitePipeline,
DreamLitePipelineOutput,
EasyAnimateControlPipeline,
EasyAnimateInpaintPipeline,
EasyAnimatePipeline,
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_2d_dreamlite"] = ["DreamLiteTransformer2DModel"]
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_anyflow"] = ["AnyFlowTransformer3DModel"]
_import_structure["transformers.transformer_anyflow_far"] = ["AnyFlowFARTransformer3DModel"]
Expand Down Expand Up @@ -141,6 +142,7 @@
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
_import_structure["unets.unet_3d_condition"] = ["UNet3DConditionModel"]
_import_structure["unets.unet_dreamlite"] = ["DreamLiteUNetModel"]
_import_structure["unets.unet_i2vgen_xl"] = ["I2VGenXLUNet"]
_import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"]
_import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
Expand Down Expand Up @@ -229,6 +231,7 @@
ConsisIDTransformer3DModel,
CosmosTransformer3DModel,
DiTTransformer2DModel,
DreamLiteTransformer2DModel,
DualTransformer2DModel,
EasyAnimateTransformer3DModel,
ErnieImageTransformer2DModel,
Expand Down Expand Up @@ -274,6 +277,7 @@
ZImageTransformer2DModel,
)
from .unets import (
DreamLiteUNetModel,
I2VGenXLUNet,
Kandinsky3UNet,
MotionAdapter,
Expand Down
108 changes: 108 additions & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2787,6 +2787,114 @@ def __call__(
return hidden_states


class DreamLiteAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention with Grouped Query Attention (GQA / MQA) support (enabled
by default if you're using PyTorch 2.0).

Identical to :class:`AttnProcessor2_0` except the key/value reshape branch correctly handles ``attn.kv_heads !=
attn.heads`` by reshaping K/V to ``kv_heads`` and then ``repeat_interleave``-ing them up to ``attn.heads``. This is
required by the DreamLite UNet, which combines GQA with ``qk_norm`` — a combination the default
:class:`AttnProcessor2_0` does not handle.
"""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"DreamLiteAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)

def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
temb: torch.Tensor | None = None,
*args,
**kwargs,
) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)

residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

# --- GQA-aware reshape (the only real difference vs AttnProcessor2_0) ---
head_dim = query.shape[-1] // attn.heads
kv_heads = key.shape[-1] // head_dim

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

if kv_heads != attn.heads:
# GQA / MQA: repeat K/V heads up to query heads for SDPA.
heads_per_kv_head = attn.heads // kv_heads
key = torch.repeat_interleave(key, heads_per_kv_head, dim=1, output_size=key.shape[1] * heads_per_kv_head)
value = torch.repeat_interleave(
value, heads_per_kv_head, dim=1, output_size=value.shape[1] * heads_per_kv_head
)
# ------------------------------------------------------------------------

# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states


class XLAFlashAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
Expand Down
Loading
Loading