diff options
| -rw-r--r-- | dreambooth.py | 4 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 16 | ||||
| -rw-r--r-- | pipelines/util.py | 9 | ||||
| -rw-r--r-- | textual_inversion.py | 4 |
4 files changed, 15 insertions, 18 deletions
diff --git a/dreambooth.py b/dreambooth.py index f3f722e..b87763e 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -24,6 +24,7 @@ from transformers import CLIPTextModel, CLIPTokenizer | |||
| 24 | from slugify import slugify | 24 | from slugify import slugify |
| 25 | 25 | ||
| 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 27 | from pipelines.util import set_use_memory_efficient_attention_xformers | ||
| 27 | from data.csv import CSVDataModule | 28 | from data.csv import CSVDataModule |
| 28 | from training.optimization import get_one_cycle_schedule | 29 | from training.optimization import get_one_cycle_schedule |
| 29 | from models.clip.prompt import PromptProcessor | 30 | from models.clip.prompt import PromptProcessor |
| @@ -594,7 +595,8 @@ def main(): | |||
| 594 | args.pretrained_model_name_or_path, subfolder='scheduler') | 595 | args.pretrained_model_name_or_path, subfolder='scheduler') |
| 595 | 596 | ||
| 596 | vae.enable_slicing() | 597 | vae.enable_slicing() |
| 597 | unet.set_use_memory_efficient_attention_xformers(True) | 598 | set_use_memory_efficient_attention_xformers(unet, True) |
| 599 | set_use_memory_efficient_attention_xformers(vae, True) | ||
| 598 | 600 | ||
| 599 | if args.gradient_checkpointing: | 601 | if args.gradient_checkpointing: |
| 600 | unet.enable_gradient_checkpointing() | 602 | unet.enable_gradient_checkpointing() |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 9b51763..f80e951 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -78,22 +78,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 78 | scheduler=scheduler, | 78 | scheduler=scheduler, |
| 79 | ) | 79 | ) |
| 80 | 80 | ||
| 81 | def enable_xformers_memory_efficient_attention(self): | ||
| 82 | r""" | ||
| 83 | Enable memory efficient attention as implemented in xformers. | ||
| 84 | When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference | ||
| 85 | time. Speed up at training time is not guaranteed. | ||
| 86 | Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention | ||
| 87 | is used. | ||
| 88 | """ | ||
| 89 | self.unet.set_use_memory_efficient_attention_xformers(True) | ||
| 90 | |||
| 91 | def disable_xformers_memory_efficient_attention(self): | ||
| 92 | r""" | ||
| 93 | Disable memory efficient attention as implemented in xformers. | ||
| 94 | """ | ||
| 95 | self.unet.set_use_memory_efficient_attention_xformers(False) | ||
| 96 | |||
| 97 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): | 81 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): |
| 98 | r""" | 82 | r""" |
| 99 | Enable sliced attention computation. | 83 | Enable sliced attention computation. |
diff --git a/pipelines/util.py b/pipelines/util.py new file mode 100644 index 0000000..661dbee --- /dev/null +++ b/pipelines/util.py | |||
| @@ -0,0 +1,9 @@ | |||
| 1 | import torch | ||
| 2 | |||
| 3 | |||
| 4 | def set_use_memory_efficient_attention_xformers(module: torch.nn.Module, valid: bool) -> None: | ||
| 5 | if hasattr(module, "set_use_memory_efficient_attention_xformers"): | ||
| 6 | module.set_use_memory_efficient_attention_xformers(valid) | ||
| 7 | |||
| 8 | for child in module.children(): | ||
| 9 | set_use_memory_efficient_attention_xformers(child, valid) | ||
diff --git a/textual_inversion.py b/textual_inversion.py index cd2d22b..1a5a8d0 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -23,6 +23,7 @@ from transformers import CLIPTextModel, CLIPTokenizer | |||
| 23 | from slugify import slugify | 23 | from slugify import slugify |
| 24 | 24 | ||
| 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 26 | from pipelines.util import set_use_memory_efficient_attention_xformers | ||
| 26 | from data.csv import CSVDataModule | 27 | from data.csv import CSVDataModule |
| 27 | from training.optimization import get_one_cycle_schedule | 28 | from training.optimization import get_one_cycle_schedule |
| 28 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
| @@ -546,7 +547,8 @@ def main(): | |||
| 546 | args.pretrained_model_name_or_path, subfolder='scheduler') | 547 | args.pretrained_model_name_or_path, subfolder='scheduler') |
| 547 | 548 | ||
| 548 | vae.enable_slicing() | 549 | vae.enable_slicing() |
| 549 | unet.set_use_memory_efficient_attention_xformers(True) | 550 | set_use_memory_efficient_attention_xformers(unet, True) |
| 551 | set_use_memory_efficient_attention_xformers(vae, True) | ||
| 550 | 552 | ||
| 551 | if args.gradient_checkpointing: | 553 | if args.gradient_checkpointing: |
| 552 | text_encoder.gradient_checkpointing_enable() | 554 | text_encoder.gradient_checkpointing_enable() |
