diff options
author | Volpeon <git@volpeon.ink> | 2022-12-04 09:24:37 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-04 09:24:37 +0100 |
commit | 2825ba2f2030b2fd3e841aad416a4fd28d67615a (patch) | |
tree | 180afc9ae42c7d29b7229423d4db53183e6ad49b /dreambooth.py | |
parent | Update (diff) | |
download | textual-inversion-diff-2825ba2f2030b2fd3e841aad416a4fd28d67615a.tar.gz textual-inversion-diff-2825ba2f2030b2fd3e841aad416a4fd28d67615a.tar.bz2 textual-inversion-diff-2825ba2f2030b2fd3e841aad416a4fd28d67615a.zip |
Update
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/dreambooth.py b/dreambooth.py index f3f722e..b87763e 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -24,6 +24,7 @@ from transformers import CLIPTextModel, CLIPTokenizer | |||
24 | from slugify import slugify | 24 | from slugify import slugify |
25 | 25 | ||
26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
27 | from pipelines.util import set_use_memory_efficient_attention_xformers | ||
27 | from data.csv import CSVDataModule | 28 | from data.csv import CSVDataModule |
28 | from training.optimization import get_one_cycle_schedule | 29 | from training.optimization import get_one_cycle_schedule |
29 | from models.clip.prompt import PromptProcessor | 30 | from models.clip.prompt import PromptProcessor |
@@ -594,7 +595,8 @@ def main(): | |||
594 | args.pretrained_model_name_or_path, subfolder='scheduler') | 595 | args.pretrained_model_name_or_path, subfolder='scheduler') |
595 | 596 | ||
596 | vae.enable_slicing() | 597 | vae.enable_slicing() |
597 | unet.set_use_memory_efficient_attention_xformers(True) | 598 | set_use_memory_efficient_attention_xformers(unet, True) |
599 | set_use_memory_efficient_attention_xformers(vae, True) | ||
598 | 600 | ||
599 | if args.gradient_checkpointing: | 601 | if args.gradient_checkpointing: |
600 | unet.enable_gradient_checkpointing() | 602 | unet.enable_gradient_checkpointing() |