diff options
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/dreambooth.py b/dreambooth.py index 770ad38..9786e0f 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -25,6 +25,7 @@ from slugify import slugify | |||
25 | from schedulers.scheduling_euler_a import EulerAScheduler | 25 | from schedulers.scheduling_euler_a import EulerAScheduler |
26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
27 | from data.csv import CSVDataModule | 27 | from data.csv import CSVDataModule |
28 | from models.clip.prompt import PromptProcessor | ||
28 | 29 | ||
29 | logger = get_logger(__name__) | 30 | logger = get_logger(__name__) |
30 | 31 | ||
@@ -141,7 +142,7 @@ def parse_args(): | |||
141 | parser.add_argument( | 142 | parser.add_argument( |
142 | "--lr_scheduler", | 143 | "--lr_scheduler", |
143 | type=str, | 144 | type=str, |
144 | default="cosine", | 145 | default="cosine_with_restarts", |
145 | help=( | 146 | help=( |
146 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 147 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
147 | ' "constant", "constant_with_warmup"]' | 148 | ' "constant", "constant_with_warmup"]' |
@@ -494,6 +495,8 @@ def main(): | |||
494 | device=accelerator.device | 495 | device=accelerator.device |
495 | ) if args.use_ema else None | 496 | ) if args.use_ema else None |
496 | 497 | ||
498 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
499 | |||
497 | if args.gradient_checkpointing: | 500 | if args.gradient_checkpointing: |
498 | unet.enable_gradient_checkpointing() | 501 | unet.enable_gradient_checkpointing() |
499 | 502 | ||
@@ -557,7 +560,7 @@ def main(): | |||
557 | pixel_values = torch.stack(pixel_values) | 560 | pixel_values = torch.stack(pixel_values) |
558 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 561 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
559 | 562 | ||
560 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | 563 | input_ids = prompt_processor.unify_input_ids(input_ids) |
561 | 564 | ||
562 | batch = { | 565 | batch = { |
563 | "prompts": prompts, | 566 | "prompts": prompts, |
@@ -570,7 +573,7 @@ def main(): | |||
570 | datamodule = CSVDataModule( | 573 | datamodule = CSVDataModule( |
571 | data_file=args.train_data_file, | 574 | data_file=args.train_data_file, |
572 | batch_size=args.train_batch_size, | 575 | batch_size=args.train_batch_size, |
573 | tokenizer=tokenizer, | 576 | prompt_processor=prompt_processor, |
574 | instance_identifier=args.instance_identifier, | 577 | instance_identifier=args.instance_identifier, |
575 | class_identifier=args.class_identifier, | 578 | class_identifier=args.class_identifier, |
576 | class_subdir="cls", | 579 | class_subdir="cls", |
@@ -641,8 +644,8 @@ def main(): | |||
641 | optimizer=optimizer, | 644 | optimizer=optimizer, |
642 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 645 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
643 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 646 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
644 | num_cycles=args.lr_cycles or math.ceil( | 647 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( |
645 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), | 648 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), |
646 | ) | 649 | ) |
647 | else: | 650 | else: |
648 | lr_scheduler = get_scheduler( | 651 | lr_scheduler = get_scheduler( |
@@ -756,7 +759,7 @@ def main(): | |||
756 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 759 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
757 | 760 | ||
758 | # Get the text embedding for conditioning | 761 | # Get the text embedding for conditioning |
759 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | 762 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
760 | 763 | ||
761 | # Predict the noise residual | 764 | # Predict the noise residual |
762 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 765 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
@@ -832,7 +835,7 @@ def main(): | |||
832 | 835 | ||
833 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 836 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
834 | 837 | ||
835 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | 838 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
836 | 839 | ||
837 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 840 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
838 | 841 | ||