diff options
author | Volpeon <git@volpeon.ink> | 2022-12-20 22:07:06 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-20 22:07:06 +0100 |
commit | 1d5038280d44a36351cb3aa21aad7a8eff220c94 (patch) | |
tree | 55ae4c0a660f5218c072d33f2896024b47c05c6b /train_dreambooth.py | |
parent | Dependency cleanup/upgrades (diff) | |
download | textual-inversion-diff-1d5038280d44a36351cb3aa21aad7a8eff220c94.tar.gz textual-inversion-diff-1d5038280d44a36351cb3aa21aad7a8eff220c94.tar.bz2 textual-inversion-diff-1d5038280d44a36351cb3aa21aad7a8eff220c94.zip |
Fix training
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 3eecf9c..0f8fece 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -23,7 +23,6 @@ from slugify import slugify | |||
23 | 23 | ||
24 | from common import load_text_embeddings | 24 | from common import load_text_embeddings |
25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
26 | from pipelines.util import set_use_memory_efficient_attention_xformers | ||
27 | from data.csv import CSVDataModule | 26 | from data.csv import CSVDataModule |
28 | from training.optimization import get_one_cycle_schedule | 27 | from training.optimization import get_one_cycle_schedule |
29 | from models.clip.prompt import PromptProcessor | 28 | from models.clip.prompt import PromptProcessor |
@@ -614,8 +613,8 @@ def main(): | |||
614 | args.pretrained_model_name_or_path, subfolder='scheduler') | 613 | args.pretrained_model_name_or_path, subfolder='scheduler') |
615 | 614 | ||
616 | vae.enable_slicing() | 615 | vae.enable_slicing() |
617 | set_use_memory_efficient_attention_xformers(unet, True) | 616 | vae.set_use_memory_efficient_attention_xformers(True) |
618 | set_use_memory_efficient_attention_xformers(vae, True) | 617 | unet.set_use_memory_efficient_attention_xformers(True) |
619 | 618 | ||
620 | if args.gradient_checkpointing: | 619 | if args.gradient_checkpointing: |
621 | unet.enable_gradient_checkpointing() | 620 | unet.enable_gradient_checkpointing() |
@@ -964,6 +963,7 @@ def main(): | |||
964 | # Add noise to the latents according to the noise magnitude at each timestep | 963 | # Add noise to the latents according to the noise magnitude at each timestep |
965 | # (this is the forward diffusion process) | 964 | # (this is the forward diffusion process) |
966 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 965 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
966 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
967 | 967 | ||
968 | # Get the text embedding for conditioning | 968 | # Get the text embedding for conditioning |
969 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | 969 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) |
@@ -1065,6 +1065,7 @@ def main(): | |||
1065 | timesteps = timesteps.long() | 1065 | timesteps = timesteps.long() |
1066 | 1066 | ||
1067 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 1067 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
1068 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
1068 | 1069 | ||
1069 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | 1070 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) |
1070 | 1071 | ||