diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-20 22:07:06 +0100 | 
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-20 22:07:06 +0100 | 
| commit | 1d5038280d44a36351cb3aa21aad7a8eff220c94 (patch) | |
| tree | 55ae4c0a660f5218c072d33f2896024b47c05c6b /train_dreambooth.py | |
| parent | Dependency cleanup/upgrades (diff) | |
| download | textual-inversion-diff-1d5038280d44a36351cb3aa21aad7a8eff220c94.tar.gz textual-inversion-diff-1d5038280d44a36351cb3aa21aad7a8eff220c94.tar.bz2 textual-inversion-diff-1d5038280d44a36351cb3aa21aad7a8eff220c94.zip | |
Fix training
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 7 | 
1 files changed, 4 insertions, 3 deletions
| diff --git a/train_dreambooth.py b/train_dreambooth.py index 3eecf9c..0f8fece 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -23,7 +23,6 @@ from slugify import slugify | |||
| 23 | 23 | ||
| 24 | from common import load_text_embeddings | 24 | from common import load_text_embeddings | 
| 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 
| 26 | from pipelines.util import set_use_memory_efficient_attention_xformers | ||
| 27 | from data.csv import CSVDataModule | 26 | from data.csv import CSVDataModule | 
| 28 | from training.optimization import get_one_cycle_schedule | 27 | from training.optimization import get_one_cycle_schedule | 
| 29 | from models.clip.prompt import PromptProcessor | 28 | from models.clip.prompt import PromptProcessor | 
| @@ -614,8 +613,8 @@ def main(): | |||
| 614 | args.pretrained_model_name_or_path, subfolder='scheduler') | 613 | args.pretrained_model_name_or_path, subfolder='scheduler') | 
| 615 | 614 | ||
| 616 | vae.enable_slicing() | 615 | vae.enable_slicing() | 
| 617 | set_use_memory_efficient_attention_xformers(unet, True) | 616 | vae.set_use_memory_efficient_attention_xformers(True) | 
| 618 | set_use_memory_efficient_attention_xformers(vae, True) | 617 | unet.set_use_memory_efficient_attention_xformers(True) | 
| 619 | 618 | ||
| 620 | if args.gradient_checkpointing: | 619 | if args.gradient_checkpointing: | 
| 621 | unet.enable_gradient_checkpointing() | 620 | unet.enable_gradient_checkpointing() | 
| @@ -964,6 +963,7 @@ def main(): | |||
| 964 | # Add noise to the latents according to the noise magnitude at each timestep | 963 | # Add noise to the latents according to the noise magnitude at each timestep | 
| 965 | # (this is the forward diffusion process) | 964 | # (this is the forward diffusion process) | 
| 966 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 965 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 
| 966 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
| 967 | 967 | ||
| 968 | # Get the text embedding for conditioning | 968 | # Get the text embedding for conditioning | 
| 969 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | 969 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | 
| @@ -1065,6 +1065,7 @@ def main(): | |||
| 1065 | timesteps = timesteps.long() | 1065 | timesteps = timesteps.long() | 
| 1066 | 1066 | ||
| 1067 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 1067 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 
| 1068 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
| 1068 | 1069 | ||
| 1069 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | 1070 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | 
| 1070 | 1071 | ||
