From b938d300fda21a824d8951b869ad787196aa396c Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Fri, 17 Oct 2025 10:01:22 +0800 Subject: [PATCH 1/4] adjust unit tests for wan pipeline Signed-off-by: Liu, Kaixuan --- tests/pipelines/test_pipelines_common.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index db8209835be4..022262a8eefe 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1438,19 +1438,19 @@ def test_save_load_float16(self, expected_max_diff=1e-2): with tempfile.TemporaryDirectory() as tmpdir: pipe.save_pretrained(tmpdir) pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16) - for component in pipe_loaded.components.values(): + for name, component in pipe_loaded.components.items(): if hasattr(component, "set_default_attn_processor"): component.set_default_attn_processor() - pipe_loaded.to(torch_device) + if hasattr(component, "dtype"): + self.assertTrue( + component.dtype == torch.float16, + f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.", + ) + if hasattr(component, "half"): + # Although all components for pipe_loaded should be float16 now, some submodules still use fp32, like in https://github.com/huggingface/transformers/blob/v4.57.1/src/transformers/models/t5/modeling_t5.py#L783, so we need to do the conversion again manally to align with the datatype we use in pipe exactly + component = component.to(torch_device).half() pipe_loaded.set_progress_bar_config(disable=None) - for name, component in pipe_loaded.components.items(): - if hasattr(component, "dtype"): - self.assertTrue( - component.dtype == torch.float16, - f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.", - ) - inputs = self.get_dummy_inputs(torch_device) output_loaded = pipe_loaded(**inputs)[0] max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() From 2244e237336b720236af955d17811d83953624e9 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Thu, 30 Oct 2025 07:26:47 +0000 Subject: [PATCH 2/4] update code Signed-off-by: Liu, Kaixuan --- tests/pipelines/test_pipelines_common.py | 23 ++++++++++------------- tests/pipelines/wan/test_wan_22.py | 18 +++++++++--------- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 022262a8eefe..fa19d583ae3f 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1420,10 +1420,7 @@ def test_float16_inference(self, expected_max_diff=5e-2): @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") @require_accelerator def test_save_load_float16(self, expected_max_diff=1e-2): - components = self.get_dummy_components() - for name, module in components.items(): - if hasattr(module, "half"): - components[name] = module.to(torch_device).half() + components = self.get_dummy_components(dtype=torch.float16) pipe = self.pipeline_class(**components) for component in pipe.components.values(): @@ -1438,19 +1435,19 @@ def test_save_load_float16(self, expected_max_diff=1e-2): with tempfile.TemporaryDirectory() as tmpdir: pipe.save_pretrained(tmpdir) pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16) - for name, component in pipe_loaded.components.items(): + for component in pipe_loaded.components.values(): if hasattr(component, "set_default_attn_processor"): component.set_default_attn_processor() - if hasattr(component, "dtype"): - self.assertTrue( - component.dtype == torch.float16, - f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.", - ) - if hasattr(component, "half"): - # Although all components for pipe_loaded should be float16 now, some submodules still use fp32, like in https://github.com/huggingface/transformers/blob/v4.57.1/src/transformers/models/t5/modeling_t5.py#L783, so we need to do the conversion again manally to align with the datatype we use in pipe exactly - component = component.to(torch_device).half() + pipe_loaded.to(torch_device) pipe_loaded.set_progress_bar_config(disable=None) + for name, component in pipe_loaded.components.items(): + if hasattr(component, "dtype"): + self.assertTrue( + component.dtype == torch.float16, + f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.", + ) + inputs = self.get_dummy_inputs(torch_device) output_loaded = pipe_loaded(**inputs)[0] max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() diff --git a/tests/pipelines/wan/test_wan_22.py b/tests/pipelines/wan/test_wan_22.py index 56ef5ceb97ed..6fa17c168c6c 100644 --- a/tests/pipelines/wan/test_wan_22.py +++ b/tests/pipelines/wan/test_wan_22.py @@ -51,7 +51,7 @@ class Wan22PipelineFastTests(PipelineTesterMixin, unittest.TestCase): test_xformers_attention = False supports_dduf = False - def get_dummy_components(self): + def get_dummy_components(self, dtype=torch.float32): torch.manual_seed(0) vae = AutoencoderKLWan( base_dim=3, @@ -59,11 +59,11 @@ def get_dummy_components(self): dim_mult=[1, 1, 1, 1], num_res_blocks=1, temperal_downsample=[False, True, True], - ) + ).to(dtype=dtype) torch.manual_seed(0) scheduler = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0) - text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5", dtype=dtype) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") torch.manual_seed(0) @@ -80,7 +80,7 @@ def get_dummy_components(self): cross_attn_norm=True, qk_norm="rms_norm_across_heads", rope_max_seq_len=32, - ) + ).to(dtype=dtype) torch.manual_seed(0) transformer_2 = WanTransformer3DModel( @@ -96,7 +96,7 @@ def get_dummy_components(self): cross_attn_norm=True, qk_norm="rms_norm_across_heads", rope_max_seq_len=32, - ) + ).to(dtype=dtype) components = { "transformer": transformer, @@ -215,7 +215,7 @@ class Wan225BPipelineFastTests(PipelineTesterMixin, unittest.TestCase): test_xformers_attention = False supports_dduf = False - def get_dummy_components(self): + def get_dummy_components(self, dtype=torch.float32): torch.manual_seed(0) vae = AutoencoderKLWan( base_dim=3, @@ -231,11 +231,11 @@ def get_dummy_components(self): scale_factor_spatial=16, scale_factor_temporal=4, temperal_downsample=[False, True, True], - ) + ).to(dtype=dtype) torch.manual_seed(0) scheduler = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0) - text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5", dtype=dtype) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") torch.manual_seed(0) @@ -252,7 +252,7 @@ def get_dummy_components(self): cross_attn_norm=True, qk_norm="rms_norm_across_heads", rope_max_seq_len=32, - ) + ).to(dtype=dtype) components = { "transformer": transformer, From 5305169e7476ec6f8b069645edeb17cf060080de Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Thu, 30 Oct 2025 08:19:06 +0000 Subject: [PATCH 3/4] avoid adjusting common `get_dummy_components` API Signed-off-by: Liu, Kaixuan --- tests/pipelines/test_pipelines_common.py | 5 ++- tests/pipelines/wan/test_wan_22.py | 45 +++++++++++++++++++++--- 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index fa19d583ae3f..db8209835be4 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1420,7 +1420,10 @@ def test_float16_inference(self, expected_max_diff=5e-2): @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") @require_accelerator def test_save_load_float16(self, expected_max_diff=1e-2): - components = self.get_dummy_components(dtype=torch.float16) + components = self.get_dummy_components() + for name, module in components.items(): + if hasattr(module, "half"): + components[name] = module.to(torch_device).half() pipe = self.pipeline_class(**components) for component in pipe.components.values(): diff --git a/tests/pipelines/wan/test_wan_22.py b/tests/pipelines/wan/test_wan_22.py index 6fa17c168c6c..81337e753dbc 100644 --- a/tests/pipelines/wan/test_wan_22.py +++ b/tests/pipelines/wan/test_wan_22.py @@ -155,6 +155,43 @@ def test_inference(self): def test_attention_slicing_forward_pass(self): pass + @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") + def test_save_load_float16(self, expected_max_diff=1e-2): + # Use get_dummy_components with dtype parameter instead of converting components + components = self.get_dummy_components(dtype=torch.float16) + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16) + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for name, component in pipe_loaded.components.items(): + if hasattr(component, "dtype"): + self.assertTrue( + component.dtype == torch.float16, + f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.", + ) + + inputs = self.get_dummy_inputs(torch_device) + output_loaded = pipe_loaded(**inputs)[0] + max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max() + self.assertLess( + max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading." + ) + def test_save_load_optional_components(self, expected_max_difference=1e-4): optional_component = "transformer" @@ -215,7 +252,7 @@ class Wan225BPipelineFastTests(PipelineTesterMixin, unittest.TestCase): test_xformers_attention = False supports_dduf = False - def get_dummy_components(self, dtype=torch.float32): + def get_dummy_components(self): torch.manual_seed(0) vae = AutoencoderKLWan( base_dim=3, @@ -231,11 +268,11 @@ def get_dummy_components(self, dtype=torch.float32): scale_factor_spatial=16, scale_factor_temporal=4, temperal_downsample=[False, True, True], - ).to(dtype=dtype) + ) torch.manual_seed(0) scheduler = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0) - text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5", dtype=dtype) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") torch.manual_seed(0) @@ -252,7 +289,7 @@ def get_dummy_components(self, dtype=torch.float32): cross_attn_norm=True, qk_norm="rms_norm_across_heads", rope_max_seq_len=32, - ).to(dtype=dtype) + ) components = { "transformer": transformer, From ecd4c8b621b5a423278b3441b68243da4e277ab5 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Thu, 30 Oct 2025 11:07:45 +0000 Subject: [PATCH 4/4] use `form_pretrained` to `transformer` and `transformer_2` Signed-off-by: Liu, Kaixuan --- tests/pipelines/wan/test_wan_22.py | 42 +++++++++--------------------- 1 file changed, 13 insertions(+), 29 deletions(-) diff --git a/tests/pipelines/wan/test_wan_22.py b/tests/pipelines/wan/test_wan_22.py index 81337e753dbc..adcc3f752531 100644 --- a/tests/pipelines/wan/test_wan_22.py +++ b/tests/pipelines/wan/test_wan_22.py @@ -22,6 +22,7 @@ from diffusers import AutoencoderKLWan, UniPCMultistepScheduler, WanPipeline, WanTransformer3DModel from ...testing_utils import ( + require_accelerator, enable_full_determinism, torch_device, ) @@ -63,40 +64,22 @@ def get_dummy_components(self, dtype=torch.float32): torch.manual_seed(0) scheduler = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0) - text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5", dtype=dtype) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5", torch_dtype=dtype) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") torch.manual_seed(0) - transformer = WanTransformer3DModel( - patch_size=(1, 2, 2), - num_attention_heads=2, - attention_head_dim=12, - in_channels=16, - out_channels=16, - text_dim=32, - freq_dim=256, - ffn_dim=32, - num_layers=2, - cross_attn_norm=True, - qk_norm="rms_norm_across_heads", - rope_max_seq_len=32, - ).to(dtype=dtype) + # Use from_pretrained with a tiny model to ensure proper dtype handling + # This ensures _keep_in_fp32_modules and _skip_layerwise_casting_patterns are respected + transformer = WanTransformer3DModel.from_pretrained( + "Kaixuanliu/tiny-random-wan-transformer", + torch_dtype=dtype + ) torch.manual_seed(0) - transformer_2 = WanTransformer3DModel( - patch_size=(1, 2, 2), - num_attention_heads=2, - attention_head_dim=12, - in_channels=16, - out_channels=16, - text_dim=32, - freq_dim=256, - ffn_dim=32, - num_layers=2, - cross_attn_norm=True, - qk_norm="rms_norm_across_heads", - rope_max_seq_len=32, - ).to(dtype=dtype) + transformer_2 = WanTransformer3DModel.from_pretrained( + "Kaixuanliu/tiny-random-wan-transformer", + torch_dtype=dtype + ) components = { "transformer": transformer, @@ -156,6 +139,7 @@ def test_attention_slicing_forward_pass(self): pass @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") + @require_accelerator def test_save_load_float16(self, expected_max_diff=1e-2): # Use get_dummy_components with dtype parameter instead of converting components components = self.get_dummy_components(dtype=torch.float16)