diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-20 22:07:06 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-20 22:07:06 +0100 |
| commit | 1d5038280d44a36351cb3aa21aad7a8eff220c94 (patch) | |
| tree | 55ae4c0a660f5218c072d33f2896024b47c05c6b /train_dreambooth.py | |
| parent | Dependency cleanup/upgrades (diff) | |
| download | textual-inversion-diff-1d5038280d44a36351cb3aa21aad7a8eff220c94.tar.gz textual-inversion-diff-1d5038280d44a36351cb3aa21aad7a8eff220c94.tar.bz2 textual-inversion-diff-1d5038280d44a36351cb3aa21aad7a8eff220c94.zip | |
Fix training
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 3eecf9c..0f8fece 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -23,7 +23,6 @@ from slugify import slugify | |||
| 23 | 23 | ||
| 24 | from common import load_text_embeddings | 24 | from common import load_text_embeddings |
| 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 26 | from pipelines.util import set_use_memory_efficient_attention_xformers | ||
| 27 | from data.csv import CSVDataModule | 26 | from data.csv import CSVDataModule |
| 28 | from training.optimization import get_one_cycle_schedule | 27 | from training.optimization import get_one_cycle_schedule |
| 29 | from models.clip.prompt import PromptProcessor | 28 | from models.clip.prompt import PromptProcessor |
| @@ -614,8 +613,8 @@ def main(): | |||
| 614 | args.pretrained_model_name_or_path, subfolder='scheduler') | 613 | args.pretrained_model_name_or_path, subfolder='scheduler') |
| 615 | 614 | ||
| 616 | vae.enable_slicing() | 615 | vae.enable_slicing() |
| 617 | set_use_memory_efficient_attention_xformers(unet, True) | 616 | vae.set_use_memory_efficient_attention_xformers(True) |
| 618 | set_use_memory_efficient_attention_xformers(vae, True) | 617 | unet.set_use_memory_efficient_attention_xformers(True) |
| 619 | 618 | ||
| 620 | if args.gradient_checkpointing: | 619 | if args.gradient_checkpointing: |
| 621 | unet.enable_gradient_checkpointing() | 620 | unet.enable_gradient_checkpointing() |
| @@ -964,6 +963,7 @@ def main(): | |||
| 964 | # Add noise to the latents according to the noise magnitude at each timestep | 963 | # Add noise to the latents according to the noise magnitude at each timestep |
| 965 | # (this is the forward diffusion process) | 964 | # (this is the forward diffusion process) |
| 966 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 965 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 966 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
| 967 | 967 | ||
| 968 | # Get the text embedding for conditioning | 968 | # Get the text embedding for conditioning |
| 969 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | 969 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) |
| @@ -1065,6 +1065,7 @@ def main(): | |||
| 1065 | timesteps = timesteps.long() | 1065 | timesteps = timesteps.long() |
| 1066 | 1066 | ||
| 1067 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 1067 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 1068 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
| 1068 | 1069 | ||
| 1069 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | 1070 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) |
| 1070 | 1071 | ||
