summaryrefslogtreecommitdiffstats
path: root/dreambooth_plus.py
diff options
context:
space:
mode:
Diffstat (limited to 'dreambooth_plus.py')
-rw-r--r--dreambooth_plus.py44
1 files changed, 32 insertions, 12 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py
index a98417f..ae31377 100644
--- a/dreambooth_plus.py
+++ b/dreambooth_plus.py
@@ -124,7 +124,7 @@ def parse_args():
124 parser.add_argument( 124 parser.add_argument(
125 "--max_train_steps", 125 "--max_train_steps",
126 type=int, 126 type=int,
127 default=1500, 127 default=1400,
128 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 128 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
129 ) 129 )
130 parser.add_argument( 130 parser.add_argument(
@@ -147,7 +147,7 @@ def parse_args():
147 parser.add_argument( 147 parser.add_argument(
148 "--learning_rate_text", 148 "--learning_rate_text",
149 type=float, 149 type=float,
150 default=5e-6, 150 default=1e-6,
151 help="Initial learning rate (after the potential warmup period) to use.", 151 help="Initial learning rate (after the potential warmup period) to use.",
152 ) 152 )
153 parser.add_argument( 153 parser.add_argument(
@@ -469,9 +469,16 @@ class Checkpointer:
469 469
470 for i in range(self.sample_batches): 470 for i in range(self.sample_batches):
471 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] 471 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
472 prompt = [prompt.format(self.instance_identifier) 472 prompt = [
473 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] 473 [p.format(self.instance_identifier) for p in prompt]
474 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] 474 for batch in batches
475 for prompt in batch["prompts"]
476 ][:self.sample_batch_size]
477 nprompt = [
478 prompt
479 for batch in batches
480 for prompt in batch["nprompts"]
481 ][:self.sample_batch_size]
475 482
476 samples = pipeline( 483 samples = pipeline(
477 prompt=prompt, 484 prompt=prompt,
@@ -666,6 +673,17 @@ def main():
666 } 673 }
667 return batch 674 return batch
668 675
676 def encode_input_ids(input_ids):
677 text_embeddings = []
678
679 for ids in input_ids:
680 embeddings = text_encoder(ids)[0]
681 embeddings = embeddings.reshape((1, -1, 768))
682 text_embeddings.append(embeddings)
683
684 text_embeddings = torch.cat(text_embeddings)
685 return text_embeddings
686
669 datamodule = CSVDataModule( 687 datamodule = CSVDataModule(
670 data_file=args.train_data_file, 688 data_file=args.train_data_file,
671 batch_size=args.train_batch_size, 689 batch_size=args.train_batch_size,
@@ -688,8 +706,10 @@ def main():
688 missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] 706 missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()]
689 707
690 if len(missing_data) != 0: 708 if len(missing_data) != 0:
691 batched_data = [missing_data[i:i+args.sample_batch_size] 709 batched_data = [
692 for i in range(0, len(missing_data), args.sample_batch_size)] 710 missing_data[i:i+args.sample_batch_size]
711 for i in range(0, len(missing_data), args.sample_batch_size)
712 ]
693 713
694 scheduler = EulerAScheduler( 714 scheduler = EulerAScheduler(
695 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 715 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
@@ -706,9 +726,9 @@ def main():
706 726
707 with torch.inference_mode(): 727 with torch.inference_mode():
708 for batch in batched_data: 728 for batch in batched_data:
709 image_name = [p.class_image_path for p in batch] 729 image_name = [item.class_image_path for item in batch]
710 prompt = [p.prompt.format(args.class_identifier) for p in batch] 730 prompt = [[p.format(args.class_identifier) for p in item.prompt] for item in batch]
711 nprompt = [p.nprompt for p in batch] 731 nprompt = [item.nprompt for item in batch]
712 732
713 images = pipeline( 733 images = pipeline(
714 prompt=prompt, 734 prompt=prompt,
@@ -855,7 +875,7 @@ def main():
855 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 875 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
856 876
857 # Get the text embedding for conditioning 877 # Get the text embedding for conditioning
858 encoder_hidden_states = text_encoder(batch["input_ids"])[0] 878 encoder_hidden_states = encode_input_ids(batch["input_ids"])
859 879
860 # Predict the noise residual 880 # Predict the noise residual
861 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 881 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
@@ -954,7 +974,7 @@ def main():
954 974
955 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 975 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
956 976
957 encoder_hidden_states = text_encoder(batch["input_ids"])[0] 977 encoder_hidden_states = encode_input_ids(batch["input_ids"])
958 978
959 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 979 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
960 980