diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index f432928aa..319bfbc72 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -41,6 +41,11 @@ revision: '' weights_dtype: 'bfloat16' # This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) activations_dtype: 'bfloat16' +# The dtype for text_encoder model during load/compile +text_encoder_dtype: 'float32' + +# Whether to compile the text_encoder with torch.compile +compile_text_encoder: False # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False diff --git a/src/maxdiffusion/configs/base_wan_1_3b.yml b/src/maxdiffusion/configs/base_wan_1_3b.yml index 0e0552656..3134ed93d 100644 --- a/src/maxdiffusion/configs/base_wan_1_3b.yml +++ b/src/maxdiffusion/configs/base_wan_1_3b.yml @@ -41,6 +41,11 @@ revision: '' weights_dtype: 'bfloat16' # This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) activations_dtype: 'bfloat16' +# The dtype for text_encoder model during load/compile +text_encoder_dtype: 'float32' + +# Whether to compile the text_encoder with torch.compile +compile_text_encoder: False # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index bf29fa867..dfe300ddf 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -41,6 +41,11 @@ revision: '' weights_dtype: 'bfloat16' # This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) activations_dtype: 'bfloat16' +# The dtype for text_encoder model during load/compile +text_encoder_dtype: 'float32' + +# Whether to compile the text_encoder with torch.compile +compile_text_encoder: False # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False diff --git a/src/maxdiffusion/configs/base_wan_animate.yml b/src/maxdiffusion/configs/base_wan_animate.yml index 8f95c8558..7b3334c79 100644 --- a/src/maxdiffusion/configs/base_wan_animate.yml +++ b/src/maxdiffusion/configs/base_wan_animate.yml @@ -41,6 +41,11 @@ revision: '' weights_dtype: 'bfloat16' # This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) activations_dtype: 'bfloat16' +# The dtype for text_encoder model during load/compile +text_encoder_dtype: 'float32' + +# Whether to compile the text_encoder with torch.compile +compile_text_encoder: False # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml index ca2d239ab..f722e04e2 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_14b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -41,6 +41,11 @@ revision: '' weights_dtype: 'bfloat16' # This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) activations_dtype: 'bfloat16' +# The dtype for text_encoder model during load/compile +text_encoder_dtype: 'float32' + +# Whether to compile the text_encoder with torch.compile +compile_text_encoder: False # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index 90799524c..0aa533b40 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -41,6 +41,11 @@ revision: '' weights_dtype: 'bfloat16' # This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) activations_dtype: 'bfloat16' +# The dtype for text_encoder model during load/compile +text_encoder_dtype: 'float32' + +# Whether to compile the text_encoder with torch.compile +compile_text_encoder: False # Replicates vae across devices instead of using the model's sharding annotations for sharding. replicate_vae: False diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 5a5cfa293..c304ee423 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -270,13 +270,15 @@ def __init__( @classmethod def load_text_encoder(cls, config: HyperParameters): - torch_dtype = getattr(torch, str(config.weights_dtype), torch.float32) + text_encoder_dtype = getattr(config, "text_encoder_dtype", "float32") + torch_dtype = getattr(torch, str(text_encoder_dtype), torch.float32) text_encoder = UMT5EncoderModel.from_pretrained( config.pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype, ) - text_encoder = torch.compile(text_encoder) + if getattr(config, "compile_text_encoder", True): + text_encoder = torch.compile(text_encoder) return text_encoder @classmethod