summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-04 09:24:37 +0100
committerVolpeon <git@volpeon.ink>2022-12-04 09:24:37 +0100
commit2825ba2f2030b2fd3e841aad416a4fd28d67615a (patch)
tree180afc9ae42c7d29b7229423d4db53183e6ad49b /textual_inversion.py
parentUpdate (diff)
downloadtextual-inversion-diff-2825ba2f2030b2fd3e841aad416a4fd28d67615a.tar.gz
textual-inversion-diff-2825ba2f2030b2fd3e841aad416a4fd28d67615a.tar.bz2
textual-inversion-diff-2825ba2f2030b2fd3e841aad416a4fd28d67615a.zip
Update
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index cd2d22b..1a5a8d0 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -23,6 +23,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
23from slugify import slugify 23from slugify import slugify
24 24
25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
26from pipelines.util import set_use_memory_efficient_attention_xformers
26from data.csv import CSVDataModule 27from data.csv import CSVDataModule
27from training.optimization import get_one_cycle_schedule 28from training.optimization import get_one_cycle_schedule
28from models.clip.prompt import PromptProcessor 29from models.clip.prompt import PromptProcessor
@@ -546,7 +547,8 @@ def main():
546 args.pretrained_model_name_or_path, subfolder='scheduler') 547 args.pretrained_model_name_or_path, subfolder='scheduler')
547 548
548 vae.enable_slicing() 549 vae.enable_slicing()
549 unet.set_use_memory_efficient_attention_xformers(True) 550 set_use_memory_efficient_attention_xformers(unet, True)
551 set_use_memory_efficient_attention_xformers(vae, True)
550 552
551 if args.gradient_checkpointing: 553 if args.gradient_checkpointing:
552 text_encoder.gradient_checkpointing_enable() 554 text_encoder.gradient_checkpointing_enable()