diff options
-rw-r--r-- | dreambooth.py | 4 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 16 | ||||
-rw-r--r-- | pipelines/util.py | 9 | ||||
-rw-r--r-- | textual_inversion.py | 4 |
4 files changed, 15 insertions, 18 deletions
diff --git a/dreambooth.py b/dreambooth.py index f3f722e..b87763e 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -24,6 +24,7 @@ from transformers import CLIPTextModel, CLIPTokenizer | |||
24 | from slugify import slugify | 24 | from slugify import slugify |
25 | 25 | ||
26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
27 | from pipelines.util import set_use_memory_efficient_attention_xformers | ||
27 | from data.csv import CSVDataModule | 28 | from data.csv import CSVDataModule |
28 | from training.optimization import get_one_cycle_schedule | 29 | from training.optimization import get_one_cycle_schedule |
29 | from models.clip.prompt import PromptProcessor | 30 | from models.clip.prompt import PromptProcessor |
@@ -594,7 +595,8 @@ def main(): | |||
594 | args.pretrained_model_name_or_path, subfolder='scheduler') | 595 | args.pretrained_model_name_or_path, subfolder='scheduler') |
595 | 596 | ||
596 | vae.enable_slicing() | 597 | vae.enable_slicing() |
597 | unet.set_use_memory_efficient_attention_xformers(True) | 598 | set_use_memory_efficient_attention_xformers(unet, True) |
599 | set_use_memory_efficient_attention_xformers(vae, True) | ||
598 | 600 | ||
599 | if args.gradient_checkpointing: | 601 | if args.gradient_checkpointing: |
600 | unet.enable_gradient_checkpointing() | 602 | unet.enable_gradient_checkpointing() |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 9b51763..f80e951 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -78,22 +78,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
78 | scheduler=scheduler, | 78 | scheduler=scheduler, |
79 | ) | 79 | ) |
80 | 80 | ||
81 | def enable_xformers_memory_efficient_attention(self): | ||
82 | r""" | ||
83 | Enable memory efficient attention as implemented in xformers. | ||
84 | When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference | ||
85 | time. Speed up at training time is not guaranteed. | ||
86 | Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention | ||
87 | is used. | ||
88 | """ | ||
89 | self.unet.set_use_memory_efficient_attention_xformers(True) | ||
90 | |||
91 | def disable_xformers_memory_efficient_attention(self): | ||
92 | r""" | ||
93 | Disable memory efficient attention as implemented in xformers. | ||
94 | """ | ||
95 | self.unet.set_use_memory_efficient_attention_xformers(False) | ||
96 | |||
97 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): | 81 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): |
98 | r""" | 82 | r""" |
99 | Enable sliced attention computation. | 83 | Enable sliced attention computation. |
diff --git a/pipelines/util.py b/pipelines/util.py new file mode 100644 index 0000000..661dbee --- /dev/null +++ b/pipelines/util.py | |||
@@ -0,0 +1,9 @@ | |||
1 | import torch | ||
2 | |||
3 | |||
4 | def set_use_memory_efficient_attention_xformers(module: torch.nn.Module, valid: bool) -> None: | ||
5 | if hasattr(module, "set_use_memory_efficient_attention_xformers"): | ||
6 | module.set_use_memory_efficient_attention_xformers(valid) | ||
7 | |||
8 | for child in module.children(): | ||
9 | set_use_memory_efficient_attention_xformers(child, valid) | ||
diff --git a/textual_inversion.py b/textual_inversion.py index cd2d22b..1a5a8d0 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -23,6 +23,7 @@ from transformers import CLIPTextModel, CLIPTokenizer | |||
23 | from slugify import slugify | 23 | from slugify import slugify |
24 | 24 | ||
25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
26 | from pipelines.util import set_use_memory_efficient_attention_xformers | ||
26 | from data.csv import CSVDataModule | 27 | from data.csv import CSVDataModule |
27 | from training.optimization import get_one_cycle_schedule | 28 | from training.optimization import get_one_cycle_schedule |
28 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
@@ -546,7 +547,8 @@ def main(): | |||
546 | args.pretrained_model_name_or_path, subfolder='scheduler') | 547 | args.pretrained_model_name_or_path, subfolder='scheduler') |
547 | 548 | ||
548 | vae.enable_slicing() | 549 | vae.enable_slicing() |
549 | unet.set_use_memory_efficient_attention_xformers(True) | 550 | set_use_memory_efficient_attention_xformers(unet, True) |
551 | set_use_memory_efficient_attention_xformers(vae, True) | ||
550 | 552 | ||
551 | if args.gradient_checkpointing: | 553 | if args.gradient_checkpointing: |
552 | text_encoder.gradient_checkpointing_enable() | 554 | text_encoder.gradient_checkpointing_enable() |