diff --git a/tests/models/autoencoders/test_models_autoencoder_kl.py b/tests/models/autoencoders/test_models_autoencoder_kl.py index 5f11c6cb0ab3..1547f1cd2b78 100644 --- a/tests/models/autoencoders/test_models_autoencoder_kl.py +++ b/tests/models/autoencoders/test_models_autoencoder_kl.py @@ -14,18 +14,18 @@ # limitations under the License. import gc -import unittest +import pytest import torch from parameterized import parameterized from diffusers import AutoencoderKL from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import randn_tensor from ...testing_utils import ( backend_empty_cache, enable_full_determinism, - floats_tensor, load_hf_numpy, require_torch_accelerator, require_torch_accelerator_with_fp16, @@ -35,22 +35,30 @@ torch_all_close, torch_device, ) -from ..test_modeling_common import ModelTesterMixin -from .testing_utils import AutoencoderTesterMixin +from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin +from .testing_utils import NewAutoencoderTesterMixin enable_full_determinism() -class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): - model_class = AutoencoderKL - main_input_name = "sample" - base_precision = 1e-2 +class AutoencoderKLTesterConfig(BaseModelTesterConfig): + @property + def model_class(self): + return AutoencoderKL + + @property + def output_shape(self): + return (3, 32, 32) + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) - def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None): + def get_init_dict(self, block_out_channels=None, norm_num_groups=None): block_out_channels = block_out_channels or [2, 4] norm_num_groups = norm_num_groups or 2 - init_dict = { + return { "block_out_channels": block_out_channels, "in_channels": 3, "out_channels": 3, @@ -59,42 +67,27 @@ def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=Non "latent_channels": 4, "norm_num_groups": norm_num_groups, } - return init_dict - @property - def dummy_input(self): + def get_dummy_inputs(self): batch_size = 4 num_channels = 3 sizes = (32, 32) - - image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - + image = randn_tensor((batch_size, num_channels, *sizes), generator=self.generator, device=torch_device) return {"sample": image} - @property - def input_shape(self): - return (3, 32, 32) - - @property - def output_shape(self): - return (3, 32, 32) - - def prepare_init_args_and_inputs_for_common(self): - init_dict = self.get_autoencoder_kl_config() - inputs_dict = self.dummy_input - return init_dict, inputs_dict +class TestAutoencoderKL(AutoencoderKLTesterConfig, ModelTesterMixin, TrainingTesterMixin): def test_gradient_checkpointing_is_applied(self): expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) def test_from_pretrained_hub(self): model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True) - self.assertIsNotNone(model) - self.assertEqual(len(loading_info["missing_keys"]), 0) + assert model is not None + assert len(loading_info["missing_keys"]) == 0 model.to(torch_device) - image = model(**self.dummy_input) + image = model(**self.get_dummy_inputs()) assert image is not None, "Make sure output is not None" @@ -168,17 +161,24 @@ def test_output_pretrained(self): ] ) - self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2)) + assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2) + + +class TestAutoencoderKLMemory(AutoencoderKLTesterConfig, MemoryTesterMixin): + """Memory optimization tests for AutoencoderKL.""" + + +class TestAutoencoderKLSlicingTiling(AutoencoderKLTesterConfig, NewAutoencoderTesterMixin): + """Slicing and tiling tests for AutoencoderKL.""" @slow -class AutoencoderKLIntegrationTests(unittest.TestCase): +class AutoencoderKLIntegrationTests: def get_file_format(self, seed, shape): return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy" - def tearDown(self): + def teardown_method(self): # clean up the VRAM after each test - super().tearDown() gc.collect() backend_empty_cache(torch_device) @@ -341,10 +341,7 @@ def test_stable_diffusion_decode_fp16(self, seed, expected_slice): @parameterized.expand([(13,), (16,), (27,)]) @require_torch_gpu - @unittest.skipIf( - not is_xformers_available(), - reason="xformers is not required when using PyTorch 2.0.", - ) + @pytest.mark.skipif(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.") def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed): model = self.get_sd_vae_model(fp16=True) encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True) @@ -362,10 +359,7 @@ def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed): @parameterized.expand([(13,), (16,), (37,)]) @require_torch_gpu - @unittest.skipIf( - not is_xformers_available(), - reason="xformers is not required when using PyTorch 2.0.", - ) + @pytest.mark.skipif(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.") def test_stable_diffusion_decode_xformers_vs_2_0(self, seed): model = self.get_sd_vae_model() encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))