Skip to content

UNet2DModel dtype property fails under nn.DataParallel with UnboundLocalError in get_parameter_dtype #13789

@ragibarnab

Description

@ragibarnab

Describe the bug

Description

UNet2DModel fails under torch.nn.DataParallel during the forward pass when it accesses self.dtype inside UNet2DModel.forward.

The error appears to come from diffusers.models.modeling_utils.get_parameter_dtype, where the nested function annotation uses tuple:

def find_tensor_attributes(module: nn.Module) -> list[tuple[str, Tensor]]:

### Reproduction

import torch
from diffusers import UNet2DModel

print("torch:", torch.__version__)

import diffusers
print("diffusers:", diffusers.__version__)

model = UNet2DModel(
    sample_size=128,
    in_channels=10,
    out_channels=5,
    layers_per_block=2,
    block_out_channels=(64, 128, 256, 256),
    down_block_types=(
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
    norm_num_groups=8,
)

model = model.cuda()
model = torch.nn.DataParallel(model, device_ids=[0, 1])
model.eval()

x = torch.randn(8, 10, 128, 128, device="cuda")
t = torch.rand(8, device="cuda") * 1000

with torch.no_grad():
    y = model(x, t, return_dict=False)[0]

print(y.shape)

### Logs

```shell
File ".../diffusers/models/unets/unet_2d.py", line ..., in forward
    t_emb = t_emb.to(dtype=self.dtype)

File ".../diffusers/models/modeling_utils.py", line ..., in dtype
    return get_parameter_dtype(self)

File ".../diffusers/models/modeling_utils.py", line ..., in get_parameter_dtype
    def find_tensor_attributes(module: nn.Module) -> list[tuple[str, Tensor]]:

UnboundLocalError: cannot access local variable 'tuple' where it is not associated with a value

System Info

Python 3.11.13, torch 2.9.1+cu128, diffusers 0.38.0, CUDA Version: 12.9

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingmodels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions