summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-20 22:07:06 +0100
committerVolpeon <git@volpeon.ink>2022-12-20 22:07:06 +0100
commit1d5038280d44a36351cb3aa21aad7a8eff220c94 (patch)
tree55ae4c0a660f5218c072d33f2896024b47c05c6b /train_dreambooth.py
parentDependency cleanup/upgrades (diff)
downloadtextual-inversion-diff-1d5038280d44a36351cb3aa21aad7a8eff220c94.tar.gz
textual-inversion-diff-1d5038280d44a36351cb3aa21aad7a8eff220c94.tar.bz2
textual-inversion-diff-1d5038280d44a36351cb3aa21aad7a8eff220c94.zip
Fix training
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 3eecf9c..0f8fece 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -23,7 +23,6 @@ from slugify import slugify
23 23
24from common import load_text_embeddings 24from common import load_text_embeddings
25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
26from pipelines.util import set_use_memory_efficient_attention_xformers
27from data.csv import CSVDataModule 26from data.csv import CSVDataModule
28from training.optimization import get_one_cycle_schedule 27from training.optimization import get_one_cycle_schedule
29from models.clip.prompt import PromptProcessor 28from models.clip.prompt import PromptProcessor
@@ -614,8 +613,8 @@ def main():
614 args.pretrained_model_name_or_path, subfolder='scheduler') 613 args.pretrained_model_name_or_path, subfolder='scheduler')
615 614
616 vae.enable_slicing() 615 vae.enable_slicing()
617 set_use_memory_efficient_attention_xformers(unet, True) 616 vae.set_use_memory_efficient_attention_xformers(True)
618 set_use_memory_efficient_attention_xformers(vae, True) 617 unet.set_use_memory_efficient_attention_xformers(True)
619 618
620 if args.gradient_checkpointing: 619 if args.gradient_checkpointing:
621 unet.enable_gradient_checkpointing() 620 unet.enable_gradient_checkpointing()
@@ -964,6 +963,7 @@ def main():
964 # Add noise to the latents according to the noise magnitude at each timestep 963 # Add noise to the latents according to the noise magnitude at each timestep
965 # (this is the forward diffusion process) 964 # (this is the forward diffusion process)
966 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 965 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
966 noisy_latents = noisy_latents.to(dtype=unet.dtype)
967 967
968 # Get the text embedding for conditioning 968 # Get the text embedding for conditioning
969 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) 969 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
@@ -1065,6 +1065,7 @@ def main():
1065 timesteps = timesteps.long() 1065 timesteps = timesteps.long()
1066 1066
1067 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 1067 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
1068 noisy_latents = noisy_latents.to(dtype=unet.dtype)
1068 1069
1069 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) 1070 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
1070 1071