From 2825ba2f2030b2fd3e841aad416a4fd28d67615a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 4 Dec 2022 09:24:37 +0100 Subject: Update --- dreambooth.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index f3f722e..b87763e 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -24,6 +24,7 @@ from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +from pipelines.util import set_use_memory_efficient_attention_xformers from data.csv import CSVDataModule from training.optimization import get_one_cycle_schedule from models.clip.prompt import PromptProcessor @@ -594,7 +595,8 @@ def main(): args.pretrained_model_name_or_path, subfolder='scheduler') vae.enable_slicing() - unet.set_use_memory_efficient_attention_xformers(True) + set_use_memory_efficient_attention_xformers(unet, True) + set_use_memory_efficient_attention_xformers(vae, True) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() -- cgit v1.2.3-54-g00ecf