diff options
Diffstat (limited to 'dreambooth_plus.py')
| -rw-r--r-- | dreambooth_plus.py | 69 |
1 files changed, 41 insertions, 28 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index 73225de..a98417f 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
| @@ -56,9 +56,9 @@ def parse_args(): | |||
| 56 | help="A CSV file containing the training data." | 56 | help="A CSV file containing the training data." |
| 57 | ) | 57 | ) |
| 58 | parser.add_argument( | 58 | parser.add_argument( |
| 59 | "--placeholder_token", | 59 | "--instance_identifier", |
| 60 | type=str, | 60 | type=str, |
| 61 | default="<*>", | 61 | default=None, |
| 62 | help="A token to use as a placeholder for the concept.", | 62 | help="A token to use as a placeholder for the concept.", |
| 63 | ) | 63 | ) |
| 64 | parser.add_argument( | 64 | parser.add_argument( |
| @@ -68,6 +68,12 @@ def parse_args(): | |||
| 68 | help="A token to use as a placeholder for the concept.", | 68 | help="A token to use as a placeholder for the concept.", |
| 69 | ) | 69 | ) |
| 70 | parser.add_argument( | 70 | parser.add_argument( |
| 71 | "--placeholder_token", | ||
| 72 | type=str, | ||
| 73 | default="<*>", | ||
| 74 | help="A token to use as a placeholder for the concept.", | ||
| 75 | ) | ||
| 76 | parser.add_argument( | ||
| 71 | "--initializer_token", | 77 | "--initializer_token", |
| 72 | type=str, | 78 | type=str, |
| 73 | default=None, | 79 | default=None, |
| @@ -118,7 +124,7 @@ def parse_args(): | |||
| 118 | parser.add_argument( | 124 | parser.add_argument( |
| 119 | "--max_train_steps", | 125 | "--max_train_steps", |
| 120 | type=int, | 126 | type=int, |
| 121 | default=1200, | 127 | default=1500, |
| 122 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 128 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 123 | ) | 129 | ) |
| 124 | parser.add_argument( | 130 | parser.add_argument( |
| @@ -135,13 +141,13 @@ def parse_args(): | |||
| 135 | parser.add_argument( | 141 | parser.add_argument( |
| 136 | "--learning_rate_unet", | 142 | "--learning_rate_unet", |
| 137 | type=float, | 143 | type=float, |
| 138 | default=5e-5, | 144 | default=5e-6, |
| 139 | help="Initial learning rate (after the potential warmup period) to use.", | 145 | help="Initial learning rate (after the potential warmup period) to use.", |
| 140 | ) | 146 | ) |
| 141 | parser.add_argument( | 147 | parser.add_argument( |
| 142 | "--learning_rate_text", | 148 | "--learning_rate_text", |
| 143 | type=float, | 149 | type=float, |
| 144 | default=1e-6, | 150 | default=5e-6, |
| 145 | help="Initial learning rate (after the potential warmup period) to use.", | 151 | help="Initial learning rate (after the potential warmup period) to use.", |
| 146 | ) | 152 | ) |
| 147 | parser.add_argument( | 153 | parser.add_argument( |
| @@ -353,6 +359,7 @@ class Checkpointer: | |||
| 353 | ema_unet, | 359 | ema_unet, |
| 354 | tokenizer, | 360 | tokenizer, |
| 355 | text_encoder, | 361 | text_encoder, |
| 362 | instance_identifier, | ||
| 356 | placeholder_token, | 363 | placeholder_token, |
| 357 | placeholder_token_id, | 364 | placeholder_token_id, |
| 358 | output_dir: Path, | 365 | output_dir: Path, |
| @@ -368,6 +375,7 @@ class Checkpointer: | |||
| 368 | self.ema_unet = ema_unet | 375 | self.ema_unet = ema_unet |
| 369 | self.tokenizer = tokenizer | 376 | self.tokenizer = tokenizer |
| 370 | self.text_encoder = text_encoder | 377 | self.text_encoder = text_encoder |
| 378 | self.instance_identifier = instance_identifier | ||
| 371 | self.placeholder_token = placeholder_token | 379 | self.placeholder_token = placeholder_token |
| 372 | self.placeholder_token_id = placeholder_token_id | 380 | self.placeholder_token_id = placeholder_token_id |
| 373 | self.output_dir = output_dir | 381 | self.output_dir = output_dir |
| @@ -461,7 +469,7 @@ class Checkpointer: | |||
| 461 | 469 | ||
| 462 | for i in range(self.sample_batches): | 470 | for i in range(self.sample_batches): |
| 463 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 471 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
| 464 | prompt = [prompt.format(self.placeholder_token) | 472 | prompt = [prompt.format(self.instance_identifier) |
| 465 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 473 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] |
| 466 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 474 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] |
| 467 | 475 | ||
| @@ -476,7 +484,7 @@ class Checkpointer: | |||
| 476 | eta=eta, | 484 | eta=eta, |
| 477 | num_inference_steps=num_inference_steps, | 485 | num_inference_steps=num_inference_steps, |
| 478 | output_type='pil' | 486 | output_type='pil' |
| 479 | )["sample"] | 487 | ).images |
| 480 | 488 | ||
| 481 | all_samples += samples | 489 | all_samples += samples |
| 482 | 490 | ||
| @@ -522,28 +530,26 @@ def main(): | |||
| 522 | if args.seed is not None: | 530 | if args.seed is not None: |
| 523 | set_seed(args.seed) | 531 | set_seed(args.seed) |
| 524 | 532 | ||
| 533 | args.instance_identifier = args.instance_identifier.format(args.placeholder_token) | ||
| 534 | |||
| 525 | # Load the tokenizer and add the placeholder token as a additional special token | 535 | # Load the tokenizer and add the placeholder token as a additional special token |
| 526 | if args.tokenizer_name: | 536 | if args.tokenizer_name: |
| 527 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) | 537 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) |
| 528 | elif args.pretrained_model_name_or_path: | 538 | elif args.pretrained_model_name_or_path: |
| 529 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 539 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
| 530 | 540 | ||
| 541 | # Convert the initializer_token, placeholder_token to ids | ||
| 542 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | ||
| 543 | print(f"Initializer token maps to {len(initializer_token_ids)} embeddings.") | ||
| 544 | initializer_token_ids = torch.tensor(initializer_token_ids[:1]) | ||
| 545 | |||
| 531 | # Add the placeholder token in tokenizer | 546 | # Add the placeholder token in tokenizer |
| 532 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | 547 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) |
| 533 | if num_added_tokens == 0: | 548 | if num_added_tokens == 0: |
| 534 | raise ValueError( | 549 | print(f"Re-using existing token {args.placeholder_token}.") |
| 535 | f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" | 550 | else: |
| 536 | " `placeholder_token` that is not already in the tokenizer." | 551 | print(f"Training new token {args.placeholder_token}.") |
| 537 | ) | ||
| 538 | |||
| 539 | # Convert the initializer_token, placeholder_token to ids | ||
| 540 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | ||
| 541 | # Check if initializer_token is a single token or a sequence of tokens | ||
| 542 | if len(initializer_token_ids) > 1: | ||
| 543 | raise ValueError( | ||
| 544 | f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.") | ||
| 545 | 552 | ||
| 546 | initializer_token_ids = torch.tensor(initializer_token_ids) | ||
| 547 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 553 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
| 548 | 554 | ||
| 549 | # Load models and create wrapper for stable diffusion | 555 | # Load models and create wrapper for stable diffusion |
| @@ -630,6 +636,12 @@ def main(): | |||
| 630 | num_train_timesteps=args.noise_timesteps | 636 | num_train_timesteps=args.noise_timesteps |
| 631 | ) | 637 | ) |
| 632 | 638 | ||
| 639 | weight_dtype = torch.float32 | ||
| 640 | if args.mixed_precision == "fp16": | ||
| 641 | weight_dtype = torch.float16 | ||
| 642 | elif args.mixed_precision == "bf16": | ||
| 643 | weight_dtype = torch.bfloat16 | ||
| 644 | |||
| 633 | def collate_fn(examples): | 645 | def collate_fn(examples): |
| 634 | prompts = [example["prompts"] for example in examples] | 646 | prompts = [example["prompts"] for example in examples] |
| 635 | nprompts = [example["nprompts"] for example in examples] | 647 | nprompts = [example["nprompts"] for example in examples] |
| @@ -642,7 +654,7 @@ def main(): | |||
| 642 | pixel_values += [example["class_images"] for example in examples] | 654 | pixel_values += [example["class_images"] for example in examples] |
| 643 | 655 | ||
| 644 | pixel_values = torch.stack(pixel_values) | 656 | pixel_values = torch.stack(pixel_values) |
| 645 | pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) | 657 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
| 646 | 658 | ||
| 647 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | 659 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids |
| 648 | 660 | ||
| @@ -658,7 +670,7 @@ def main(): | |||
| 658 | data_file=args.train_data_file, | 670 | data_file=args.train_data_file, |
| 659 | batch_size=args.train_batch_size, | 671 | batch_size=args.train_batch_size, |
| 660 | tokenizer=tokenizer, | 672 | tokenizer=tokenizer, |
| 661 | instance_identifier=args.placeholder_token, | 673 | instance_identifier=args.instance_identifier, |
| 662 | class_identifier=args.class_identifier, | 674 | class_identifier=args.class_identifier, |
| 663 | class_subdir="cls", | 675 | class_subdir="cls", |
| 664 | num_class_images=args.num_class_images, | 676 | num_class_images=args.num_class_images, |
| @@ -744,7 +756,7 @@ def main(): | |||
| 744 | ) | 756 | ) |
| 745 | 757 | ||
| 746 | # Move vae and unet to device | 758 | # Move vae and unet to device |
| 747 | vae.to(accelerator.device) | 759 | vae.to(accelerator.device, dtype=weight_dtype) |
| 748 | 760 | ||
| 749 | # Keep vae and unet in eval mode as we don't train these | 761 | # Keep vae and unet in eval mode as we don't train these |
| 750 | vae.eval() | 762 | vae.eval() |
| @@ -785,6 +797,7 @@ def main(): | |||
| 785 | ema_unet=ema_unet, | 797 | ema_unet=ema_unet, |
| 786 | tokenizer=tokenizer, | 798 | tokenizer=tokenizer, |
| 787 | text_encoder=text_encoder, | 799 | text_encoder=text_encoder, |
| 800 | instance_identifier=args.instance_identifier, | ||
| 788 | placeholder_token=args.placeholder_token, | 801 | placeholder_token=args.placeholder_token, |
| 789 | placeholder_token_id=placeholder_token_id, | 802 | placeholder_token_id=placeholder_token_id, |
| 790 | output_dir=basepath, | 803 | output_dir=basepath, |
| @@ -830,7 +843,7 @@ def main(): | |||
| 830 | latents = latents * 0.18215 | 843 | latents = latents * 0.18215 |
| 831 | 844 | ||
| 832 | # Sample noise that we'll add to the latents | 845 | # Sample noise that we'll add to the latents |
| 833 | noise = torch.randn(latents.shape).to(latents.device) | 846 | noise = torch.randn_like(latents) |
| 834 | bsz = latents.shape[0] | 847 | bsz = latents.shape[0] |
| 835 | # Sample a random timestep for each image | 848 | # Sample a random timestep for each image |
| 836 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 849 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
| @@ -853,15 +866,15 @@ def main(): | |||
| 853 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 866 | noise, noise_prior = torch.chunk(noise, 2, dim=0) |
| 854 | 867 | ||
| 855 | # Compute instance loss | 868 | # Compute instance loss |
| 856 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 869 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() |
| 857 | 870 | ||
| 858 | # Compute prior loss | 871 | # Compute prior loss |
| 859 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() | 872 | prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") |
| 860 | 873 | ||
| 861 | # Add the prior loss to the instance loss. | 874 | # Add the prior loss to the instance loss. |
| 862 | loss = loss + args.prior_loss_weight * prior_loss | 875 | loss = loss + args.prior_loss_weight * prior_loss |
| 863 | else: | 876 | else: |
| 864 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 877 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") |
| 865 | 878 | ||
| 866 | accelerator.backward(loss) | 879 | accelerator.backward(loss) |
| 867 | 880 | ||
| @@ -933,7 +946,7 @@ def main(): | |||
| 933 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 946 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 934 | latents = latents * 0.18215 | 947 | latents = latents * 0.18215 |
| 935 | 948 | ||
| 936 | noise = torch.randn(latents.shape).to(latents.device) | 949 | noise = torch.randn_like(latents) |
| 937 | bsz = latents.shape[0] | 950 | bsz = latents.shape[0] |
| 938 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 951 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
| 939 | (bsz,), device=latents.device) | 952 | (bsz,), device=latents.device) |
| @@ -947,7 +960,7 @@ def main(): | |||
| 947 | 960 | ||
| 948 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 961 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
| 949 | 962 | ||
| 950 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 963 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") |
| 951 | 964 | ||
| 952 | loss = loss.detach().item() | 965 | loss = loss.detach().item() |
| 953 | val_loss += loss | 966 | val_loss += loss |
