diff options
Diffstat (limited to 'dreambooth_plus.py')
-rw-r--r-- | dreambooth_plus.py | 69 |
1 files changed, 41 insertions, 28 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index 73225de..a98417f 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
@@ -56,9 +56,9 @@ def parse_args(): | |||
56 | help="A CSV file containing the training data." | 56 | help="A CSV file containing the training data." |
57 | ) | 57 | ) |
58 | parser.add_argument( | 58 | parser.add_argument( |
59 | "--placeholder_token", | 59 | "--instance_identifier", |
60 | type=str, | 60 | type=str, |
61 | default="<*>", | 61 | default=None, |
62 | help="A token to use as a placeholder for the concept.", | 62 | help="A token to use as a placeholder for the concept.", |
63 | ) | 63 | ) |
64 | parser.add_argument( | 64 | parser.add_argument( |
@@ -68,6 +68,12 @@ def parse_args(): | |||
68 | help="A token to use as a placeholder for the concept.", | 68 | help="A token to use as a placeholder for the concept.", |
69 | ) | 69 | ) |
70 | parser.add_argument( | 70 | parser.add_argument( |
71 | "--placeholder_token", | ||
72 | type=str, | ||
73 | default="<*>", | ||
74 | help="A token to use as a placeholder for the concept.", | ||
75 | ) | ||
76 | parser.add_argument( | ||
71 | "--initializer_token", | 77 | "--initializer_token", |
72 | type=str, | 78 | type=str, |
73 | default=None, | 79 | default=None, |
@@ -118,7 +124,7 @@ def parse_args(): | |||
118 | parser.add_argument( | 124 | parser.add_argument( |
119 | "--max_train_steps", | 125 | "--max_train_steps", |
120 | type=int, | 126 | type=int, |
121 | default=1200, | 127 | default=1500, |
122 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 128 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
123 | ) | 129 | ) |
124 | parser.add_argument( | 130 | parser.add_argument( |
@@ -135,13 +141,13 @@ def parse_args(): | |||
135 | parser.add_argument( | 141 | parser.add_argument( |
136 | "--learning_rate_unet", | 142 | "--learning_rate_unet", |
137 | type=float, | 143 | type=float, |
138 | default=5e-5, | 144 | default=5e-6, |
139 | help="Initial learning rate (after the potential warmup period) to use.", | 145 | help="Initial learning rate (after the potential warmup period) to use.", |
140 | ) | 146 | ) |
141 | parser.add_argument( | 147 | parser.add_argument( |
142 | "--learning_rate_text", | 148 | "--learning_rate_text", |
143 | type=float, | 149 | type=float, |
144 | default=1e-6, | 150 | default=5e-6, |
145 | help="Initial learning rate (after the potential warmup period) to use.", | 151 | help="Initial learning rate (after the potential warmup period) to use.", |
146 | ) | 152 | ) |
147 | parser.add_argument( | 153 | parser.add_argument( |
@@ -353,6 +359,7 @@ class Checkpointer: | |||
353 | ema_unet, | 359 | ema_unet, |
354 | tokenizer, | 360 | tokenizer, |
355 | text_encoder, | 361 | text_encoder, |
362 | instance_identifier, | ||
356 | placeholder_token, | 363 | placeholder_token, |
357 | placeholder_token_id, | 364 | placeholder_token_id, |
358 | output_dir: Path, | 365 | output_dir: Path, |
@@ -368,6 +375,7 @@ class Checkpointer: | |||
368 | self.ema_unet = ema_unet | 375 | self.ema_unet = ema_unet |
369 | self.tokenizer = tokenizer | 376 | self.tokenizer = tokenizer |
370 | self.text_encoder = text_encoder | 377 | self.text_encoder = text_encoder |
378 | self.instance_identifier = instance_identifier | ||
371 | self.placeholder_token = placeholder_token | 379 | self.placeholder_token = placeholder_token |
372 | self.placeholder_token_id = placeholder_token_id | 380 | self.placeholder_token_id = placeholder_token_id |
373 | self.output_dir = output_dir | 381 | self.output_dir = output_dir |
@@ -461,7 +469,7 @@ class Checkpointer: | |||
461 | 469 | ||
462 | for i in range(self.sample_batches): | 470 | for i in range(self.sample_batches): |
463 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 471 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
464 | prompt = [prompt.format(self.placeholder_token) | 472 | prompt = [prompt.format(self.instance_identifier) |
465 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 473 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] |
466 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 474 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] |
467 | 475 | ||
@@ -476,7 +484,7 @@ class Checkpointer: | |||
476 | eta=eta, | 484 | eta=eta, |
477 | num_inference_steps=num_inference_steps, | 485 | num_inference_steps=num_inference_steps, |
478 | output_type='pil' | 486 | output_type='pil' |
479 | )["sample"] | 487 | ).images |
480 | 488 | ||
481 | all_samples += samples | 489 | all_samples += samples |
482 | 490 | ||
@@ -522,28 +530,26 @@ def main(): | |||
522 | if args.seed is not None: | 530 | if args.seed is not None: |
523 | set_seed(args.seed) | 531 | set_seed(args.seed) |
524 | 532 | ||
533 | args.instance_identifier = args.instance_identifier.format(args.placeholder_token) | ||
534 | |||
525 | # Load the tokenizer and add the placeholder token as a additional special token | 535 | # Load the tokenizer and add the placeholder token as a additional special token |
526 | if args.tokenizer_name: | 536 | if args.tokenizer_name: |
527 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) | 537 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) |
528 | elif args.pretrained_model_name_or_path: | 538 | elif args.pretrained_model_name_or_path: |
529 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 539 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
530 | 540 | ||
541 | # Convert the initializer_token, placeholder_token to ids | ||
542 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | ||
543 | print(f"Initializer token maps to {len(initializer_token_ids)} embeddings.") | ||
544 | initializer_token_ids = torch.tensor(initializer_token_ids[:1]) | ||
545 | |||
531 | # Add the placeholder token in tokenizer | 546 | # Add the placeholder token in tokenizer |
532 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | 547 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) |
533 | if num_added_tokens == 0: | 548 | if num_added_tokens == 0: |
534 | raise ValueError( | 549 | print(f"Re-using existing token {args.placeholder_token}.") |
535 | f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" | 550 | else: |
536 | " `placeholder_token` that is not already in the tokenizer." | 551 | print(f"Training new token {args.placeholder_token}.") |
537 | ) | ||
538 | |||
539 | # Convert the initializer_token, placeholder_token to ids | ||
540 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | ||
541 | # Check if initializer_token is a single token or a sequence of tokens | ||
542 | if len(initializer_token_ids) > 1: | ||
543 | raise ValueError( | ||
544 | f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.") | ||
545 | 552 | ||
546 | initializer_token_ids = torch.tensor(initializer_token_ids) | ||
547 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 553 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
548 | 554 | ||
549 | # Load models and create wrapper for stable diffusion | 555 | # Load models and create wrapper for stable diffusion |
@@ -630,6 +636,12 @@ def main(): | |||
630 | num_train_timesteps=args.noise_timesteps | 636 | num_train_timesteps=args.noise_timesteps |
631 | ) | 637 | ) |
632 | 638 | ||
639 | weight_dtype = torch.float32 | ||
640 | if args.mixed_precision == "fp16": | ||
641 | weight_dtype = torch.float16 | ||
642 | elif args.mixed_precision == "bf16": | ||
643 | weight_dtype = torch.bfloat16 | ||
644 | |||
633 | def collate_fn(examples): | 645 | def collate_fn(examples): |
634 | prompts = [example["prompts"] for example in examples] | 646 | prompts = [example["prompts"] for example in examples] |
635 | nprompts = [example["nprompts"] for example in examples] | 647 | nprompts = [example["nprompts"] for example in examples] |
@@ -642,7 +654,7 @@ def main(): | |||
642 | pixel_values += [example["class_images"] for example in examples] | 654 | pixel_values += [example["class_images"] for example in examples] |
643 | 655 | ||
644 | pixel_values = torch.stack(pixel_values) | 656 | pixel_values = torch.stack(pixel_values) |
645 | pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) | 657 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
646 | 658 | ||
647 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | 659 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids |
648 | 660 | ||
@@ -658,7 +670,7 @@ def main(): | |||
658 | data_file=args.train_data_file, | 670 | data_file=args.train_data_file, |
659 | batch_size=args.train_batch_size, | 671 | batch_size=args.train_batch_size, |
660 | tokenizer=tokenizer, | 672 | tokenizer=tokenizer, |
661 | instance_identifier=args.placeholder_token, | 673 | instance_identifier=args.instance_identifier, |
662 | class_identifier=args.class_identifier, | 674 | class_identifier=args.class_identifier, |
663 | class_subdir="cls", | 675 | class_subdir="cls", |
664 | num_class_images=args.num_class_images, | 676 | num_class_images=args.num_class_images, |
@@ -744,7 +756,7 @@ def main(): | |||
744 | ) | 756 | ) |
745 | 757 | ||
746 | # Move vae and unet to device | 758 | # Move vae and unet to device |
747 | vae.to(accelerator.device) | 759 | vae.to(accelerator.device, dtype=weight_dtype) |
748 | 760 | ||
749 | # Keep vae and unet in eval mode as we don't train these | 761 | # Keep vae and unet in eval mode as we don't train these |
750 | vae.eval() | 762 | vae.eval() |
@@ -785,6 +797,7 @@ def main(): | |||
785 | ema_unet=ema_unet, | 797 | ema_unet=ema_unet, |
786 | tokenizer=tokenizer, | 798 | tokenizer=tokenizer, |
787 | text_encoder=text_encoder, | 799 | text_encoder=text_encoder, |
800 | instance_identifier=args.instance_identifier, | ||
788 | placeholder_token=args.placeholder_token, | 801 | placeholder_token=args.placeholder_token, |
789 | placeholder_token_id=placeholder_token_id, | 802 | placeholder_token_id=placeholder_token_id, |
790 | output_dir=basepath, | 803 | output_dir=basepath, |
@@ -830,7 +843,7 @@ def main(): | |||
830 | latents = latents * 0.18215 | 843 | latents = latents * 0.18215 |
831 | 844 | ||
832 | # Sample noise that we'll add to the latents | 845 | # Sample noise that we'll add to the latents |
833 | noise = torch.randn(latents.shape).to(latents.device) | 846 | noise = torch.randn_like(latents) |
834 | bsz = latents.shape[0] | 847 | bsz = latents.shape[0] |
835 | # Sample a random timestep for each image | 848 | # Sample a random timestep for each image |
836 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 849 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
@@ -853,15 +866,15 @@ def main(): | |||
853 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 866 | noise, noise_prior = torch.chunk(noise, 2, dim=0) |
854 | 867 | ||
855 | # Compute instance loss | 868 | # Compute instance loss |
856 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 869 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() |
857 | 870 | ||
858 | # Compute prior loss | 871 | # Compute prior loss |
859 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() | 872 | prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") |
860 | 873 | ||
861 | # Add the prior loss to the instance loss. | 874 | # Add the prior loss to the instance loss. |
862 | loss = loss + args.prior_loss_weight * prior_loss | 875 | loss = loss + args.prior_loss_weight * prior_loss |
863 | else: | 876 | else: |
864 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 877 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") |
865 | 878 | ||
866 | accelerator.backward(loss) | 879 | accelerator.backward(loss) |
867 | 880 | ||
@@ -933,7 +946,7 @@ def main(): | |||
933 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 946 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
934 | latents = latents * 0.18215 | 947 | latents = latents * 0.18215 |
935 | 948 | ||
936 | noise = torch.randn(latents.shape).to(latents.device) | 949 | noise = torch.randn_like(latents) |
937 | bsz = latents.shape[0] | 950 | bsz = latents.shape[0] |
938 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 951 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
939 | (bsz,), device=latents.device) | 952 | (bsz,), device=latents.device) |
@@ -947,7 +960,7 @@ def main(): | |||
947 | 960 | ||
948 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 961 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
949 | 962 | ||
950 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 963 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") |
951 | 964 | ||
952 | loss = loss.detach().item() | 965 | loss = loss.detach().item() |
953 | val_loss += loss | 966 | val_loss += loss |