diff options
author | Volpeon <git@volpeon.ink> | 2022-12-04 09:24:37 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-04 09:24:37 +0100 |
commit | 2825ba2f2030b2fd3e841aad416a4fd28d67615a (patch) | |
tree | 180afc9ae42c7d29b7229423d4db53183e6ad49b /textual_inversion.py | |
parent | Update (diff) | |
download | textual-inversion-diff-2825ba2f2030b2fd3e841aad416a4fd28d67615a.tar.gz textual-inversion-diff-2825ba2f2030b2fd3e841aad416a4fd28d67615a.tar.bz2 textual-inversion-diff-2825ba2f2030b2fd3e841aad416a4fd28d67615a.zip |
Update
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index cd2d22b..1a5a8d0 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -23,6 +23,7 @@ from transformers import CLIPTextModel, CLIPTokenizer | |||
23 | from slugify import slugify | 23 | from slugify import slugify |
24 | 24 | ||
25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
26 | from pipelines.util import set_use_memory_efficient_attention_xformers | ||
26 | from data.csv import CSVDataModule | 27 | from data.csv import CSVDataModule |
27 | from training.optimization import get_one_cycle_schedule | 28 | from training.optimization import get_one_cycle_schedule |
28 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
@@ -546,7 +547,8 @@ def main(): | |||
546 | args.pretrained_model_name_or_path, subfolder='scheduler') | 547 | args.pretrained_model_name_or_path, subfolder='scheduler') |
547 | 548 | ||
548 | vae.enable_slicing() | 549 | vae.enable_slicing() |
549 | unet.set_use_memory_efficient_attention_xformers(True) | 550 | set_use_memory_efficient_attention_xformers(unet, True) |
551 | set_use_memory_efficient_attention_xformers(vae, True) | ||
550 | 552 | ||
551 | if args.gradient_checkpointing: | 553 | if args.gradient_checkpointing: |
552 | text_encoder.gradient_checkpointing_enable() | 554 | text_encoder.gradient_checkpointing_enable() |