diff options
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/dreambooth.py b/dreambooth.py index f3f722e..b87763e 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -24,6 +24,7 @@ from transformers import CLIPTextModel, CLIPTokenizer | |||
24 | from slugify import slugify | 24 | from slugify import slugify |
25 | 25 | ||
26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
27 | from pipelines.util import set_use_memory_efficient_attention_xformers | ||
27 | from data.csv import CSVDataModule | 28 | from data.csv import CSVDataModule |
28 | from training.optimization import get_one_cycle_schedule | 29 | from training.optimization import get_one_cycle_schedule |
29 | from models.clip.prompt import PromptProcessor | 30 | from models.clip.prompt import PromptProcessor |
@@ -594,7 +595,8 @@ def main(): | |||
594 | args.pretrained_model_name_or_path, subfolder='scheduler') | 595 | args.pretrained_model_name_or_path, subfolder='scheduler') |
595 | 596 | ||
596 | vae.enable_slicing() | 597 | vae.enable_slicing() |
597 | unet.set_use_memory_efficient_attention_xformers(True) | 598 | set_use_memory_efficient_attention_xformers(unet, True) |
599 | set_use_memory_efficient_attention_xformers(vae, True) | ||
598 | 600 | ||
599 | if args.gradient_checkpointing: | 601 | if args.gradient_checkpointing: |
600 | unet.enable_gradient_checkpointing() | 602 | unet.enable_gradient_checkpointing() |