summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--dreambooth.py17
-rw-r--r--dreambooth_plus.py6
-rw-r--r--textual_inversion.py17
3 files changed, 23 insertions, 17 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 770ad38..9786e0f 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -25,6 +25,7 @@ from slugify import slugify
25from schedulers.scheduling_euler_a import EulerAScheduler 25from schedulers.scheduling_euler_a import EulerAScheduler
26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
27from data.csv import CSVDataModule 27from data.csv import CSVDataModule
28from models.clip.prompt import PromptProcessor
28 29
29logger = get_logger(__name__) 30logger = get_logger(__name__)
30 31
@@ -141,7 +142,7 @@ def parse_args():
141 parser.add_argument( 142 parser.add_argument(
142 "--lr_scheduler", 143 "--lr_scheduler",
143 type=str, 144 type=str,
144 default="cosine", 145 default="cosine_with_restarts",
145 help=( 146 help=(
146 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 147 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
147 ' "constant", "constant_with_warmup"]' 148 ' "constant", "constant_with_warmup"]'
@@ -494,6 +495,8 @@ def main():
494 device=accelerator.device 495 device=accelerator.device
495 ) if args.use_ema else None 496 ) if args.use_ema else None
496 497
498 prompt_processor = PromptProcessor(tokenizer, text_encoder)
499
497 if args.gradient_checkpointing: 500 if args.gradient_checkpointing:
498 unet.enable_gradient_checkpointing() 501 unet.enable_gradient_checkpointing()
499 502
@@ -557,7 +560,7 @@ def main():
557 pixel_values = torch.stack(pixel_values) 560 pixel_values = torch.stack(pixel_values)
558 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) 561 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
559 562
560 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids 563 input_ids = prompt_processor.unify_input_ids(input_ids)
561 564
562 batch = { 565 batch = {
563 "prompts": prompts, 566 "prompts": prompts,
@@ -570,7 +573,7 @@ def main():
570 datamodule = CSVDataModule( 573 datamodule = CSVDataModule(
571 data_file=args.train_data_file, 574 data_file=args.train_data_file,
572 batch_size=args.train_batch_size, 575 batch_size=args.train_batch_size,
573 tokenizer=tokenizer, 576 prompt_processor=prompt_processor,
574 instance_identifier=args.instance_identifier, 577 instance_identifier=args.instance_identifier,
575 class_identifier=args.class_identifier, 578 class_identifier=args.class_identifier,
576 class_subdir="cls", 579 class_subdir="cls",
@@ -641,8 +644,8 @@ def main():
641 optimizer=optimizer, 644 optimizer=optimizer,
642 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 645 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
643 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 646 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
644 num_cycles=args.lr_cycles or math.ceil( 647 num_cycles=args.lr_cycles or math.ceil(math.sqrt(
645 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), 648 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))),
646 ) 649 )
647 else: 650 else:
648 lr_scheduler = get_scheduler( 651 lr_scheduler = get_scheduler(
@@ -756,7 +759,7 @@ def main():
756 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 759 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
757 760
758 # Get the text embedding for conditioning 761 # Get the text embedding for conditioning
759 encoder_hidden_states = text_encoder(batch["input_ids"])[0] 762 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
760 763
761 # Predict the noise residual 764 # Predict the noise residual
762 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 765 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
@@ -832,7 +835,7 @@ def main():
832 835
833 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 836 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
834 837
835 encoder_hidden_states = text_encoder(batch["input_ids"])[0] 838 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
836 839
837 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 840 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
838 841
diff --git a/dreambooth_plus.py b/dreambooth_plus.py
index fa3a22b..06ff45b 100644
--- a/dreambooth_plus.py
+++ b/dreambooth_plus.py
@@ -125,7 +125,7 @@ def parse_args():
125 parser.add_argument( 125 parser.add_argument(
126 "--max_train_steps", 126 "--max_train_steps",
127 type=int, 127 type=int,
128 default=1400, 128 default=2400,
129 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 129 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
130 ) 130 )
131 parser.add_argument( 131 parser.add_argument(
@@ -752,8 +752,8 @@ def main():
752 optimizer=optimizer, 752 optimizer=optimizer,
753 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 753 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
754 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 754 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
755 num_cycles=args.lr_cycles or math.ceil( 755 num_cycles=args.lr_cycles or math.ceil(math.sqrt(
756 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), 756 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))),
757 ) 757 )
758 else: 758 else:
759 lr_scheduler = get_scheduler( 759 lr_scheduler = get_scheduler(
diff --git a/textual_inversion.py b/textual_inversion.py
index 69d9c7f..8f266e0 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -25,6 +25,7 @@ from slugify import slugify
25from schedulers.scheduling_euler_a import EulerAScheduler 25from schedulers.scheduling_euler_a import EulerAScheduler
26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
27from data.csv import CSVDataModule 27from data.csv import CSVDataModule
28from models.clip.prompt import PromptProcessor
28 29
29logger = get_logger(__name__) 30logger = get_logger(__name__)
30 31
@@ -152,7 +153,7 @@ def parse_args():
152 parser.add_argument( 153 parser.add_argument(
153 "--lr_scheduler", 154 "--lr_scheduler",
154 type=str, 155 type=str,
155 default="cosine", 156 default="cosine_with_restarts",
156 help=( 157 help=(
157 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 158 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
158 ' "constant", "constant_with_warmup"]' 159 ' "constant", "constant_with_warmup"]'
@@ -516,6 +517,8 @@ def main():
516 unet = UNet2DConditionModel.from_pretrained( 517 unet = UNet2DConditionModel.from_pretrained(
517 args.pretrained_model_name_or_path, subfolder='unet') 518 args.pretrained_model_name_or_path, subfolder='unet')
518 519
520 prompt_processor = PromptProcessor(tokenizer, text_encoder)
521
519 if args.gradient_checkpointing: 522 if args.gradient_checkpointing:
520 unet.enable_gradient_checkpointing() 523 unet.enable_gradient_checkpointing()
521 524
@@ -594,7 +597,7 @@ def main():
594 pixel_values = torch.stack(pixel_values) 597 pixel_values = torch.stack(pixel_values)
595 pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) 598 pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format)
596 599
597 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids 600 input_ids = prompt_processor.unify_input_ids(input_ids)
598 601
599 batch = { 602 batch = {
600 "prompts": prompts, 603 "prompts": prompts,
@@ -607,7 +610,7 @@ def main():
607 datamodule = CSVDataModule( 610 datamodule = CSVDataModule(
608 data_file=args.train_data_file, 611 data_file=args.train_data_file,
609 batch_size=args.train_batch_size, 612 batch_size=args.train_batch_size,
610 tokenizer=tokenizer, 613 prompt_processor=prompt_processor,
611 instance_identifier=args.instance_identifier, 614 instance_identifier=args.instance_identifier,
612 class_identifier=args.class_identifier, 615 class_identifier=args.class_identifier,
613 class_subdir="cls", 616 class_subdir="cls",
@@ -678,8 +681,8 @@ def main():
678 optimizer=optimizer, 681 optimizer=optimizer,
679 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 682 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
680 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 683 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
681 num_cycles=args.lr_cycles or math.ceil( 684 num_cycles=args.lr_cycles or math.ceil(math.sqrt(
682 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), 685 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))),
683 ) 686 )
684 else: 687 else:
685 lr_scheduler = get_scheduler( 688 lr_scheduler = get_scheduler(
@@ -794,7 +797,7 @@ def main():
794 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 797 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
795 798
796 # Get the text embedding for conditioning 799 # Get the text embedding for conditioning
797 encoder_hidden_states = text_encoder(batch["input_ids"])[0] 800 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
798 801
799 # Predict the noise residual 802 # Predict the noise residual
800 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 803 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
@@ -885,7 +888,7 @@ def main():
885 888
886 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 889 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
887 890
888 encoder_hidden_states = text_encoder(batch["input_ids"])[0] 891 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
889 892
890 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 893 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
891 894