summaryrefslogtreecommitdiffstats
path: root/dreambooth_plus.py
diff options
context:
space:
mode:
Diffstat (limited to 'dreambooth_plus.py')
-rw-r--r--dreambooth_plus.py69
1 files changed, 41 insertions, 28 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py
index 73225de..a98417f 100644
--- a/dreambooth_plus.py
+++ b/dreambooth_plus.py
@@ -56,9 +56,9 @@ def parse_args():
56 help="A CSV file containing the training data." 56 help="A CSV file containing the training data."
57 ) 57 )
58 parser.add_argument( 58 parser.add_argument(
59 "--placeholder_token", 59 "--instance_identifier",
60 type=str, 60 type=str,
61 default="<*>", 61 default=None,
62 help="A token to use as a placeholder for the concept.", 62 help="A token to use as a placeholder for the concept.",
63 ) 63 )
64 parser.add_argument( 64 parser.add_argument(
@@ -68,6 +68,12 @@ def parse_args():
68 help="A token to use as a placeholder for the concept.", 68 help="A token to use as a placeholder for the concept.",
69 ) 69 )
70 parser.add_argument( 70 parser.add_argument(
71 "--placeholder_token",
72 type=str,
73 default="<*>",
74 help="A token to use as a placeholder for the concept.",
75 )
76 parser.add_argument(
71 "--initializer_token", 77 "--initializer_token",
72 type=str, 78 type=str,
73 default=None, 79 default=None,
@@ -118,7 +124,7 @@ def parse_args():
118 parser.add_argument( 124 parser.add_argument(
119 "--max_train_steps", 125 "--max_train_steps",
120 type=int, 126 type=int,
121 default=1200, 127 default=1500,
122 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 128 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
123 ) 129 )
124 parser.add_argument( 130 parser.add_argument(
@@ -135,13 +141,13 @@ def parse_args():
135 parser.add_argument( 141 parser.add_argument(
136 "--learning_rate_unet", 142 "--learning_rate_unet",
137 type=float, 143 type=float,
138 default=5e-5, 144 default=5e-6,
139 help="Initial learning rate (after the potential warmup period) to use.", 145 help="Initial learning rate (after the potential warmup period) to use.",
140 ) 146 )
141 parser.add_argument( 147 parser.add_argument(
142 "--learning_rate_text", 148 "--learning_rate_text",
143 type=float, 149 type=float,
144 default=1e-6, 150 default=5e-6,
145 help="Initial learning rate (after the potential warmup period) to use.", 151 help="Initial learning rate (after the potential warmup period) to use.",
146 ) 152 )
147 parser.add_argument( 153 parser.add_argument(
@@ -353,6 +359,7 @@ class Checkpointer:
353 ema_unet, 359 ema_unet,
354 tokenizer, 360 tokenizer,
355 text_encoder, 361 text_encoder,
362 instance_identifier,
356 placeholder_token, 363 placeholder_token,
357 placeholder_token_id, 364 placeholder_token_id,
358 output_dir: Path, 365 output_dir: Path,
@@ -368,6 +375,7 @@ class Checkpointer:
368 self.ema_unet = ema_unet 375 self.ema_unet = ema_unet
369 self.tokenizer = tokenizer 376 self.tokenizer = tokenizer
370 self.text_encoder = text_encoder 377 self.text_encoder = text_encoder
378 self.instance_identifier = instance_identifier
371 self.placeholder_token = placeholder_token 379 self.placeholder_token = placeholder_token
372 self.placeholder_token_id = placeholder_token_id 380 self.placeholder_token_id = placeholder_token_id
373 self.output_dir = output_dir 381 self.output_dir = output_dir
@@ -461,7 +469,7 @@ class Checkpointer:
461 469
462 for i in range(self.sample_batches): 470 for i in range(self.sample_batches):
463 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] 471 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
464 prompt = [prompt.format(self.placeholder_token) 472 prompt = [prompt.format(self.instance_identifier)
465 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] 473 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size]
466 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] 474 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size]
467 475
@@ -476,7 +484,7 @@ class Checkpointer:
476 eta=eta, 484 eta=eta,
477 num_inference_steps=num_inference_steps, 485 num_inference_steps=num_inference_steps,
478 output_type='pil' 486 output_type='pil'
479 )["sample"] 487 ).images
480 488
481 all_samples += samples 489 all_samples += samples
482 490
@@ -522,28 +530,26 @@ def main():
522 if args.seed is not None: 530 if args.seed is not None:
523 set_seed(args.seed) 531 set_seed(args.seed)
524 532
533 args.instance_identifier = args.instance_identifier.format(args.placeholder_token)
534
525 # Load the tokenizer and add the placeholder token as a additional special token 535 # Load the tokenizer and add the placeholder token as a additional special token
526 if args.tokenizer_name: 536 if args.tokenizer_name:
527 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) 537 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
528 elif args.pretrained_model_name_or_path: 538 elif args.pretrained_model_name_or_path:
529 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') 539 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
530 540
541 # Convert the initializer_token, placeholder_token to ids
542 initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
543 print(f"Initializer token maps to {len(initializer_token_ids)} embeddings.")
544 initializer_token_ids = torch.tensor(initializer_token_ids[:1])
545
531 # Add the placeholder token in tokenizer 546 # Add the placeholder token in tokenizer
532 num_added_tokens = tokenizer.add_tokens(args.placeholder_token) 547 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
533 if num_added_tokens == 0: 548 if num_added_tokens == 0:
534 raise ValueError( 549 print(f"Re-using existing token {args.placeholder_token}.")
535 f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" 550 else:
536 " `placeholder_token` that is not already in the tokenizer." 551 print(f"Training new token {args.placeholder_token}.")
537 )
538
539 # Convert the initializer_token, placeholder_token to ids
540 initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
541 # Check if initializer_token is a single token or a sequence of tokens
542 if len(initializer_token_ids) > 1:
543 raise ValueError(
544 f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.")
545 552
546 initializer_token_ids = torch.tensor(initializer_token_ids)
547 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) 553 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
548 554
549 # Load models and create wrapper for stable diffusion 555 # Load models and create wrapper for stable diffusion
@@ -630,6 +636,12 @@ def main():
630 num_train_timesteps=args.noise_timesteps 636 num_train_timesteps=args.noise_timesteps
631 ) 637 )
632 638
639 weight_dtype = torch.float32
640 if args.mixed_precision == "fp16":
641 weight_dtype = torch.float16
642 elif args.mixed_precision == "bf16":
643 weight_dtype = torch.bfloat16
644
633 def collate_fn(examples): 645 def collate_fn(examples):
634 prompts = [example["prompts"] for example in examples] 646 prompts = [example["prompts"] for example in examples]
635 nprompts = [example["nprompts"] for example in examples] 647 nprompts = [example["nprompts"] for example in examples]
@@ -642,7 +654,7 @@ def main():
642 pixel_values += [example["class_images"] for example in examples] 654 pixel_values += [example["class_images"] for example in examples]
643 655
644 pixel_values = torch.stack(pixel_values) 656 pixel_values = torch.stack(pixel_values)
645 pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) 657 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
646 658
647 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids 659 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
648 660
@@ -658,7 +670,7 @@ def main():
658 data_file=args.train_data_file, 670 data_file=args.train_data_file,
659 batch_size=args.train_batch_size, 671 batch_size=args.train_batch_size,
660 tokenizer=tokenizer, 672 tokenizer=tokenizer,
661 instance_identifier=args.placeholder_token, 673 instance_identifier=args.instance_identifier,
662 class_identifier=args.class_identifier, 674 class_identifier=args.class_identifier,
663 class_subdir="cls", 675 class_subdir="cls",
664 num_class_images=args.num_class_images, 676 num_class_images=args.num_class_images,
@@ -744,7 +756,7 @@ def main():
744 ) 756 )
745 757
746 # Move vae and unet to device 758 # Move vae and unet to device
747 vae.to(accelerator.device) 759 vae.to(accelerator.device, dtype=weight_dtype)
748 760
749 # Keep vae and unet in eval mode as we don't train these 761 # Keep vae and unet in eval mode as we don't train these
750 vae.eval() 762 vae.eval()
@@ -785,6 +797,7 @@ def main():
785 ema_unet=ema_unet, 797 ema_unet=ema_unet,
786 tokenizer=tokenizer, 798 tokenizer=tokenizer,
787 text_encoder=text_encoder, 799 text_encoder=text_encoder,
800 instance_identifier=args.instance_identifier,
788 placeholder_token=args.placeholder_token, 801 placeholder_token=args.placeholder_token,
789 placeholder_token_id=placeholder_token_id, 802 placeholder_token_id=placeholder_token_id,
790 output_dir=basepath, 803 output_dir=basepath,
@@ -830,7 +843,7 @@ def main():
830 latents = latents * 0.18215 843 latents = latents * 0.18215
831 844
832 # Sample noise that we'll add to the latents 845 # Sample noise that we'll add to the latents
833 noise = torch.randn(latents.shape).to(latents.device) 846 noise = torch.randn_like(latents)
834 bsz = latents.shape[0] 847 bsz = latents.shape[0]
835 # Sample a random timestep for each image 848 # Sample a random timestep for each image
836 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, 849 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
@@ -853,15 +866,15 @@ def main():
853 noise, noise_prior = torch.chunk(noise, 2, dim=0) 866 noise, noise_prior = torch.chunk(noise, 2, dim=0)
854 867
855 # Compute instance loss 868 # Compute instance loss
856 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 869 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
857 870
858 # Compute prior loss 871 # Compute prior loss
859 prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() 872 prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
860 873
861 # Add the prior loss to the instance loss. 874 # Add the prior loss to the instance loss.
862 loss = loss + args.prior_loss_weight * prior_loss 875 loss = loss + args.prior_loss_weight * prior_loss
863 else: 876 else:
864 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 877 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
865 878
866 accelerator.backward(loss) 879 accelerator.backward(loss)
867 880
@@ -933,7 +946,7 @@ def main():
933 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 946 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
934 latents = latents * 0.18215 947 latents = latents * 0.18215
935 948
936 noise = torch.randn(latents.shape).to(latents.device) 949 noise = torch.randn_like(latents)
937 bsz = latents.shape[0] 950 bsz = latents.shape[0]
938 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, 951 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
939 (bsz,), device=latents.device) 952 (bsz,), device=latents.device)
@@ -947,7 +960,7 @@ def main():
947 960
948 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) 961 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
949 962
950 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 963 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
951 964
952 loss = loss.detach().item() 965 loss = loss.detach().item()
953 val_loss += loss 966 val_loss += loss