From 1d5038280d44a36351cb3aa21aad7a8eff220c94 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 20 Dec 2022 22:07:06 +0100 Subject: Fix training --- train_dreambooth.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 3eecf9c..0f8fece 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -23,7 +23,6 @@ from slugify import slugify from common import load_text_embeddings from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion -from pipelines.util import set_use_memory_efficient_attention_xformers from data.csv import CSVDataModule from training.optimization import get_one_cycle_schedule from models.clip.prompt import PromptProcessor @@ -614,8 +613,8 @@ def main(): args.pretrained_model_name_or_path, subfolder='scheduler') vae.enable_slicing() - set_use_memory_efficient_attention_xformers(unet, True) - set_use_memory_efficient_attention_xformers(vae, True) + vae.set_use_memory_efficient_attention_xformers(True) + unet.set_use_memory_efficient_attention_xformers(True) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -964,6 +963,7 @@ def main(): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noisy_latents.to(dtype=unet.dtype) # Get the text embedding for conditioning encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) @@ -1065,6 +1065,7 @@ def main(): timesteps = timesteps.long() noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noisy_latents.to(dtype=unet.dtype) encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) -- cgit v1.2.3-54-g00ecf