Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 37 additions & 43 deletions tests/models/autoencoders/test_models_autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@
# limitations under the License.

import gc
import unittest

import pytest
import torch
from parameterized import parameterized

from diffusers import AutoencoderKL
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import randn_tensor

from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_hf_numpy,
require_torch_accelerator,
require_torch_accelerator_with_fp16,
Expand All @@ -35,22 +35,30 @@
torch_all_close,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
from .testing_utils import NewAutoencoderTesterMixin


enable_full_determinism()


class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKL
main_input_name = "sample"
base_precision = 1e-2
class AutoencoderKLTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return AutoencoderKL

@property
def output_shape(self):
return (3, 32, 32)

@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)

def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
def get_init_dict(self, block_out_channels=None, norm_num_groups=None):
block_out_channels = block_out_channels or [2, 4]
norm_num_groups = norm_num_groups or 2
init_dict = {
return {
"block_out_channels": block_out_channels,
"in_channels": 3,
"out_channels": 3,
Expand All @@ -59,42 +67,27 @@ def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=Non
"latent_channels": 4,
"norm_num_groups": norm_num_groups,
}
return init_dict

@property
def dummy_input(self):
def get_dummy_inputs(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)

image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)

image = randn_tensor((batch_size, num_channels, *sizes), generator=self.generator, device=torch_device)
return {"sample": image}

@property
def input_shape(self):
return (3, 32, 32)

@property
def output_shape(self):
return (3, 32, 32)

def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict

class TestAutoencoderKL(AutoencoderKLTesterConfig, ModelTesterMixin, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

def test_from_pretrained_hub(self):
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
assert model is not None
assert len(loading_info["missing_keys"]) == 0

model.to(torch_device)
image = model(**self.dummy_input)
image = model(**self.get_dummy_inputs())

assert image is not None, "Make sure output is not None"

Expand Down Expand Up @@ -168,17 +161,24 @@ def test_output_pretrained(self):
]
)

self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2)


class TestAutoencoderKLMemory(AutoencoderKLTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for AutoencoderKL."""


class TestAutoencoderKLSlicingTiling(AutoencoderKLTesterConfig, NewAutoencoderTesterMixin):
"""Slicing and tiling tests for AutoencoderKL."""


@slow
class AutoencoderKLIntegrationTests(unittest.TestCase):
class AutoencoderKLIntegrationTests:
def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"

def tearDown(self):
def teardown_method(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)

Expand Down Expand Up @@ -341,10 +341,7 @@ def test_stable_diffusion_decode_fp16(self, seed, expected_slice):

@parameterized.expand([(13,), (16,), (27,)])
@require_torch_gpu
@unittest.skipIf(
not is_xformers_available(),
reason="xformers is not required when using PyTorch 2.0.",
)
@pytest.mark.skipif(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed):
model = self.get_sd_vae_model(fp16=True)
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
Expand All @@ -362,10 +359,7 @@ def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed):

@parameterized.expand([(13,), (16,), (37,)])
@require_torch_gpu
@unittest.skipIf(
not is_xformers_available(),
reason="xformers is not required when using PyTorch 2.0.",
)
@pytest.mark.skipif(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
model = self.get_sd_vae_model()
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
Expand Down