summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--dreambooth.py4
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py16
-rw-r--r--pipelines/util.py9
-rw-r--r--textual_inversion.py4
4 files changed, 15 insertions, 18 deletions
diff --git a/dreambooth.py b/dreambooth.py
index f3f722e..b87763e 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -24,6 +24,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
24from slugify import slugify 24from slugify import slugify
25 25
26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
27from pipelines.util import set_use_memory_efficient_attention_xformers
27from data.csv import CSVDataModule 28from data.csv import CSVDataModule
28from training.optimization import get_one_cycle_schedule 29from training.optimization import get_one_cycle_schedule
29from models.clip.prompt import PromptProcessor 30from models.clip.prompt import PromptProcessor
@@ -594,7 +595,8 @@ def main():
594 args.pretrained_model_name_or_path, subfolder='scheduler') 595 args.pretrained_model_name_or_path, subfolder='scheduler')
595 596
596 vae.enable_slicing() 597 vae.enable_slicing()
597 unet.set_use_memory_efficient_attention_xformers(True) 598 set_use_memory_efficient_attention_xformers(unet, True)
599 set_use_memory_efficient_attention_xformers(vae, True)
598 600
599 if args.gradient_checkpointing: 601 if args.gradient_checkpointing:
600 unet.enable_gradient_checkpointing() 602 unet.enable_gradient_checkpointing()
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 9b51763..f80e951 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -78,22 +78,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
78 scheduler=scheduler, 78 scheduler=scheduler,
79 ) 79 )
80 80
81 def enable_xformers_memory_efficient_attention(self):
82 r"""
83 Enable memory efficient attention as implemented in xformers.
84 When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
85 time. Speed up at training time is not guaranteed.
86 Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
87 is used.
88 """
89 self.unet.set_use_memory_efficient_attention_xformers(True)
90
91 def disable_xformers_memory_efficient_attention(self):
92 r"""
93 Disable memory efficient attention as implemented in xformers.
94 """
95 self.unet.set_use_memory_efficient_attention_xformers(False)
96
97 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): 81 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
98 r""" 82 r"""
99 Enable sliced attention computation. 83 Enable sliced attention computation.
diff --git a/pipelines/util.py b/pipelines/util.py
new file mode 100644
index 0000000..661dbee
--- /dev/null
+++ b/pipelines/util.py
@@ -0,0 +1,9 @@
1import torch
2
3
4def set_use_memory_efficient_attention_xformers(module: torch.nn.Module, valid: bool) -> None:
5 if hasattr(module, "set_use_memory_efficient_attention_xformers"):
6 module.set_use_memory_efficient_attention_xformers(valid)
7
8 for child in module.children():
9 set_use_memory_efficient_attention_xformers(child, valid)
diff --git a/textual_inversion.py b/textual_inversion.py
index cd2d22b..1a5a8d0 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -23,6 +23,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
23from slugify import slugify 23from slugify import slugify
24 24
25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
26from pipelines.util import set_use_memory_efficient_attention_xformers
26from data.csv import CSVDataModule 27from data.csv import CSVDataModule
27from training.optimization import get_one_cycle_schedule 28from training.optimization import get_one_cycle_schedule
28from models.clip.prompt import PromptProcessor 29from models.clip.prompt import PromptProcessor
@@ -546,7 +547,8 @@ def main():
546 args.pretrained_model_name_or_path, subfolder='scheduler') 547 args.pretrained_model_name_or_path, subfolder='scheduler')
547 548
548 vae.enable_slicing() 549 vae.enable_slicing()
549 unet.set_use_memory_efficient_attention_xformers(True) 550 set_use_memory_efficient_attention_xformers(unet, True)
551 set_use_memory_efficient_attention_xformers(vae, True)
550 552
551 if args.gradient_checkpointing: 553 if args.gradient_checkpointing:
552 text_encoder.gradient_checkpointing_enable() 554 text_encoder.gradient_checkpointing_enable()