From 4b9edcb3caad34287289762dc28bf30788333d18 Mon Sep 17 00:00:00 2001 From: Kaan Baris BAYRAK Date: Fri, 16 Dec 2022 22:29:55 +0100 Subject: [PATCH 1/3] xformers added into modeling utils --- src/diffusers/modeling_utils.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 704ba00cad29..df394425b815 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -178,6 +178,38 @@ def disable_gradient_checkpointing(self): """ if self._supports_gradient_checkpointing: self.apply(partial(self._set_gradient_checkpointing, value=False)) + + + def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None: + # Recursively walk through all the children. + # Any children which exposes the set_use_memory_efficient_attention_xformers method + # gets the message + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid) + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + for module in self.children(): + if isinstance(module, torch.nn.Module): + fn_recursive_set_mem_eff(module) + + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.set_use_memory_efficient_attention_xformers(False) def save_pretrained( self, From 5b30a65360dcd7a2941cef522acda878185faa92 Mon Sep 17 00:00:00 2001 From: Kaan Baris BAYRAK Date: Fri, 16 Dec 2022 22:31:56 +0100 Subject: [PATCH 2/3] set_use added into pipeline utils --- src/diffusers/pipeline_utils.py | 34 +++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 7e759e987fb4..052d2643db9d 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -710,3 +710,37 @@ def progress_bar(self, iterable): def set_progress_bar_config(self, **kwargs): self._progress_bar_config = kwargs + + + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.set_use_memory_efficient_attention_xformers(False) + + def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None: + # Recursively walk through all the children. + # Any children which exposes the set_use_memory_efficient_attention_xformers method + # gets the message + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid) + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + module_names, _, _ = self.extract_init_dict(dict(self.config)) + for module_name in module_names: + module = getattr(self, module_name) + if isinstance(module, torch.nn.Module): + fn_recursive_set_mem_eff(module) From 153791c85ad542a2af15ca53435d96702b0f126b Mon Sep 17 00:00:00 2001 From: Kaan Baris BAYRAK Date: Fri, 16 Dec 2022 22:40:56 +0100 Subject: [PATCH 3/3] xformes added into train_dreambooth.py --- examples/dreambooth/train_dreambooth.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index d77fb5b2871f..458e52766c82 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -17,6 +17,7 @@ from accelerate.utils import set_seed from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler +from diffusers.utils.import_utils import is_xformers_available from huggingface_hub import HfFolder, Repository, whoami from PIL import Image from torchvision import transforms @@ -522,6 +523,15 @@ def main(): vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") + if is_xformers_available(): + try: + unet.enable_xformers_memory_efficient_attention() + except Exception as e: + logger.warning( + "Could not enable memory efficient attention. Make sure xformers is installed" + f" correctly and a GPU is available: {e}" + ) + vae.requires_grad_(False) if not args.train_text_encoder: text_encoder.requires_grad_(False)