diff options
Diffstat (limited to 'dreambooth_plus.py')
-rw-r--r-- | dreambooth_plus.py | 44 |
1 files changed, 32 insertions, 12 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index a98417f..ae31377 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
@@ -124,7 +124,7 @@ def parse_args(): | |||
124 | parser.add_argument( | 124 | parser.add_argument( |
125 | "--max_train_steps", | 125 | "--max_train_steps", |
126 | type=int, | 126 | type=int, |
127 | default=1500, | 127 | default=1400, |
128 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 128 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
129 | ) | 129 | ) |
130 | parser.add_argument( | 130 | parser.add_argument( |
@@ -147,7 +147,7 @@ def parse_args(): | |||
147 | parser.add_argument( | 147 | parser.add_argument( |
148 | "--learning_rate_text", | 148 | "--learning_rate_text", |
149 | type=float, | 149 | type=float, |
150 | default=5e-6, | 150 | default=1e-6, |
151 | help="Initial learning rate (after the potential warmup period) to use.", | 151 | help="Initial learning rate (after the potential warmup period) to use.", |
152 | ) | 152 | ) |
153 | parser.add_argument( | 153 | parser.add_argument( |
@@ -469,9 +469,16 @@ class Checkpointer: | |||
469 | 469 | ||
470 | for i in range(self.sample_batches): | 470 | for i in range(self.sample_batches): |
471 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 471 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
472 | prompt = [prompt.format(self.instance_identifier) | 472 | prompt = [ |
473 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 473 | [p.format(self.instance_identifier) for p in prompt] |
474 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 474 | for batch in batches |
475 | for prompt in batch["prompts"] | ||
476 | ][:self.sample_batch_size] | ||
477 | nprompt = [ | ||
478 | prompt | ||
479 | for batch in batches | ||
480 | for prompt in batch["nprompts"] | ||
481 | ][:self.sample_batch_size] | ||
475 | 482 | ||
476 | samples = pipeline( | 483 | samples = pipeline( |
477 | prompt=prompt, | 484 | prompt=prompt, |
@@ -666,6 +673,17 @@ def main(): | |||
666 | } | 673 | } |
667 | return batch | 674 | return batch |
668 | 675 | ||
676 | def encode_input_ids(input_ids): | ||
677 | text_embeddings = [] | ||
678 | |||
679 | for ids in input_ids: | ||
680 | embeddings = text_encoder(ids)[0] | ||
681 | embeddings = embeddings.reshape((1, -1, 768)) | ||
682 | text_embeddings.append(embeddings) | ||
683 | |||
684 | text_embeddings = torch.cat(text_embeddings) | ||
685 | return text_embeddings | ||
686 | |||
669 | datamodule = CSVDataModule( | 687 | datamodule = CSVDataModule( |
670 | data_file=args.train_data_file, | 688 | data_file=args.train_data_file, |
671 | batch_size=args.train_batch_size, | 689 | batch_size=args.train_batch_size, |
@@ -688,8 +706,10 @@ def main(): | |||
688 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] | 706 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] |
689 | 707 | ||
690 | if len(missing_data) != 0: | 708 | if len(missing_data) != 0: |
691 | batched_data = [missing_data[i:i+args.sample_batch_size] | 709 | batched_data = [ |
692 | for i in range(0, len(missing_data), args.sample_batch_size)] | 710 | missing_data[i:i+args.sample_batch_size] |
711 | for i in range(0, len(missing_data), args.sample_batch_size) | ||
712 | ] | ||
693 | 713 | ||
694 | scheduler = EulerAScheduler( | 714 | scheduler = EulerAScheduler( |
695 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 715 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
@@ -706,9 +726,9 @@ def main(): | |||
706 | 726 | ||
707 | with torch.inference_mode(): | 727 | with torch.inference_mode(): |
708 | for batch in batched_data: | 728 | for batch in batched_data: |
709 | image_name = [p.class_image_path for p in batch] | 729 | image_name = [item.class_image_path for item in batch] |
710 | prompt = [p.prompt.format(args.class_identifier) for p in batch] | 730 | prompt = [[p.format(args.class_identifier) for p in item.prompt] for item in batch] |
711 | nprompt = [p.nprompt for p in batch] | 731 | nprompt = [item.nprompt for item in batch] |
712 | 732 | ||
713 | images = pipeline( | 733 | images = pipeline( |
714 | prompt=prompt, | 734 | prompt=prompt, |
@@ -855,7 +875,7 @@ def main(): | |||
855 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 875 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
856 | 876 | ||
857 | # Get the text embedding for conditioning | 877 | # Get the text embedding for conditioning |
858 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | 878 | encoder_hidden_states = encode_input_ids(batch["input_ids"]) |
859 | 879 | ||
860 | # Predict the noise residual | 880 | # Predict the noise residual |
861 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 881 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
@@ -954,7 +974,7 @@ def main(): | |||
954 | 974 | ||
955 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 975 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
956 | 976 | ||
957 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | 977 | encoder_hidden_states = encode_input_ids(batch["input_ids"]) |
958 | 978 | ||
959 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 979 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
960 | 980 | ||