summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-04 09:24:37 +0100
committerVolpeon <git@volpeon.ink>2022-12-04 09:24:37 +0100
commit2825ba2f2030b2fd3e841aad416a4fd28d67615a (patch)
tree180afc9ae42c7d29b7229423d4db53183e6ad49b /dreambooth.py
parentUpdate (diff)
downloadtextual-inversion-diff-2825ba2f2030b2fd3e841aad416a4fd28d67615a.tar.gz
textual-inversion-diff-2825ba2f2030b2fd3e841aad416a4fd28d67615a.tar.bz2
textual-inversion-diff-2825ba2f2030b2fd3e841aad416a4fd28d67615a.zip
Update
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/dreambooth.py b/dreambooth.py
index f3f722e..b87763e 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -24,6 +24,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
24from slugify import slugify 24from slugify import slugify
25 25
26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
27from pipelines.util import set_use_memory_efficient_attention_xformers
27from data.csv import CSVDataModule 28from data.csv import CSVDataModule
28from training.optimization import get_one_cycle_schedule 29from training.optimization import get_one_cycle_schedule
29from models.clip.prompt import PromptProcessor 30from models.clip.prompt import PromptProcessor
@@ -594,7 +595,8 @@ def main():
594 args.pretrained_model_name_or_path, subfolder='scheduler') 595 args.pretrained_model_name_or_path, subfolder='scheduler')
595 596
596 vae.enable_slicing() 597 vae.enable_slicing()
597 unet.set_use_memory_efficient_attention_xformers(True) 598 set_use_memory_efficient_attention_xformers(unet, True)
599 set_use_memory_efficient_attention_xformers(vae, True)
598 600
599 if args.gradient_checkpointing: 601 if args.gradient_checkpointing:
600 unet.enable_gradient_checkpointing() 602 unet.enable_gradient_checkpointing()