diff options
author | Volpeon <git@volpeon.ink> | 2022-10-17 12:27:53 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-17 12:27:53 +0200 |
commit | 633d890e4964e070be9b0a5b299c2f2e51d4b055 (patch) | |
tree | 235b33195b041e45bb7a6a24471ea55ad4bd7850 | |
parent | Update (diff) | |
download | textual-inversion-diff-633d890e4964e070be9b0a5b299c2f2e51d4b055.tar.gz textual-inversion-diff-633d890e4964e070be9b0a5b299c2f2e51d4b055.tar.bz2 textual-inversion-diff-633d890e4964e070be9b0a5b299c2f2e51d4b055.zip |
Upstream updates; better handling of textual embedding
-rw-r--r-- | dreambooth.py | 26 | ||||
-rw-r--r-- | dreambooth_plus.py | 69 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 2 | ||||
-rw-r--r-- | schedulers/scheduling_euler_a.py | 3 | ||||
-rw-r--r-- | textual_inversion.py | 41 |
5 files changed, 82 insertions, 59 deletions
diff --git a/dreambooth.py b/dreambooth.py index 42d3980..770ad38 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -430,7 +430,7 @@ class Checkpointer: | |||
430 | eta=eta, | 430 | eta=eta, |
431 | num_inference_steps=num_inference_steps, | 431 | num_inference_steps=num_inference_steps, |
432 | output_type='pil' | 432 | output_type='pil' |
433 | )["sample"] | 433 | ).images |
434 | 434 | ||
435 | all_samples += samples | 435 | all_samples += samples |
436 | 436 | ||
@@ -537,6 +537,12 @@ def main(): | |||
537 | num_train_timesteps=args.noise_timesteps | 537 | num_train_timesteps=args.noise_timesteps |
538 | ) | 538 | ) |
539 | 539 | ||
540 | weight_dtype = torch.float32 | ||
541 | if args.mixed_precision == "fp16": | ||
542 | weight_dtype = torch.float16 | ||
543 | elif args.mixed_precision == "bf16": | ||
544 | weight_dtype = torch.bfloat16 | ||
545 | |||
540 | def collate_fn(examples): | 546 | def collate_fn(examples): |
541 | prompts = [example["prompts"] for example in examples] | 547 | prompts = [example["prompts"] for example in examples] |
542 | nprompts = [example["nprompts"] for example in examples] | 548 | nprompts = [example["nprompts"] for example in examples] |
@@ -549,7 +555,7 @@ def main(): | |||
549 | pixel_values += [example["class_images"] for example in examples] | 555 | pixel_values += [example["class_images"] for example in examples] |
550 | 556 | ||
551 | pixel_values = torch.stack(pixel_values) | 557 | pixel_values = torch.stack(pixel_values) |
552 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format) | 558 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
553 | 559 | ||
554 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | 560 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids |
555 | 561 | ||
@@ -651,8 +657,8 @@ def main(): | |||
651 | ) | 657 | ) |
652 | 658 | ||
653 | # Move text_encoder and vae to device | 659 | # Move text_encoder and vae to device |
654 | text_encoder.to(accelerator.device) | 660 | text_encoder.to(accelerator.device, dtype=weight_dtype) |
655 | vae.to(accelerator.device) | 661 | vae.to(accelerator.device, dtype=weight_dtype) |
656 | 662 | ||
657 | # Keep text_encoder and vae in eval mode as we don't train these | 663 | # Keep text_encoder and vae in eval mode as we don't train these |
658 | text_encoder.eval() | 664 | text_encoder.eval() |
@@ -738,7 +744,7 @@ def main(): | |||
738 | latents = latents * 0.18215 | 744 | latents = latents * 0.18215 |
739 | 745 | ||
740 | # Sample noise that we'll add to the latents | 746 | # Sample noise that we'll add to the latents |
741 | noise = torch.randn(latents.shape).to(latents.device) | 747 | noise = torch.randn_like(latents) |
742 | bsz = latents.shape[0] | 748 | bsz = latents.shape[0] |
743 | # Sample a random timestep for each image | 749 | # Sample a random timestep for each image |
744 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 750 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
@@ -761,15 +767,15 @@ def main(): | |||
761 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 767 | noise, noise_prior = torch.chunk(noise, 2, dim=0) |
762 | 768 | ||
763 | # Compute instance loss | 769 | # Compute instance loss |
764 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 770 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() |
765 | 771 | ||
766 | # Compute prior loss | 772 | # Compute prior loss |
767 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() | 773 | prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") |
768 | 774 | ||
769 | # Add the prior loss to the instance loss. | 775 | # Add the prior loss to the instance loss. |
770 | loss = loss + args.prior_loss_weight * prior_loss | 776 | loss = loss + args.prior_loss_weight * prior_loss |
771 | else: | 777 | else: |
772 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 778 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") |
773 | 779 | ||
774 | accelerator.backward(loss) | 780 | accelerator.backward(loss) |
775 | if accelerator.sync_gradients: | 781 | if accelerator.sync_gradients: |
@@ -818,7 +824,7 @@ def main(): | |||
818 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 824 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
819 | latents = latents * 0.18215 | 825 | latents = latents * 0.18215 |
820 | 826 | ||
821 | noise = torch.randn(latents.shape).to(latents.device) | 827 | noise = torch.randn_like(latents) |
822 | bsz = latents.shape[0] | 828 | bsz = latents.shape[0] |
823 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 829 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
824 | (bsz,), device=latents.device) | 830 | (bsz,), device=latents.device) |
@@ -832,7 +838,7 @@ def main(): | |||
832 | 838 | ||
833 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 839 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
834 | 840 | ||
835 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 841 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") |
836 | 842 | ||
837 | loss = loss.detach().item() | 843 | loss = loss.detach().item() |
838 | val_loss += loss | 844 | val_loss += loss |
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index 73225de..a98417f 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
@@ -56,9 +56,9 @@ def parse_args(): | |||
56 | help="A CSV file containing the training data." | 56 | help="A CSV file containing the training data." |
57 | ) | 57 | ) |
58 | parser.add_argument( | 58 | parser.add_argument( |
59 | "--placeholder_token", | 59 | "--instance_identifier", |
60 | type=str, | 60 | type=str, |
61 | default="<*>", | 61 | default=None, |
62 | help="A token to use as a placeholder for the concept.", | 62 | help="A token to use as a placeholder for the concept.", |
63 | ) | 63 | ) |
64 | parser.add_argument( | 64 | parser.add_argument( |
@@ -68,6 +68,12 @@ def parse_args(): | |||
68 | help="A token to use as a placeholder for the concept.", | 68 | help="A token to use as a placeholder for the concept.", |
69 | ) | 69 | ) |
70 | parser.add_argument( | 70 | parser.add_argument( |
71 | "--placeholder_token", | ||
72 | type=str, | ||
73 | default="<*>", | ||
74 | help="A token to use as a placeholder for the concept.", | ||
75 | ) | ||
76 | parser.add_argument( | ||
71 | "--initializer_token", | 77 | "--initializer_token", |
72 | type=str, | 78 | type=str, |
73 | default=None, | 79 | default=None, |
@@ -118,7 +124,7 @@ def parse_args(): | |||
118 | parser.add_argument( | 124 | parser.add_argument( |
119 | "--max_train_steps", | 125 | "--max_train_steps", |
120 | type=int, | 126 | type=int, |
121 | default=1200, | 127 | default=1500, |
122 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 128 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
123 | ) | 129 | ) |
124 | parser.add_argument( | 130 | parser.add_argument( |
@@ -135,13 +141,13 @@ def parse_args(): | |||
135 | parser.add_argument( | 141 | parser.add_argument( |
136 | "--learning_rate_unet", | 142 | "--learning_rate_unet", |
137 | type=float, | 143 | type=float, |
138 | default=5e-5, | 144 | default=5e-6, |
139 | help="Initial learning rate (after the potential warmup period) to use.", | 145 | help="Initial learning rate (after the potential warmup period) to use.", |
140 | ) | 146 | ) |
141 | parser.add_argument( | 147 | parser.add_argument( |
142 | "--learning_rate_text", | 148 | "--learning_rate_text", |
143 | type=float, | 149 | type=float, |
144 | default=1e-6, | 150 | default=5e-6, |
145 | help="Initial learning rate (after the potential warmup period) to use.", | 151 | help="Initial learning rate (after the potential warmup period) to use.", |
146 | ) | 152 | ) |
147 | parser.add_argument( | 153 | parser.add_argument( |
@@ -353,6 +359,7 @@ class Checkpointer: | |||
353 | ema_unet, | 359 | ema_unet, |
354 | tokenizer, | 360 | tokenizer, |
355 | text_encoder, | 361 | text_encoder, |
362 | instance_identifier, | ||
356 | placeholder_token, | 363 | placeholder_token, |
357 | placeholder_token_id, | 364 | placeholder_token_id, |
358 | output_dir: Path, | 365 | output_dir: Path, |
@@ -368,6 +375,7 @@ class Checkpointer: | |||
368 | self.ema_unet = ema_unet | 375 | self.ema_unet = ema_unet |
369 | self.tokenizer = tokenizer | 376 | self.tokenizer = tokenizer |
370 | self.text_encoder = text_encoder | 377 | self.text_encoder = text_encoder |
378 | self.instance_identifier = instance_identifier | ||
371 | self.placeholder_token = placeholder_token | 379 | self.placeholder_token = placeholder_token |
372 | self.placeholder_token_id = placeholder_token_id | 380 | self.placeholder_token_id = placeholder_token_id |
373 | self.output_dir = output_dir | 381 | self.output_dir = output_dir |
@@ -461,7 +469,7 @@ class Checkpointer: | |||
461 | 469 | ||
462 | for i in range(self.sample_batches): | 470 | for i in range(self.sample_batches): |
463 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 471 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
464 | prompt = [prompt.format(self.placeholder_token) | 472 | prompt = [prompt.format(self.instance_identifier) |
465 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 473 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] |
466 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 474 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] |
467 | 475 | ||
@@ -476,7 +484,7 @@ class Checkpointer: | |||
476 | eta=eta, | 484 | eta=eta, |
477 | num_inference_steps=num_inference_steps, | 485 | num_inference_steps=num_inference_steps, |
478 | output_type='pil' | 486 | output_type='pil' |
479 | )["sample"] | 487 | ).images |
480 | 488 | ||
481 | all_samples += samples | 489 | all_samples += samples |
482 | 490 | ||
@@ -522,28 +530,26 @@ def main(): | |||
522 | if args.seed is not None: | 530 | if args.seed is not None: |
523 | set_seed(args.seed) | 531 | set_seed(args.seed) |
524 | 532 | ||
533 | args.instance_identifier = args.instance_identifier.format(args.placeholder_token) | ||
534 | |||
525 | # Load the tokenizer and add the placeholder token as a additional special token | 535 | # Load the tokenizer and add the placeholder token as a additional special token |
526 | if args.tokenizer_name: | 536 | if args.tokenizer_name: |
527 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) | 537 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) |
528 | elif args.pretrained_model_name_or_path: | 538 | elif args.pretrained_model_name_or_path: |
529 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 539 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
530 | 540 | ||
541 | # Convert the initializer_token, placeholder_token to ids | ||
542 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | ||
543 | print(f"Initializer token maps to {len(initializer_token_ids)} embeddings.") | ||
544 | initializer_token_ids = torch.tensor(initializer_token_ids[:1]) | ||
545 | |||
531 | # Add the placeholder token in tokenizer | 546 | # Add the placeholder token in tokenizer |
532 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | 547 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) |
533 | if num_added_tokens == 0: | 548 | if num_added_tokens == 0: |
534 | raise ValueError( | 549 | print(f"Re-using existing token {args.placeholder_token}.") |
535 | f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" | 550 | else: |
536 | " `placeholder_token` that is not already in the tokenizer." | 551 | print(f"Training new token {args.placeholder_token}.") |
537 | ) | ||
538 | |||
539 | # Convert the initializer_token, placeholder_token to ids | ||
540 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | ||
541 | # Check if initializer_token is a single token or a sequence of tokens | ||
542 | if len(initializer_token_ids) > 1: | ||
543 | raise ValueError( | ||
544 | f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.") | ||
545 | 552 | ||
546 | initializer_token_ids = torch.tensor(initializer_token_ids) | ||
547 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 553 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
548 | 554 | ||
549 | # Load models and create wrapper for stable diffusion | 555 | # Load models and create wrapper for stable diffusion |
@@ -630,6 +636,12 @@ def main(): | |||
630 | num_train_timesteps=args.noise_timesteps | 636 | num_train_timesteps=args.noise_timesteps |
631 | ) | 637 | ) |
632 | 638 | ||
639 | weight_dtype = torch.float32 | ||
640 | if args.mixed_precision == "fp16": | ||
641 | weight_dtype = torch.float16 | ||
642 | elif args.mixed_precision == "bf16": | ||
643 | weight_dtype = torch.bfloat16 | ||
644 | |||
633 | def collate_fn(examples): | 645 | def collate_fn(examples): |
634 | prompts = [example["prompts"] for example in examples] | 646 | prompts = [example["prompts"] for example in examples] |
635 | nprompts = [example["nprompts"] for example in examples] | 647 | nprompts = [example["nprompts"] for example in examples] |
@@ -642,7 +654,7 @@ def main(): | |||
642 | pixel_values += [example["class_images"] for example in examples] | 654 | pixel_values += [example["class_images"] for example in examples] |
643 | 655 | ||
644 | pixel_values = torch.stack(pixel_values) | 656 | pixel_values = torch.stack(pixel_values) |
645 | pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) | 657 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
646 | 658 | ||
647 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | 659 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids |
648 | 660 | ||
@@ -658,7 +670,7 @@ def main(): | |||
658 | data_file=args.train_data_file, | 670 | data_file=args.train_data_file, |
659 | batch_size=args.train_batch_size, | 671 | batch_size=args.train_batch_size, |
660 | tokenizer=tokenizer, | 672 | tokenizer=tokenizer, |
661 | instance_identifier=args.placeholder_token, | 673 | instance_identifier=args.instance_identifier, |
662 | class_identifier=args.class_identifier, | 674 | class_identifier=args.class_identifier, |
663 | class_subdir="cls", | 675 | class_subdir="cls", |
664 | num_class_images=args.num_class_images, | 676 | num_class_images=args.num_class_images, |
@@ -744,7 +756,7 @@ def main(): | |||
744 | ) | 756 | ) |
745 | 757 | ||
746 | # Move vae and unet to device | 758 | # Move vae and unet to device |
747 | vae.to(accelerator.device) | 759 | vae.to(accelerator.device, dtype=weight_dtype) |
748 | 760 | ||
749 | # Keep vae and unet in eval mode as we don't train these | 761 | # Keep vae and unet in eval mode as we don't train these |
750 | vae.eval() | 762 | vae.eval() |
@@ -785,6 +797,7 @@ def main(): | |||
785 | ema_unet=ema_unet, | 797 | ema_unet=ema_unet, |
786 | tokenizer=tokenizer, | 798 | tokenizer=tokenizer, |
787 | text_encoder=text_encoder, | 799 | text_encoder=text_encoder, |
800 | instance_identifier=args.instance_identifier, | ||
788 | placeholder_token=args.placeholder_token, | 801 | placeholder_token=args.placeholder_token, |
789 | placeholder_token_id=placeholder_token_id, | 802 | placeholder_token_id=placeholder_token_id, |
790 | output_dir=basepath, | 803 | output_dir=basepath, |
@@ -830,7 +843,7 @@ def main(): | |||
830 | latents = latents * 0.18215 | 843 | latents = latents * 0.18215 |
831 | 844 | ||
832 | # Sample noise that we'll add to the latents | 845 | # Sample noise that we'll add to the latents |
833 | noise = torch.randn(latents.shape).to(latents.device) | 846 | noise = torch.randn_like(latents) |
834 | bsz = latents.shape[0] | 847 | bsz = latents.shape[0] |
835 | # Sample a random timestep for each image | 848 | # Sample a random timestep for each image |
836 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 849 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
@@ -853,15 +866,15 @@ def main(): | |||
853 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 866 | noise, noise_prior = torch.chunk(noise, 2, dim=0) |
854 | 867 | ||
855 | # Compute instance loss | 868 | # Compute instance loss |
856 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 869 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() |
857 | 870 | ||
858 | # Compute prior loss | 871 | # Compute prior loss |
859 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() | 872 | prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") |
860 | 873 | ||
861 | # Add the prior loss to the instance loss. | 874 | # Add the prior loss to the instance loss. |
862 | loss = loss + args.prior_loss_weight * prior_loss | 875 | loss = loss + args.prior_loss_weight * prior_loss |
863 | else: | 876 | else: |
864 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 877 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") |
865 | 878 | ||
866 | accelerator.backward(loss) | 879 | accelerator.backward(loss) |
867 | 880 | ||
@@ -933,7 +946,7 @@ def main(): | |||
933 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 946 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
934 | latents = latents * 0.18215 | 947 | latents = latents * 0.18215 |
935 | 948 | ||
936 | noise = torch.randn(latents.shape).to(latents.device) | 949 | noise = torch.randn_like(latents) |
937 | bsz = latents.shape[0] | 950 | bsz = latents.shape[0] |
938 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 951 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
939 | (bsz,), device=latents.device) | 952 | (bsz,), device=latents.device) |
@@ -947,7 +960,7 @@ def main(): | |||
947 | 960 | ||
948 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 961 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
949 | 962 | ||
950 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 963 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") |
951 | 964 | ||
952 | loss = loss.detach().item() | 965 | loss = loss.detach().item() |
953 | val_loss += loss | 966 | val_loss += loss |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 2656b28..8b08a6f 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -301,7 +301,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
301 | 301 | ||
302 | # scale and decode the image latents with vae | 302 | # scale and decode the image latents with vae |
303 | latents = 1 / 0.18215 * latents | 303 | latents = 1 / 0.18215 * latents |
304 | image = self.vae.decode(latents).sample | 304 | image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample |
305 | 305 | ||
306 | image = (image / 2 + 0.5).clamp(0, 1) | 306 | image = (image / 2 + 0.5).clamp(0, 1) |
307 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() | 307 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 6abe971..c097a8a 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py | |||
@@ -47,7 +47,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
47 | beta_end: float = 0.02, | 47 | beta_end: float = 0.02, |
48 | beta_schedule: str = "linear", | 48 | beta_schedule: str = "linear", |
49 | trained_betas: Optional[np.ndarray] = None, | 49 | trained_betas: Optional[np.ndarray] = None, |
50 | tensor_format: str = "pt", | ||
51 | num_inference_steps=None, | 50 | num_inference_steps=None, |
52 | device='cuda' | 51 | device='cuda' |
53 | ): | 52 | ): |
@@ -63,7 +62,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
63 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") | 62 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") |
64 | 63 | ||
65 | self.device = device | 64 | self.device = device |
66 | self.tensor_format = tensor_format | ||
67 | 65 | ||
68 | self.alphas = 1.0 - self.betas | 66 | self.alphas = 1.0 - self.betas |
69 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | 67 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) |
@@ -77,7 +75,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
77 | # get sigmas | 75 | # get sigmas |
78 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 | 76 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 |
79 | self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps) | 77 | self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps) |
80 | self.set_format(tensor_format=tensor_format) | ||
81 | 78 | ||
82 | # A# take number of steps as input | 79 | # A# take number of steps as input |
83 | # A# store 1) number of steps 2) timesteps 3) schedule | 80 | # A# store 1) number of steps 2) timesteps 3) schedule |
diff --git a/textual_inversion.py b/textual_inversion.py index 0d5a742..69d9c7f 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -55,9 +55,9 @@ def parse_args(): | |||
55 | help="A CSV file containing the training data." | 55 | help="A CSV file containing the training data." |
56 | ) | 56 | ) |
57 | parser.add_argument( | 57 | parser.add_argument( |
58 | "--placeholder_token", | 58 | "--instance_identifier", |
59 | type=str, | 59 | type=str, |
60 | default="<*>", | 60 | default=None, |
61 | help="A token to use as a placeholder for the concept.", | 61 | help="A token to use as a placeholder for the concept.", |
62 | ) | 62 | ) |
63 | parser.add_argument( | 63 | parser.add_argument( |
@@ -67,6 +67,12 @@ def parse_args(): | |||
67 | help="A token to use as a placeholder for the concept.", | 67 | help="A token to use as a placeholder for the concept.", |
68 | ) | 68 | ) |
69 | parser.add_argument( | 69 | parser.add_argument( |
70 | "--placeholder_token", | ||
71 | type=str, | ||
72 | default="<*>", | ||
73 | help="A token to use as a placeholder for the concept.", | ||
74 | ) | ||
75 | parser.add_argument( | ||
70 | "--initializer_token", | 76 | "--initializer_token", |
71 | type=str, | 77 | type=str, |
72 | default=None, | 78 | default=None, |
@@ -333,6 +339,7 @@ class Checkpointer: | |||
333 | unet, | 339 | unet, |
334 | tokenizer, | 340 | tokenizer, |
335 | text_encoder, | 341 | text_encoder, |
342 | instance_identifier, | ||
336 | placeholder_token, | 343 | placeholder_token, |
337 | placeholder_token_id, | 344 | placeholder_token_id, |
338 | output_dir: Path, | 345 | output_dir: Path, |
@@ -347,6 +354,7 @@ class Checkpointer: | |||
347 | self.unet = unet | 354 | self.unet = unet |
348 | self.tokenizer = tokenizer | 355 | self.tokenizer = tokenizer |
349 | self.text_encoder = text_encoder | 356 | self.text_encoder = text_encoder |
357 | self.instance_identifier = instance_identifier | ||
350 | self.placeholder_token = placeholder_token | 358 | self.placeholder_token = placeholder_token |
351 | self.placeholder_token_id = placeholder_token_id | 359 | self.placeholder_token_id = placeholder_token_id |
352 | self.output_dir = output_dir | 360 | self.output_dir = output_dir |
@@ -413,7 +421,7 @@ class Checkpointer: | |||
413 | 421 | ||
414 | for i in range(self.sample_batches): | 422 | for i in range(self.sample_batches): |
415 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 423 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
416 | prompt = [prompt.format(self.placeholder_token) | 424 | prompt = [prompt.format(self.instance_identifier) |
417 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 425 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] |
418 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 426 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] |
419 | 427 | ||
@@ -428,7 +436,7 @@ class Checkpointer: | |||
428 | eta=eta, | 436 | eta=eta, |
429 | num_inference_steps=num_inference_steps, | 437 | num_inference_steps=num_inference_steps, |
430 | output_type='pil' | 438 | output_type='pil' |
431 | )["sample"] | 439 | ).images |
432 | 440 | ||
433 | all_samples += samples | 441 | all_samples += samples |
434 | 442 | ||
@@ -480,28 +488,26 @@ def main(): | |||
480 | if args.seed is not None: | 488 | if args.seed is not None: |
481 | set_seed(args.seed) | 489 | set_seed(args.seed) |
482 | 490 | ||
491 | args.instance_identifier = args.instance_identifier.format(args.placeholder_token) | ||
492 | |||
483 | # Load the tokenizer and add the placeholder token as a additional special token | 493 | # Load the tokenizer and add the placeholder token as a additional special token |
484 | if args.tokenizer_name: | 494 | if args.tokenizer_name: |
485 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) | 495 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) |
486 | elif args.pretrained_model_name_or_path: | 496 | elif args.pretrained_model_name_or_path: |
487 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 497 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
488 | 498 | ||
499 | # Convert the initializer_token, placeholder_token to ids | ||
500 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | ||
501 | print(f"Initializer token maps to {len(initializer_token_ids)} embeddings.") | ||
502 | initializer_token_ids = torch.tensor(initializer_token_ids[:1]) | ||
503 | |||
489 | # Add the placeholder token in tokenizer | 504 | # Add the placeholder token in tokenizer |
490 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | 505 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) |
491 | if num_added_tokens == 0: | 506 | if num_added_tokens == 0: |
492 | raise ValueError( | 507 | print(f"Re-using existing token {args.placeholder_token}.") |
493 | f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" | 508 | else: |
494 | " `placeholder_token` that is not already in the tokenizer." | 509 | print(f"Training new token {args.placeholder_token}.") |
495 | ) | ||
496 | |||
497 | # Convert the initializer_token, placeholder_token to ids | ||
498 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | ||
499 | # Check if initializer_token is a single token or a sequence of tokens | ||
500 | if len(initializer_token_ids) > 1: | ||
501 | raise ValueError( | ||
502 | f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.") | ||
503 | 510 | ||
504 | initializer_token_ids = torch.tensor(initializer_token_ids) | ||
505 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 511 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
506 | 512 | ||
507 | # Load models and create wrapper for stable diffusion | 513 | # Load models and create wrapper for stable diffusion |
@@ -602,7 +608,7 @@ def main(): | |||
602 | data_file=args.train_data_file, | 608 | data_file=args.train_data_file, |
603 | batch_size=args.train_batch_size, | 609 | batch_size=args.train_batch_size, |
604 | tokenizer=tokenizer, | 610 | tokenizer=tokenizer, |
605 | instance_identifier=args.placeholder_token, | 611 | instance_identifier=args.instance_identifier, |
606 | class_identifier=args.class_identifier, | 612 | class_identifier=args.class_identifier, |
607 | class_subdir="cls", | 613 | class_subdir="cls", |
608 | num_class_images=args.num_class_images, | 614 | num_class_images=args.num_class_images, |
@@ -730,6 +736,7 @@ def main(): | |||
730 | unet=unet, | 736 | unet=unet, |
731 | tokenizer=tokenizer, | 737 | tokenizer=tokenizer, |
732 | text_encoder=text_encoder, | 738 | text_encoder=text_encoder, |
739 | instance_identifier=args.instance_identifier, | ||
733 | placeholder_token=args.placeholder_token, | 740 | placeholder_token=args.placeholder_token, |
734 | placeholder_token_id=placeholder_token_id, | 741 | placeholder_token_id=placeholder_token_id, |
735 | output_dir=basepath, | 742 | output_dir=basepath, |