diff options
| -rw-r--r-- | dreambooth.py | 26 | ||||
| -rw-r--r-- | dreambooth_plus.py | 69 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 2 | ||||
| -rw-r--r-- | schedulers/scheduling_euler_a.py | 3 | ||||
| -rw-r--r-- | textual_inversion.py | 41 |
5 files changed, 82 insertions, 59 deletions
diff --git a/dreambooth.py b/dreambooth.py index 42d3980..770ad38 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -430,7 +430,7 @@ class Checkpointer: | |||
| 430 | eta=eta, | 430 | eta=eta, |
| 431 | num_inference_steps=num_inference_steps, | 431 | num_inference_steps=num_inference_steps, |
| 432 | output_type='pil' | 432 | output_type='pil' |
| 433 | )["sample"] | 433 | ).images |
| 434 | 434 | ||
| 435 | all_samples += samples | 435 | all_samples += samples |
| 436 | 436 | ||
| @@ -537,6 +537,12 @@ def main(): | |||
| 537 | num_train_timesteps=args.noise_timesteps | 537 | num_train_timesteps=args.noise_timesteps |
| 538 | ) | 538 | ) |
| 539 | 539 | ||
| 540 | weight_dtype = torch.float32 | ||
| 541 | if args.mixed_precision == "fp16": | ||
| 542 | weight_dtype = torch.float16 | ||
| 543 | elif args.mixed_precision == "bf16": | ||
| 544 | weight_dtype = torch.bfloat16 | ||
| 545 | |||
| 540 | def collate_fn(examples): | 546 | def collate_fn(examples): |
| 541 | prompts = [example["prompts"] for example in examples] | 547 | prompts = [example["prompts"] for example in examples] |
| 542 | nprompts = [example["nprompts"] for example in examples] | 548 | nprompts = [example["nprompts"] for example in examples] |
| @@ -549,7 +555,7 @@ def main(): | |||
| 549 | pixel_values += [example["class_images"] for example in examples] | 555 | pixel_values += [example["class_images"] for example in examples] |
| 550 | 556 | ||
| 551 | pixel_values = torch.stack(pixel_values) | 557 | pixel_values = torch.stack(pixel_values) |
| 552 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format) | 558 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
| 553 | 559 | ||
| 554 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | 560 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids |
| 555 | 561 | ||
| @@ -651,8 +657,8 @@ def main(): | |||
| 651 | ) | 657 | ) |
| 652 | 658 | ||
| 653 | # Move text_encoder and vae to device | 659 | # Move text_encoder and vae to device |
| 654 | text_encoder.to(accelerator.device) | 660 | text_encoder.to(accelerator.device, dtype=weight_dtype) |
| 655 | vae.to(accelerator.device) | 661 | vae.to(accelerator.device, dtype=weight_dtype) |
| 656 | 662 | ||
| 657 | # Keep text_encoder and vae in eval mode as we don't train these | 663 | # Keep text_encoder and vae in eval mode as we don't train these |
| 658 | text_encoder.eval() | 664 | text_encoder.eval() |
| @@ -738,7 +744,7 @@ def main(): | |||
| 738 | latents = latents * 0.18215 | 744 | latents = latents * 0.18215 |
| 739 | 745 | ||
| 740 | # Sample noise that we'll add to the latents | 746 | # Sample noise that we'll add to the latents |
| 741 | noise = torch.randn(latents.shape).to(latents.device) | 747 | noise = torch.randn_like(latents) |
| 742 | bsz = latents.shape[0] | 748 | bsz = latents.shape[0] |
| 743 | # Sample a random timestep for each image | 749 | # Sample a random timestep for each image |
| 744 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 750 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
| @@ -761,15 +767,15 @@ def main(): | |||
| 761 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 767 | noise, noise_prior = torch.chunk(noise, 2, dim=0) |
| 762 | 768 | ||
| 763 | # Compute instance loss | 769 | # Compute instance loss |
| 764 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 770 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() |
| 765 | 771 | ||
| 766 | # Compute prior loss | 772 | # Compute prior loss |
| 767 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() | 773 | prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") |
| 768 | 774 | ||
| 769 | # Add the prior loss to the instance loss. | 775 | # Add the prior loss to the instance loss. |
| 770 | loss = loss + args.prior_loss_weight * prior_loss | 776 | loss = loss + args.prior_loss_weight * prior_loss |
| 771 | else: | 777 | else: |
| 772 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 778 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") |
| 773 | 779 | ||
| 774 | accelerator.backward(loss) | 780 | accelerator.backward(loss) |
| 775 | if accelerator.sync_gradients: | 781 | if accelerator.sync_gradients: |
| @@ -818,7 +824,7 @@ def main(): | |||
| 818 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 824 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 819 | latents = latents * 0.18215 | 825 | latents = latents * 0.18215 |
| 820 | 826 | ||
| 821 | noise = torch.randn(latents.shape).to(latents.device) | 827 | noise = torch.randn_like(latents) |
| 822 | bsz = latents.shape[0] | 828 | bsz = latents.shape[0] |
| 823 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 829 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
| 824 | (bsz,), device=latents.device) | 830 | (bsz,), device=latents.device) |
| @@ -832,7 +838,7 @@ def main(): | |||
| 832 | 838 | ||
| 833 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 839 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
| 834 | 840 | ||
| 835 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 841 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") |
| 836 | 842 | ||
| 837 | loss = loss.detach().item() | 843 | loss = loss.detach().item() |
| 838 | val_loss += loss | 844 | val_loss += loss |
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index 73225de..a98417f 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
| @@ -56,9 +56,9 @@ def parse_args(): | |||
| 56 | help="A CSV file containing the training data." | 56 | help="A CSV file containing the training data." |
| 57 | ) | 57 | ) |
| 58 | parser.add_argument( | 58 | parser.add_argument( |
| 59 | "--placeholder_token", | 59 | "--instance_identifier", |
| 60 | type=str, | 60 | type=str, |
| 61 | default="<*>", | 61 | default=None, |
| 62 | help="A token to use as a placeholder for the concept.", | 62 | help="A token to use as a placeholder for the concept.", |
| 63 | ) | 63 | ) |
| 64 | parser.add_argument( | 64 | parser.add_argument( |
| @@ -68,6 +68,12 @@ def parse_args(): | |||
| 68 | help="A token to use as a placeholder for the concept.", | 68 | help="A token to use as a placeholder for the concept.", |
| 69 | ) | 69 | ) |
| 70 | parser.add_argument( | 70 | parser.add_argument( |
| 71 | "--placeholder_token", | ||
| 72 | type=str, | ||
| 73 | default="<*>", | ||
| 74 | help="A token to use as a placeholder for the concept.", | ||
| 75 | ) | ||
| 76 | parser.add_argument( | ||
| 71 | "--initializer_token", | 77 | "--initializer_token", |
| 72 | type=str, | 78 | type=str, |
| 73 | default=None, | 79 | default=None, |
| @@ -118,7 +124,7 @@ def parse_args(): | |||
| 118 | parser.add_argument( | 124 | parser.add_argument( |
| 119 | "--max_train_steps", | 125 | "--max_train_steps", |
| 120 | type=int, | 126 | type=int, |
| 121 | default=1200, | 127 | default=1500, |
| 122 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 128 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 123 | ) | 129 | ) |
| 124 | parser.add_argument( | 130 | parser.add_argument( |
| @@ -135,13 +141,13 @@ def parse_args(): | |||
| 135 | parser.add_argument( | 141 | parser.add_argument( |
| 136 | "--learning_rate_unet", | 142 | "--learning_rate_unet", |
| 137 | type=float, | 143 | type=float, |
| 138 | default=5e-5, | 144 | default=5e-6, |
| 139 | help="Initial learning rate (after the potential warmup period) to use.", | 145 | help="Initial learning rate (after the potential warmup period) to use.", |
| 140 | ) | 146 | ) |
| 141 | parser.add_argument( | 147 | parser.add_argument( |
| 142 | "--learning_rate_text", | 148 | "--learning_rate_text", |
| 143 | type=float, | 149 | type=float, |
| 144 | default=1e-6, | 150 | default=5e-6, |
| 145 | help="Initial learning rate (after the potential warmup period) to use.", | 151 | help="Initial learning rate (after the potential warmup period) to use.", |
| 146 | ) | 152 | ) |
| 147 | parser.add_argument( | 153 | parser.add_argument( |
| @@ -353,6 +359,7 @@ class Checkpointer: | |||
| 353 | ema_unet, | 359 | ema_unet, |
| 354 | tokenizer, | 360 | tokenizer, |
| 355 | text_encoder, | 361 | text_encoder, |
| 362 | instance_identifier, | ||
| 356 | placeholder_token, | 363 | placeholder_token, |
| 357 | placeholder_token_id, | 364 | placeholder_token_id, |
| 358 | output_dir: Path, | 365 | output_dir: Path, |
| @@ -368,6 +375,7 @@ class Checkpointer: | |||
| 368 | self.ema_unet = ema_unet | 375 | self.ema_unet = ema_unet |
| 369 | self.tokenizer = tokenizer | 376 | self.tokenizer = tokenizer |
| 370 | self.text_encoder = text_encoder | 377 | self.text_encoder = text_encoder |
| 378 | self.instance_identifier = instance_identifier | ||
| 371 | self.placeholder_token = placeholder_token | 379 | self.placeholder_token = placeholder_token |
| 372 | self.placeholder_token_id = placeholder_token_id | 380 | self.placeholder_token_id = placeholder_token_id |
| 373 | self.output_dir = output_dir | 381 | self.output_dir = output_dir |
| @@ -461,7 +469,7 @@ class Checkpointer: | |||
| 461 | 469 | ||
| 462 | for i in range(self.sample_batches): | 470 | for i in range(self.sample_batches): |
| 463 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 471 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
| 464 | prompt = [prompt.format(self.placeholder_token) | 472 | prompt = [prompt.format(self.instance_identifier) |
| 465 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 473 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] |
| 466 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 474 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] |
| 467 | 475 | ||
| @@ -476,7 +484,7 @@ class Checkpointer: | |||
| 476 | eta=eta, | 484 | eta=eta, |
| 477 | num_inference_steps=num_inference_steps, | 485 | num_inference_steps=num_inference_steps, |
| 478 | output_type='pil' | 486 | output_type='pil' |
| 479 | )["sample"] | 487 | ).images |
| 480 | 488 | ||
| 481 | all_samples += samples | 489 | all_samples += samples |
| 482 | 490 | ||
| @@ -522,28 +530,26 @@ def main(): | |||
| 522 | if args.seed is not None: | 530 | if args.seed is not None: |
| 523 | set_seed(args.seed) | 531 | set_seed(args.seed) |
| 524 | 532 | ||
| 533 | args.instance_identifier = args.instance_identifier.format(args.placeholder_token) | ||
| 534 | |||
| 525 | # Load the tokenizer and add the placeholder token as a additional special token | 535 | # Load the tokenizer and add the placeholder token as a additional special token |
| 526 | if args.tokenizer_name: | 536 | if args.tokenizer_name: |
| 527 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) | 537 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) |
| 528 | elif args.pretrained_model_name_or_path: | 538 | elif args.pretrained_model_name_or_path: |
| 529 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 539 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
| 530 | 540 | ||
| 541 | # Convert the initializer_token, placeholder_token to ids | ||
| 542 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | ||
| 543 | print(f"Initializer token maps to {len(initializer_token_ids)} embeddings.") | ||
| 544 | initializer_token_ids = torch.tensor(initializer_token_ids[:1]) | ||
| 545 | |||
| 531 | # Add the placeholder token in tokenizer | 546 | # Add the placeholder token in tokenizer |
| 532 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | 547 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) |
| 533 | if num_added_tokens == 0: | 548 | if num_added_tokens == 0: |
| 534 | raise ValueError( | 549 | print(f"Re-using existing token {args.placeholder_token}.") |
| 535 | f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" | 550 | else: |
| 536 | " `placeholder_token` that is not already in the tokenizer." | 551 | print(f"Training new token {args.placeholder_token}.") |
| 537 | ) | ||
| 538 | |||
| 539 | # Convert the initializer_token, placeholder_token to ids | ||
| 540 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | ||
| 541 | # Check if initializer_token is a single token or a sequence of tokens | ||
| 542 | if len(initializer_token_ids) > 1: | ||
| 543 | raise ValueError( | ||
| 544 | f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.") | ||
| 545 | 552 | ||
| 546 | initializer_token_ids = torch.tensor(initializer_token_ids) | ||
| 547 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 553 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
| 548 | 554 | ||
| 549 | # Load models and create wrapper for stable diffusion | 555 | # Load models and create wrapper for stable diffusion |
| @@ -630,6 +636,12 @@ def main(): | |||
| 630 | num_train_timesteps=args.noise_timesteps | 636 | num_train_timesteps=args.noise_timesteps |
| 631 | ) | 637 | ) |
| 632 | 638 | ||
| 639 | weight_dtype = torch.float32 | ||
| 640 | if args.mixed_precision == "fp16": | ||
| 641 | weight_dtype = torch.float16 | ||
| 642 | elif args.mixed_precision == "bf16": | ||
| 643 | weight_dtype = torch.bfloat16 | ||
| 644 | |||
| 633 | def collate_fn(examples): | 645 | def collate_fn(examples): |
| 634 | prompts = [example["prompts"] for example in examples] | 646 | prompts = [example["prompts"] for example in examples] |
| 635 | nprompts = [example["nprompts"] for example in examples] | 647 | nprompts = [example["nprompts"] for example in examples] |
| @@ -642,7 +654,7 @@ def main(): | |||
| 642 | pixel_values += [example["class_images"] for example in examples] | 654 | pixel_values += [example["class_images"] for example in examples] |
| 643 | 655 | ||
| 644 | pixel_values = torch.stack(pixel_values) | 656 | pixel_values = torch.stack(pixel_values) |
| 645 | pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) | 657 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
| 646 | 658 | ||
| 647 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | 659 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids |
| 648 | 660 | ||
| @@ -658,7 +670,7 @@ def main(): | |||
| 658 | data_file=args.train_data_file, | 670 | data_file=args.train_data_file, |
| 659 | batch_size=args.train_batch_size, | 671 | batch_size=args.train_batch_size, |
| 660 | tokenizer=tokenizer, | 672 | tokenizer=tokenizer, |
| 661 | instance_identifier=args.placeholder_token, | 673 | instance_identifier=args.instance_identifier, |
| 662 | class_identifier=args.class_identifier, | 674 | class_identifier=args.class_identifier, |
| 663 | class_subdir="cls", | 675 | class_subdir="cls", |
| 664 | num_class_images=args.num_class_images, | 676 | num_class_images=args.num_class_images, |
| @@ -744,7 +756,7 @@ def main(): | |||
| 744 | ) | 756 | ) |
| 745 | 757 | ||
| 746 | # Move vae and unet to device | 758 | # Move vae and unet to device |
| 747 | vae.to(accelerator.device) | 759 | vae.to(accelerator.device, dtype=weight_dtype) |
| 748 | 760 | ||
| 749 | # Keep vae and unet in eval mode as we don't train these | 761 | # Keep vae and unet in eval mode as we don't train these |
| 750 | vae.eval() | 762 | vae.eval() |
| @@ -785,6 +797,7 @@ def main(): | |||
| 785 | ema_unet=ema_unet, | 797 | ema_unet=ema_unet, |
| 786 | tokenizer=tokenizer, | 798 | tokenizer=tokenizer, |
| 787 | text_encoder=text_encoder, | 799 | text_encoder=text_encoder, |
| 800 | instance_identifier=args.instance_identifier, | ||
| 788 | placeholder_token=args.placeholder_token, | 801 | placeholder_token=args.placeholder_token, |
| 789 | placeholder_token_id=placeholder_token_id, | 802 | placeholder_token_id=placeholder_token_id, |
| 790 | output_dir=basepath, | 803 | output_dir=basepath, |
| @@ -830,7 +843,7 @@ def main(): | |||
| 830 | latents = latents * 0.18215 | 843 | latents = latents * 0.18215 |
| 831 | 844 | ||
| 832 | # Sample noise that we'll add to the latents | 845 | # Sample noise that we'll add to the latents |
| 833 | noise = torch.randn(latents.shape).to(latents.device) | 846 | noise = torch.randn_like(latents) |
| 834 | bsz = latents.shape[0] | 847 | bsz = latents.shape[0] |
| 835 | # Sample a random timestep for each image | 848 | # Sample a random timestep for each image |
| 836 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 849 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
| @@ -853,15 +866,15 @@ def main(): | |||
| 853 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 866 | noise, noise_prior = torch.chunk(noise, 2, dim=0) |
| 854 | 867 | ||
| 855 | # Compute instance loss | 868 | # Compute instance loss |
| 856 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 869 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() |
| 857 | 870 | ||
| 858 | # Compute prior loss | 871 | # Compute prior loss |
| 859 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() | 872 | prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") |
| 860 | 873 | ||
| 861 | # Add the prior loss to the instance loss. | 874 | # Add the prior loss to the instance loss. |
| 862 | loss = loss + args.prior_loss_weight * prior_loss | 875 | loss = loss + args.prior_loss_weight * prior_loss |
| 863 | else: | 876 | else: |
| 864 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 877 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") |
| 865 | 878 | ||
| 866 | accelerator.backward(loss) | 879 | accelerator.backward(loss) |
| 867 | 880 | ||
| @@ -933,7 +946,7 @@ def main(): | |||
| 933 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 946 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 934 | latents = latents * 0.18215 | 947 | latents = latents * 0.18215 |
| 935 | 948 | ||
| 936 | noise = torch.randn(latents.shape).to(latents.device) | 949 | noise = torch.randn_like(latents) |
| 937 | bsz = latents.shape[0] | 950 | bsz = latents.shape[0] |
| 938 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 951 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
| 939 | (bsz,), device=latents.device) | 952 | (bsz,), device=latents.device) |
| @@ -947,7 +960,7 @@ def main(): | |||
| 947 | 960 | ||
| 948 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 961 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
| 949 | 962 | ||
| 950 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 963 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") |
| 951 | 964 | ||
| 952 | loss = loss.detach().item() | 965 | loss = loss.detach().item() |
| 953 | val_loss += loss | 966 | val_loss += loss |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 2656b28..8b08a6f 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -301,7 +301,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 301 | 301 | ||
| 302 | # scale and decode the image latents with vae | 302 | # scale and decode the image latents with vae |
| 303 | latents = 1 / 0.18215 * latents | 303 | latents = 1 / 0.18215 * latents |
| 304 | image = self.vae.decode(latents).sample | 304 | image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample |
| 305 | 305 | ||
| 306 | image = (image / 2 + 0.5).clamp(0, 1) | 306 | image = (image / 2 + 0.5).clamp(0, 1) |
| 307 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() | 307 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 6abe971..c097a8a 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py | |||
| @@ -47,7 +47,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 47 | beta_end: float = 0.02, | 47 | beta_end: float = 0.02, |
| 48 | beta_schedule: str = "linear", | 48 | beta_schedule: str = "linear", |
| 49 | trained_betas: Optional[np.ndarray] = None, | 49 | trained_betas: Optional[np.ndarray] = None, |
| 50 | tensor_format: str = "pt", | ||
| 51 | num_inference_steps=None, | 50 | num_inference_steps=None, |
| 52 | device='cuda' | 51 | device='cuda' |
| 53 | ): | 52 | ): |
| @@ -63,7 +62,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 63 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") | 62 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") |
| 64 | 63 | ||
| 65 | self.device = device | 64 | self.device = device |
| 66 | self.tensor_format = tensor_format | ||
| 67 | 65 | ||
| 68 | self.alphas = 1.0 - self.betas | 66 | self.alphas = 1.0 - self.betas |
| 69 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | 67 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) |
| @@ -77,7 +75,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 77 | # get sigmas | 75 | # get sigmas |
| 78 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 | 76 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 |
| 79 | self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps) | 77 | self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps) |
| 80 | self.set_format(tensor_format=tensor_format) | ||
| 81 | 78 | ||
| 82 | # A# take number of steps as input | 79 | # A# take number of steps as input |
| 83 | # A# store 1) number of steps 2) timesteps 3) schedule | 80 | # A# store 1) number of steps 2) timesteps 3) schedule |
diff --git a/textual_inversion.py b/textual_inversion.py index 0d5a742..69d9c7f 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -55,9 +55,9 @@ def parse_args(): | |||
| 55 | help="A CSV file containing the training data." | 55 | help="A CSV file containing the training data." |
| 56 | ) | 56 | ) |
| 57 | parser.add_argument( | 57 | parser.add_argument( |
| 58 | "--placeholder_token", | 58 | "--instance_identifier", |
| 59 | type=str, | 59 | type=str, |
| 60 | default="<*>", | 60 | default=None, |
| 61 | help="A token to use as a placeholder for the concept.", | 61 | help="A token to use as a placeholder for the concept.", |
| 62 | ) | 62 | ) |
| 63 | parser.add_argument( | 63 | parser.add_argument( |
| @@ -67,6 +67,12 @@ def parse_args(): | |||
| 67 | help="A token to use as a placeholder for the concept.", | 67 | help="A token to use as a placeholder for the concept.", |
| 68 | ) | 68 | ) |
| 69 | parser.add_argument( | 69 | parser.add_argument( |
| 70 | "--placeholder_token", | ||
| 71 | type=str, | ||
| 72 | default="<*>", | ||
| 73 | help="A token to use as a placeholder for the concept.", | ||
| 74 | ) | ||
| 75 | parser.add_argument( | ||
| 70 | "--initializer_token", | 76 | "--initializer_token", |
| 71 | type=str, | 77 | type=str, |
| 72 | default=None, | 78 | default=None, |
| @@ -333,6 +339,7 @@ class Checkpointer: | |||
| 333 | unet, | 339 | unet, |
| 334 | tokenizer, | 340 | tokenizer, |
| 335 | text_encoder, | 341 | text_encoder, |
| 342 | instance_identifier, | ||
| 336 | placeholder_token, | 343 | placeholder_token, |
| 337 | placeholder_token_id, | 344 | placeholder_token_id, |
| 338 | output_dir: Path, | 345 | output_dir: Path, |
| @@ -347,6 +354,7 @@ class Checkpointer: | |||
| 347 | self.unet = unet | 354 | self.unet = unet |
| 348 | self.tokenizer = tokenizer | 355 | self.tokenizer = tokenizer |
| 349 | self.text_encoder = text_encoder | 356 | self.text_encoder = text_encoder |
| 357 | self.instance_identifier = instance_identifier | ||
| 350 | self.placeholder_token = placeholder_token | 358 | self.placeholder_token = placeholder_token |
| 351 | self.placeholder_token_id = placeholder_token_id | 359 | self.placeholder_token_id = placeholder_token_id |
| 352 | self.output_dir = output_dir | 360 | self.output_dir = output_dir |
| @@ -413,7 +421,7 @@ class Checkpointer: | |||
| 413 | 421 | ||
| 414 | for i in range(self.sample_batches): | 422 | for i in range(self.sample_batches): |
| 415 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 423 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
| 416 | prompt = [prompt.format(self.placeholder_token) | 424 | prompt = [prompt.format(self.instance_identifier) |
| 417 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 425 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] |
| 418 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 426 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] |
| 419 | 427 | ||
| @@ -428,7 +436,7 @@ class Checkpointer: | |||
| 428 | eta=eta, | 436 | eta=eta, |
| 429 | num_inference_steps=num_inference_steps, | 437 | num_inference_steps=num_inference_steps, |
| 430 | output_type='pil' | 438 | output_type='pil' |
| 431 | )["sample"] | 439 | ).images |
| 432 | 440 | ||
| 433 | all_samples += samples | 441 | all_samples += samples |
| 434 | 442 | ||
| @@ -480,28 +488,26 @@ def main(): | |||
| 480 | if args.seed is not None: | 488 | if args.seed is not None: |
| 481 | set_seed(args.seed) | 489 | set_seed(args.seed) |
| 482 | 490 | ||
| 491 | args.instance_identifier = args.instance_identifier.format(args.placeholder_token) | ||
| 492 | |||
| 483 | # Load the tokenizer and add the placeholder token as a additional special token | 493 | # Load the tokenizer and add the placeholder token as a additional special token |
| 484 | if args.tokenizer_name: | 494 | if args.tokenizer_name: |
| 485 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) | 495 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) |
| 486 | elif args.pretrained_model_name_or_path: | 496 | elif args.pretrained_model_name_or_path: |
| 487 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 497 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
| 488 | 498 | ||
| 499 | # Convert the initializer_token, placeholder_token to ids | ||
| 500 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | ||
| 501 | print(f"Initializer token maps to {len(initializer_token_ids)} embeddings.") | ||
| 502 | initializer_token_ids = torch.tensor(initializer_token_ids[:1]) | ||
| 503 | |||
| 489 | # Add the placeholder token in tokenizer | 504 | # Add the placeholder token in tokenizer |
| 490 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | 505 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) |
| 491 | if num_added_tokens == 0: | 506 | if num_added_tokens == 0: |
| 492 | raise ValueError( | 507 | print(f"Re-using existing token {args.placeholder_token}.") |
| 493 | f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" | 508 | else: |
| 494 | " `placeholder_token` that is not already in the tokenizer." | 509 | print(f"Training new token {args.placeholder_token}.") |
| 495 | ) | ||
| 496 | |||
| 497 | # Convert the initializer_token, placeholder_token to ids | ||
| 498 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | ||
| 499 | # Check if initializer_token is a single token or a sequence of tokens | ||
| 500 | if len(initializer_token_ids) > 1: | ||
| 501 | raise ValueError( | ||
| 502 | f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.") | ||
| 503 | 510 | ||
| 504 | initializer_token_ids = torch.tensor(initializer_token_ids) | ||
| 505 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 511 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
| 506 | 512 | ||
| 507 | # Load models and create wrapper for stable diffusion | 513 | # Load models and create wrapper for stable diffusion |
| @@ -602,7 +608,7 @@ def main(): | |||
| 602 | data_file=args.train_data_file, | 608 | data_file=args.train_data_file, |
| 603 | batch_size=args.train_batch_size, | 609 | batch_size=args.train_batch_size, |
| 604 | tokenizer=tokenizer, | 610 | tokenizer=tokenizer, |
| 605 | instance_identifier=args.placeholder_token, | 611 | instance_identifier=args.instance_identifier, |
| 606 | class_identifier=args.class_identifier, | 612 | class_identifier=args.class_identifier, |
| 607 | class_subdir="cls", | 613 | class_subdir="cls", |
| 608 | num_class_images=args.num_class_images, | 614 | num_class_images=args.num_class_images, |
| @@ -730,6 +736,7 @@ def main(): | |||
| 730 | unet=unet, | 736 | unet=unet, |
| 731 | tokenizer=tokenizer, | 737 | tokenizer=tokenizer, |
| 732 | text_encoder=text_encoder, | 738 | text_encoder=text_encoder, |
| 739 | instance_identifier=args.instance_identifier, | ||
| 733 | placeholder_token=args.placeholder_token, | 740 | placeholder_token=args.placeholder_token, |
| 734 | placeholder_token_id=placeholder_token_id, | 741 | placeholder_token_id=placeholder_token_id, |
| 735 | output_dir=basepath, | 742 | output_dir=basepath, |
