diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-17 22:08:58 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-17 22:08:58 +0200 |
| commit | 728dfcf57c30f40236b3a00d7380c4e0057cacb3 (patch) | |
| tree | 9aee7759b7f31752a87a1c9af4d9c4ea20f9a862 /dreambooth_plus.py | |
| parent | Upstream updates; better handling of textual embedding (diff) | |
| download | textual-inversion-diff-728dfcf57c30f40236b3a00d7380c4e0057cacb3.tar.gz textual-inversion-diff-728dfcf57c30f40236b3a00d7380c4e0057cacb3.tar.bz2 textual-inversion-diff-728dfcf57c30f40236b3a00d7380c4e0057cacb3.zip | |
Implemented extended prompt limit
Diffstat (limited to 'dreambooth_plus.py')
| -rw-r--r-- | dreambooth_plus.py | 44 |
1 files changed, 32 insertions, 12 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index a98417f..ae31377 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
| @@ -124,7 +124,7 @@ def parse_args(): | |||
| 124 | parser.add_argument( | 124 | parser.add_argument( |
| 125 | "--max_train_steps", | 125 | "--max_train_steps", |
| 126 | type=int, | 126 | type=int, |
| 127 | default=1500, | 127 | default=1400, |
| 128 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 128 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 129 | ) | 129 | ) |
| 130 | parser.add_argument( | 130 | parser.add_argument( |
| @@ -147,7 +147,7 @@ def parse_args(): | |||
| 147 | parser.add_argument( | 147 | parser.add_argument( |
| 148 | "--learning_rate_text", | 148 | "--learning_rate_text", |
| 149 | type=float, | 149 | type=float, |
| 150 | default=5e-6, | 150 | default=1e-6, |
| 151 | help="Initial learning rate (after the potential warmup period) to use.", | 151 | help="Initial learning rate (after the potential warmup period) to use.", |
| 152 | ) | 152 | ) |
| 153 | parser.add_argument( | 153 | parser.add_argument( |
| @@ -469,9 +469,16 @@ class Checkpointer: | |||
| 469 | 469 | ||
| 470 | for i in range(self.sample_batches): | 470 | for i in range(self.sample_batches): |
| 471 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 471 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
| 472 | prompt = [prompt.format(self.instance_identifier) | 472 | prompt = [ |
| 473 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 473 | [p.format(self.instance_identifier) for p in prompt] |
| 474 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 474 | for batch in batches |
| 475 | for prompt in batch["prompts"] | ||
| 476 | ][:self.sample_batch_size] | ||
| 477 | nprompt = [ | ||
| 478 | prompt | ||
| 479 | for batch in batches | ||
| 480 | for prompt in batch["nprompts"] | ||
| 481 | ][:self.sample_batch_size] | ||
| 475 | 482 | ||
| 476 | samples = pipeline( | 483 | samples = pipeline( |
| 477 | prompt=prompt, | 484 | prompt=prompt, |
| @@ -666,6 +673,17 @@ def main(): | |||
| 666 | } | 673 | } |
| 667 | return batch | 674 | return batch |
| 668 | 675 | ||
| 676 | def encode_input_ids(input_ids): | ||
| 677 | text_embeddings = [] | ||
| 678 | |||
| 679 | for ids in input_ids: | ||
| 680 | embeddings = text_encoder(ids)[0] | ||
| 681 | embeddings = embeddings.reshape((1, -1, 768)) | ||
| 682 | text_embeddings.append(embeddings) | ||
| 683 | |||
| 684 | text_embeddings = torch.cat(text_embeddings) | ||
| 685 | return text_embeddings | ||
| 686 | |||
| 669 | datamodule = CSVDataModule( | 687 | datamodule = CSVDataModule( |
| 670 | data_file=args.train_data_file, | 688 | data_file=args.train_data_file, |
| 671 | batch_size=args.train_batch_size, | 689 | batch_size=args.train_batch_size, |
| @@ -688,8 +706,10 @@ def main(): | |||
| 688 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] | 706 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] |
| 689 | 707 | ||
| 690 | if len(missing_data) != 0: | 708 | if len(missing_data) != 0: |
| 691 | batched_data = [missing_data[i:i+args.sample_batch_size] | 709 | batched_data = [ |
| 692 | for i in range(0, len(missing_data), args.sample_batch_size)] | 710 | missing_data[i:i+args.sample_batch_size] |
| 711 | for i in range(0, len(missing_data), args.sample_batch_size) | ||
| 712 | ] | ||
| 693 | 713 | ||
| 694 | scheduler = EulerAScheduler( | 714 | scheduler = EulerAScheduler( |
| 695 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 715 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
| @@ -706,9 +726,9 @@ def main(): | |||
| 706 | 726 | ||
| 707 | with torch.inference_mode(): | 727 | with torch.inference_mode(): |
| 708 | for batch in batched_data: | 728 | for batch in batched_data: |
| 709 | image_name = [p.class_image_path for p in batch] | 729 | image_name = [item.class_image_path for item in batch] |
| 710 | prompt = [p.prompt.format(args.class_identifier) for p in batch] | 730 | prompt = [[p.format(args.class_identifier) for p in item.prompt] for item in batch] |
| 711 | nprompt = [p.nprompt for p in batch] | 731 | nprompt = [item.nprompt for item in batch] |
| 712 | 732 | ||
| 713 | images = pipeline( | 733 | images = pipeline( |
| 714 | prompt=prompt, | 734 | prompt=prompt, |
| @@ -855,7 +875,7 @@ def main(): | |||
| 855 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 875 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 856 | 876 | ||
| 857 | # Get the text embedding for conditioning | 877 | # Get the text embedding for conditioning |
| 858 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | 878 | encoder_hidden_states = encode_input_ids(batch["input_ids"]) |
| 859 | 879 | ||
| 860 | # Predict the noise residual | 880 | # Predict the noise residual |
| 861 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 881 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| @@ -954,7 +974,7 @@ def main(): | |||
| 954 | 974 | ||
| 955 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 975 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 956 | 976 | ||
| 957 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | 977 | encoder_hidden_states = encode_input_ids(batch["input_ids"]) |
| 958 | 978 | ||
| 959 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 979 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 960 | 980 | ||
