diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-18 18:08:32 +0200 | 
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-18 18:08:32 +0200 | 
| commit | 2ddfbd65e482fa2361e8ba41b657656f825c9143 (patch) | |
| tree | 41cc82e23d82dd620c81f2715a50969b832e9bda /dreambooth.py | |
| parent | Improved prompt handling (diff) | |
| download | textual-inversion-diff-2ddfbd65e482fa2361e8ba41b657656f825c9143.tar.gz textual-inversion-diff-2ddfbd65e482fa2361e8ba41b657656f825c9143.tar.bz2 textual-inversion-diff-2ddfbd65e482fa2361e8ba41b657656f825c9143.zip | |
Adapted other scripts for new prompt processing
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 17 | 
1 files changed, 10 insertions, 7 deletions
| diff --git a/dreambooth.py b/dreambooth.py index 770ad38..9786e0f 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -25,6 +25,7 @@ from slugify import slugify | |||
| 25 | from schedulers.scheduling_euler_a import EulerAScheduler | 25 | from schedulers.scheduling_euler_a import EulerAScheduler | 
| 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 
| 27 | from data.csv import CSVDataModule | 27 | from data.csv import CSVDataModule | 
| 28 | from models.clip.prompt import PromptProcessor | ||
| 28 | 29 | ||
| 29 | logger = get_logger(__name__) | 30 | logger = get_logger(__name__) | 
| 30 | 31 | ||
| @@ -141,7 +142,7 @@ def parse_args(): | |||
| 141 | parser.add_argument( | 142 | parser.add_argument( | 
| 142 | "--lr_scheduler", | 143 | "--lr_scheduler", | 
| 143 | type=str, | 144 | type=str, | 
| 144 | default="cosine", | 145 | default="cosine_with_restarts", | 
| 145 | help=( | 146 | help=( | 
| 146 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 147 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 
| 147 | ' "constant", "constant_with_warmup"]' | 148 | ' "constant", "constant_with_warmup"]' | 
| @@ -494,6 +495,8 @@ def main(): | |||
| 494 | device=accelerator.device | 495 | device=accelerator.device | 
| 495 | ) if args.use_ema else None | 496 | ) if args.use_ema else None | 
| 496 | 497 | ||
| 498 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
| 499 | |||
| 497 | if args.gradient_checkpointing: | 500 | if args.gradient_checkpointing: | 
| 498 | unet.enable_gradient_checkpointing() | 501 | unet.enable_gradient_checkpointing() | 
| 499 | 502 | ||
| @@ -557,7 +560,7 @@ def main(): | |||
| 557 | pixel_values = torch.stack(pixel_values) | 560 | pixel_values = torch.stack(pixel_values) | 
| 558 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 561 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 
| 559 | 562 | ||
| 560 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | 563 | input_ids = prompt_processor.unify_input_ids(input_ids) | 
| 561 | 564 | ||
| 562 | batch = { | 565 | batch = { | 
| 563 | "prompts": prompts, | 566 | "prompts": prompts, | 
| @@ -570,7 +573,7 @@ def main(): | |||
| 570 | datamodule = CSVDataModule( | 573 | datamodule = CSVDataModule( | 
| 571 | data_file=args.train_data_file, | 574 | data_file=args.train_data_file, | 
| 572 | batch_size=args.train_batch_size, | 575 | batch_size=args.train_batch_size, | 
| 573 | tokenizer=tokenizer, | 576 | prompt_processor=prompt_processor, | 
| 574 | instance_identifier=args.instance_identifier, | 577 | instance_identifier=args.instance_identifier, | 
| 575 | class_identifier=args.class_identifier, | 578 | class_identifier=args.class_identifier, | 
| 576 | class_subdir="cls", | 579 | class_subdir="cls", | 
| @@ -641,8 +644,8 @@ def main(): | |||
| 641 | optimizer=optimizer, | 644 | optimizer=optimizer, | 
| 642 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 645 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 
| 643 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 646 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 
| 644 | num_cycles=args.lr_cycles or math.ceil( | 647 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( | 
| 645 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), | 648 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), | 
| 646 | ) | 649 | ) | 
| 647 | else: | 650 | else: | 
| 648 | lr_scheduler = get_scheduler( | 651 | lr_scheduler = get_scheduler( | 
| @@ -756,7 +759,7 @@ def main(): | |||
| 756 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 759 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 
| 757 | 760 | ||
| 758 | # Get the text embedding for conditioning | 761 | # Get the text embedding for conditioning | 
| 759 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | 762 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 
| 760 | 763 | ||
| 761 | # Predict the noise residual | 764 | # Predict the noise residual | 
| 762 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 765 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 
| @@ -832,7 +835,7 @@ def main(): | |||
| 832 | 835 | ||
| 833 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 836 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 
| 834 | 837 | ||
| 835 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | 838 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 
| 836 | 839 | ||
| 837 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 840 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 
| 838 | 841 | ||
