diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-04 09:24:37 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-04 09:24:37 +0100 |
| commit | 2825ba2f2030b2fd3e841aad416a4fd28d67615a (patch) | |
| tree | 180afc9ae42c7d29b7229423d4db53183e6ad49b /dreambooth.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-2825ba2f2030b2fd3e841aad416a4fd28d67615a.tar.gz textual-inversion-diff-2825ba2f2030b2fd3e841aad416a4fd28d67615a.tar.bz2 textual-inversion-diff-2825ba2f2030b2fd3e841aad416a4fd28d67615a.zip | |
Update
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/dreambooth.py b/dreambooth.py index f3f722e..b87763e 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -24,6 +24,7 @@ from transformers import CLIPTextModel, CLIPTokenizer | |||
| 24 | from slugify import slugify | 24 | from slugify import slugify |
| 25 | 25 | ||
| 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 27 | from pipelines.util import set_use_memory_efficient_attention_xformers | ||
| 27 | from data.csv import CSVDataModule | 28 | from data.csv import CSVDataModule |
| 28 | from training.optimization import get_one_cycle_schedule | 29 | from training.optimization import get_one_cycle_schedule |
| 29 | from models.clip.prompt import PromptProcessor | 30 | from models.clip.prompt import PromptProcessor |
| @@ -594,7 +595,8 @@ def main(): | |||
| 594 | args.pretrained_model_name_or_path, subfolder='scheduler') | 595 | args.pretrained_model_name_or_path, subfolder='scheduler') |
| 595 | 596 | ||
| 596 | vae.enable_slicing() | 597 | vae.enable_slicing() |
| 597 | unet.set_use_memory_efficient_attention_xformers(True) | 598 | set_use_memory_efficient_attention_xformers(unet, True) |
| 599 | set_use_memory_efficient_attention_xformers(vae, True) | ||
| 598 | 600 | ||
| 599 | if args.gradient_checkpointing: | 601 | if args.gradient_checkpointing: |
| 600 | unet.enable_gradient_checkpointing() | 602 | unet.enable_gradient_checkpointing() |
