diff options
| -rw-r--r-- | dreambooth.py | 17 | ||||
| -rw-r--r-- | dreambooth_plus.py | 6 | ||||
| -rw-r--r-- | textual_inversion.py | 17 |
3 files changed, 23 insertions, 17 deletions
diff --git a/dreambooth.py b/dreambooth.py index 770ad38..9786e0f 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -25,6 +25,7 @@ from slugify import slugify | |||
| 25 | from schedulers.scheduling_euler_a import EulerAScheduler | 25 | from schedulers.scheduling_euler_a import EulerAScheduler |
| 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 27 | from data.csv import CSVDataModule | 27 | from data.csv import CSVDataModule |
| 28 | from models.clip.prompt import PromptProcessor | ||
| 28 | 29 | ||
| 29 | logger = get_logger(__name__) | 30 | logger = get_logger(__name__) |
| 30 | 31 | ||
| @@ -141,7 +142,7 @@ def parse_args(): | |||
| 141 | parser.add_argument( | 142 | parser.add_argument( |
| 142 | "--lr_scheduler", | 143 | "--lr_scheduler", |
| 143 | type=str, | 144 | type=str, |
| 144 | default="cosine", | 145 | default="cosine_with_restarts", |
| 145 | help=( | 146 | help=( |
| 146 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 147 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
| 147 | ' "constant", "constant_with_warmup"]' | 148 | ' "constant", "constant_with_warmup"]' |
| @@ -494,6 +495,8 @@ def main(): | |||
| 494 | device=accelerator.device | 495 | device=accelerator.device |
| 495 | ) if args.use_ema else None | 496 | ) if args.use_ema else None |
| 496 | 497 | ||
| 498 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
| 499 | |||
| 497 | if args.gradient_checkpointing: | 500 | if args.gradient_checkpointing: |
| 498 | unet.enable_gradient_checkpointing() | 501 | unet.enable_gradient_checkpointing() |
| 499 | 502 | ||
| @@ -557,7 +560,7 @@ def main(): | |||
| 557 | pixel_values = torch.stack(pixel_values) | 560 | pixel_values = torch.stack(pixel_values) |
| 558 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 561 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
| 559 | 562 | ||
| 560 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | 563 | input_ids = prompt_processor.unify_input_ids(input_ids) |
| 561 | 564 | ||
| 562 | batch = { | 565 | batch = { |
| 563 | "prompts": prompts, | 566 | "prompts": prompts, |
| @@ -570,7 +573,7 @@ def main(): | |||
| 570 | datamodule = CSVDataModule( | 573 | datamodule = CSVDataModule( |
| 571 | data_file=args.train_data_file, | 574 | data_file=args.train_data_file, |
| 572 | batch_size=args.train_batch_size, | 575 | batch_size=args.train_batch_size, |
| 573 | tokenizer=tokenizer, | 576 | prompt_processor=prompt_processor, |
| 574 | instance_identifier=args.instance_identifier, | 577 | instance_identifier=args.instance_identifier, |
| 575 | class_identifier=args.class_identifier, | 578 | class_identifier=args.class_identifier, |
| 576 | class_subdir="cls", | 579 | class_subdir="cls", |
| @@ -641,8 +644,8 @@ def main(): | |||
| 641 | optimizer=optimizer, | 644 | optimizer=optimizer, |
| 642 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 645 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
| 643 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 646 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| 644 | num_cycles=args.lr_cycles or math.ceil( | 647 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( |
| 645 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), | 648 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), |
| 646 | ) | 649 | ) |
| 647 | else: | 650 | else: |
| 648 | lr_scheduler = get_scheduler( | 651 | lr_scheduler = get_scheduler( |
| @@ -756,7 +759,7 @@ def main(): | |||
| 756 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 759 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 757 | 760 | ||
| 758 | # Get the text embedding for conditioning | 761 | # Get the text embedding for conditioning |
| 759 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | 762 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
| 760 | 763 | ||
| 761 | # Predict the noise residual | 764 | # Predict the noise residual |
| 762 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 765 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| @@ -832,7 +835,7 @@ def main(): | |||
| 832 | 835 | ||
| 833 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 836 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 834 | 837 | ||
| 835 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | 838 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
| 836 | 839 | ||
| 837 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 840 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 838 | 841 | ||
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index fa3a22b..06ff45b 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
| @@ -125,7 +125,7 @@ def parse_args(): | |||
| 125 | parser.add_argument( | 125 | parser.add_argument( |
| 126 | "--max_train_steps", | 126 | "--max_train_steps", |
| 127 | type=int, | 127 | type=int, |
| 128 | default=1400, | 128 | default=2400, |
| 129 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 129 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 130 | ) | 130 | ) |
| 131 | parser.add_argument( | 131 | parser.add_argument( |
| @@ -752,8 +752,8 @@ def main(): | |||
| 752 | optimizer=optimizer, | 752 | optimizer=optimizer, |
| 753 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 753 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
| 754 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 754 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| 755 | num_cycles=args.lr_cycles or math.ceil( | 755 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( |
| 756 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), | 756 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), |
| 757 | ) | 757 | ) |
| 758 | else: | 758 | else: |
| 759 | lr_scheduler = get_scheduler( | 759 | lr_scheduler = get_scheduler( |
diff --git a/textual_inversion.py b/textual_inversion.py index 69d9c7f..8f266e0 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -25,6 +25,7 @@ from slugify import slugify | |||
| 25 | from schedulers.scheduling_euler_a import EulerAScheduler | 25 | from schedulers.scheduling_euler_a import EulerAScheduler |
| 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 27 | from data.csv import CSVDataModule | 27 | from data.csv import CSVDataModule |
| 28 | from models.clip.prompt import PromptProcessor | ||
| 28 | 29 | ||
| 29 | logger = get_logger(__name__) | 30 | logger = get_logger(__name__) |
| 30 | 31 | ||
| @@ -152,7 +153,7 @@ def parse_args(): | |||
| 152 | parser.add_argument( | 153 | parser.add_argument( |
| 153 | "--lr_scheduler", | 154 | "--lr_scheduler", |
| 154 | type=str, | 155 | type=str, |
| 155 | default="cosine", | 156 | default="cosine_with_restarts", |
| 156 | help=( | 157 | help=( |
| 157 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 158 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
| 158 | ' "constant", "constant_with_warmup"]' | 159 | ' "constant", "constant_with_warmup"]' |
| @@ -516,6 +517,8 @@ def main(): | |||
| 516 | unet = UNet2DConditionModel.from_pretrained( | 517 | unet = UNet2DConditionModel.from_pretrained( |
| 517 | args.pretrained_model_name_or_path, subfolder='unet') | 518 | args.pretrained_model_name_or_path, subfolder='unet') |
| 518 | 519 | ||
| 520 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
| 521 | |||
| 519 | if args.gradient_checkpointing: | 522 | if args.gradient_checkpointing: |
| 520 | unet.enable_gradient_checkpointing() | 523 | unet.enable_gradient_checkpointing() |
| 521 | 524 | ||
| @@ -594,7 +597,7 @@ def main(): | |||
| 594 | pixel_values = torch.stack(pixel_values) | 597 | pixel_values = torch.stack(pixel_values) |
| 595 | pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) | 598 | pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) |
| 596 | 599 | ||
| 597 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | 600 | input_ids = prompt_processor.unify_input_ids(input_ids) |
| 598 | 601 | ||
| 599 | batch = { | 602 | batch = { |
| 600 | "prompts": prompts, | 603 | "prompts": prompts, |
| @@ -607,7 +610,7 @@ def main(): | |||
| 607 | datamodule = CSVDataModule( | 610 | datamodule = CSVDataModule( |
| 608 | data_file=args.train_data_file, | 611 | data_file=args.train_data_file, |
| 609 | batch_size=args.train_batch_size, | 612 | batch_size=args.train_batch_size, |
| 610 | tokenizer=tokenizer, | 613 | prompt_processor=prompt_processor, |
| 611 | instance_identifier=args.instance_identifier, | 614 | instance_identifier=args.instance_identifier, |
| 612 | class_identifier=args.class_identifier, | 615 | class_identifier=args.class_identifier, |
| 613 | class_subdir="cls", | 616 | class_subdir="cls", |
| @@ -678,8 +681,8 @@ def main(): | |||
| 678 | optimizer=optimizer, | 681 | optimizer=optimizer, |
| 679 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 682 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
| 680 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 683 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| 681 | num_cycles=args.lr_cycles or math.ceil( | 684 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( |
| 682 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), | 685 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), |
| 683 | ) | 686 | ) |
| 684 | else: | 687 | else: |
| 685 | lr_scheduler = get_scheduler( | 688 | lr_scheduler = get_scheduler( |
| @@ -794,7 +797,7 @@ def main(): | |||
| 794 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 797 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 795 | 798 | ||
| 796 | # Get the text embedding for conditioning | 799 | # Get the text embedding for conditioning |
| 797 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | 800 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
| 798 | 801 | ||
| 799 | # Predict the noise residual | 802 | # Predict the noise residual |
| 800 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 803 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| @@ -885,7 +888,7 @@ def main(): | |||
| 885 | 888 | ||
| 886 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 889 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 887 | 890 | ||
| 888 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | 891 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
| 889 | 892 | ||
| 890 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 893 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 891 | 894 | ||
