From 2825ba2f2030b2fd3e841aad416a4fd28d67615a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 4 Dec 2022 09:24:37 +0100 Subject: Update --- dreambooth.py | 4 +++- pipelines/stable_diffusion/vlpn_stable_diffusion.py | 16 ---------------- pipelines/util.py | 9 +++++++++ textual_inversion.py | 4 +++- 4 files changed, 15 insertions(+), 18 deletions(-) create mode 100644 pipelines/util.py diff --git a/dreambooth.py b/dreambooth.py index f3f722e..b87763e 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -24,6 +24,7 @@ from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +from pipelines.util import set_use_memory_efficient_attention_xformers from data.csv import CSVDataModule from training.optimization import get_one_cycle_schedule from models.clip.prompt import PromptProcessor @@ -594,7 +595,8 @@ def main(): args.pretrained_model_name_or_path, subfolder='scheduler') vae.enable_slicing() - unet.set_use_memory_efficient_attention_xformers(True) + set_use_memory_efficient_attention_xformers(unet, True) + set_use_memory_efficient_attention_xformers(vae, True) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 9b51763..f80e951 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -78,22 +78,6 @@ class VlpnStableDiffusion(DiffusionPipeline): scheduler=scheduler, ) - def enable_xformers_memory_efficient_attention(self): - r""" - Enable memory efficient attention as implemented in xformers. - When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference - time. Speed up at training time is not guaranteed. - Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention - is used. - """ - self.unet.set_use_memory_efficient_attention_xformers(True) - - def disable_xformers_memory_efficient_attention(self): - r""" - Disable memory efficient attention as implemented in xformers. - """ - self.unet.set_use_memory_efficient_attention_xformers(False) - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" Enable sliced attention computation. diff --git a/pipelines/util.py b/pipelines/util.py new file mode 100644 index 0000000..661dbee --- /dev/null +++ b/pipelines/util.py @@ -0,0 +1,9 @@ +import torch + + +def set_use_memory_efficient_attention_xformers(module: torch.nn.Module, valid: bool) -> None: + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid) + + for child in module.children(): + set_use_memory_efficient_attention_xformers(child, valid) diff --git a/textual_inversion.py b/textual_inversion.py index cd2d22b..1a5a8d0 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -23,6 +23,7 @@ from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +from pipelines.util import set_use_memory_efficient_attention_xformers from data.csv import CSVDataModule from training.optimization import get_one_cycle_schedule from models.clip.prompt import PromptProcessor @@ -546,7 +547,8 @@ def main(): args.pretrained_model_name_or_path, subfolder='scheduler') vae.enable_slicing() - unet.set_use_memory_efficient_attention_xformers(True) + set_use_memory_efficient_attention_xformers(unet, True) + set_use_memory_efficient_attention_xformers(vae, True) if args.gradient_checkpointing: text_encoder.gradient_checkpointing_enable() -- cgit v1.2.3-54-g00ecf