summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--dreambooth.py26
-rw-r--r--dreambooth_plus.py69
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py2
-rw-r--r--schedulers/scheduling_euler_a.py3
-rw-r--r--textual_inversion.py41
5 files changed, 82 insertions, 59 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 42d3980..770ad38 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -430,7 +430,7 @@ class Checkpointer:
430 eta=eta, 430 eta=eta,
431 num_inference_steps=num_inference_steps, 431 num_inference_steps=num_inference_steps,
432 output_type='pil' 432 output_type='pil'
433 )["sample"] 433 ).images
434 434
435 all_samples += samples 435 all_samples += samples
436 436
@@ -537,6 +537,12 @@ def main():
537 num_train_timesteps=args.noise_timesteps 537 num_train_timesteps=args.noise_timesteps
538 ) 538 )
539 539
540 weight_dtype = torch.float32
541 if args.mixed_precision == "fp16":
542 weight_dtype = torch.float16
543 elif args.mixed_precision == "bf16":
544 weight_dtype = torch.bfloat16
545
540 def collate_fn(examples): 546 def collate_fn(examples):
541 prompts = [example["prompts"] for example in examples] 547 prompts = [example["prompts"] for example in examples]
542 nprompts = [example["nprompts"] for example in examples] 548 nprompts = [example["nprompts"] for example in examples]
@@ -549,7 +555,7 @@ def main():
549 pixel_values += [example["class_images"] for example in examples] 555 pixel_values += [example["class_images"] for example in examples]
550 556
551 pixel_values = torch.stack(pixel_values) 557 pixel_values = torch.stack(pixel_values)
552 pixel_values = pixel_values.to(memory_format=torch.contiguous_format) 558 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
553 559
554 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids 560 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
555 561
@@ -651,8 +657,8 @@ def main():
651 ) 657 )
652 658
653 # Move text_encoder and vae to device 659 # Move text_encoder and vae to device
654 text_encoder.to(accelerator.device) 660 text_encoder.to(accelerator.device, dtype=weight_dtype)
655 vae.to(accelerator.device) 661 vae.to(accelerator.device, dtype=weight_dtype)
656 662
657 # Keep text_encoder and vae in eval mode as we don't train these 663 # Keep text_encoder and vae in eval mode as we don't train these
658 text_encoder.eval() 664 text_encoder.eval()
@@ -738,7 +744,7 @@ def main():
738 latents = latents * 0.18215 744 latents = latents * 0.18215
739 745
740 # Sample noise that we'll add to the latents 746 # Sample noise that we'll add to the latents
741 noise = torch.randn(latents.shape).to(latents.device) 747 noise = torch.randn_like(latents)
742 bsz = latents.shape[0] 748 bsz = latents.shape[0]
743 # Sample a random timestep for each image 749 # Sample a random timestep for each image
744 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, 750 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
@@ -761,15 +767,15 @@ def main():
761 noise, noise_prior = torch.chunk(noise, 2, dim=0) 767 noise, noise_prior = torch.chunk(noise, 2, dim=0)
762 768
763 # Compute instance loss 769 # Compute instance loss
764 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 770 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
765 771
766 # Compute prior loss 772 # Compute prior loss
767 prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() 773 prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
768 774
769 # Add the prior loss to the instance loss. 775 # Add the prior loss to the instance loss.
770 loss = loss + args.prior_loss_weight * prior_loss 776 loss = loss + args.prior_loss_weight * prior_loss
771 else: 777 else:
772 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 778 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
773 779
774 accelerator.backward(loss) 780 accelerator.backward(loss)
775 if accelerator.sync_gradients: 781 if accelerator.sync_gradients:
@@ -818,7 +824,7 @@ def main():
818 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 824 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
819 latents = latents * 0.18215 825 latents = latents * 0.18215
820 826
821 noise = torch.randn(latents.shape).to(latents.device) 827 noise = torch.randn_like(latents)
822 bsz = latents.shape[0] 828 bsz = latents.shape[0]
823 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, 829 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
824 (bsz,), device=latents.device) 830 (bsz,), device=latents.device)
@@ -832,7 +838,7 @@ def main():
832 838
833 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) 839 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
834 840
835 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 841 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
836 842
837 loss = loss.detach().item() 843 loss = loss.detach().item()
838 val_loss += loss 844 val_loss += loss
diff --git a/dreambooth_plus.py b/dreambooth_plus.py
index 73225de..a98417f 100644
--- a/dreambooth_plus.py
+++ b/dreambooth_plus.py
@@ -56,9 +56,9 @@ def parse_args():
56 help="A CSV file containing the training data." 56 help="A CSV file containing the training data."
57 ) 57 )
58 parser.add_argument( 58 parser.add_argument(
59 "--placeholder_token", 59 "--instance_identifier",
60 type=str, 60 type=str,
61 default="<*>", 61 default=None,
62 help="A token to use as a placeholder for the concept.", 62 help="A token to use as a placeholder for the concept.",
63 ) 63 )
64 parser.add_argument( 64 parser.add_argument(
@@ -68,6 +68,12 @@ def parse_args():
68 help="A token to use as a placeholder for the concept.", 68 help="A token to use as a placeholder for the concept.",
69 ) 69 )
70 parser.add_argument( 70 parser.add_argument(
71 "--placeholder_token",
72 type=str,
73 default="<*>",
74 help="A token to use as a placeholder for the concept.",
75 )
76 parser.add_argument(
71 "--initializer_token", 77 "--initializer_token",
72 type=str, 78 type=str,
73 default=None, 79 default=None,
@@ -118,7 +124,7 @@ def parse_args():
118 parser.add_argument( 124 parser.add_argument(
119 "--max_train_steps", 125 "--max_train_steps",
120 type=int, 126 type=int,
121 default=1200, 127 default=1500,
122 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 128 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
123 ) 129 )
124 parser.add_argument( 130 parser.add_argument(
@@ -135,13 +141,13 @@ def parse_args():
135 parser.add_argument( 141 parser.add_argument(
136 "--learning_rate_unet", 142 "--learning_rate_unet",
137 type=float, 143 type=float,
138 default=5e-5, 144 default=5e-6,
139 help="Initial learning rate (after the potential warmup period) to use.", 145 help="Initial learning rate (after the potential warmup period) to use.",
140 ) 146 )
141 parser.add_argument( 147 parser.add_argument(
142 "--learning_rate_text", 148 "--learning_rate_text",
143 type=float, 149 type=float,
144 default=1e-6, 150 default=5e-6,
145 help="Initial learning rate (after the potential warmup period) to use.", 151 help="Initial learning rate (after the potential warmup period) to use.",
146 ) 152 )
147 parser.add_argument( 153 parser.add_argument(
@@ -353,6 +359,7 @@ class Checkpointer:
353 ema_unet, 359 ema_unet,
354 tokenizer, 360 tokenizer,
355 text_encoder, 361 text_encoder,
362 instance_identifier,
356 placeholder_token, 363 placeholder_token,
357 placeholder_token_id, 364 placeholder_token_id,
358 output_dir: Path, 365 output_dir: Path,
@@ -368,6 +375,7 @@ class Checkpointer:
368 self.ema_unet = ema_unet 375 self.ema_unet = ema_unet
369 self.tokenizer = tokenizer 376 self.tokenizer = tokenizer
370 self.text_encoder = text_encoder 377 self.text_encoder = text_encoder
378 self.instance_identifier = instance_identifier
371 self.placeholder_token = placeholder_token 379 self.placeholder_token = placeholder_token
372 self.placeholder_token_id = placeholder_token_id 380 self.placeholder_token_id = placeholder_token_id
373 self.output_dir = output_dir 381 self.output_dir = output_dir
@@ -461,7 +469,7 @@ class Checkpointer:
461 469
462 for i in range(self.sample_batches): 470 for i in range(self.sample_batches):
463 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] 471 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
464 prompt = [prompt.format(self.placeholder_token) 472 prompt = [prompt.format(self.instance_identifier)
465 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] 473 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size]
466 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] 474 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size]
467 475
@@ -476,7 +484,7 @@ class Checkpointer:
476 eta=eta, 484 eta=eta,
477 num_inference_steps=num_inference_steps, 485 num_inference_steps=num_inference_steps,
478 output_type='pil' 486 output_type='pil'
479 )["sample"] 487 ).images
480 488
481 all_samples += samples 489 all_samples += samples
482 490
@@ -522,28 +530,26 @@ def main():
522 if args.seed is not None: 530 if args.seed is not None:
523 set_seed(args.seed) 531 set_seed(args.seed)
524 532
533 args.instance_identifier = args.instance_identifier.format(args.placeholder_token)
534
525 # Load the tokenizer and add the placeholder token as a additional special token 535 # Load the tokenizer and add the placeholder token as a additional special token
526 if args.tokenizer_name: 536 if args.tokenizer_name:
527 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) 537 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
528 elif args.pretrained_model_name_or_path: 538 elif args.pretrained_model_name_or_path:
529 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') 539 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
530 540
541 # Convert the initializer_token, placeholder_token to ids
542 initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
543 print(f"Initializer token maps to {len(initializer_token_ids)} embeddings.")
544 initializer_token_ids = torch.tensor(initializer_token_ids[:1])
545
531 # Add the placeholder token in tokenizer 546 # Add the placeholder token in tokenizer
532 num_added_tokens = tokenizer.add_tokens(args.placeholder_token) 547 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
533 if num_added_tokens == 0: 548 if num_added_tokens == 0:
534 raise ValueError( 549 print(f"Re-using existing token {args.placeholder_token}.")
535 f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" 550 else:
536 " `placeholder_token` that is not already in the tokenizer." 551 print(f"Training new token {args.placeholder_token}.")
537 )
538
539 # Convert the initializer_token, placeholder_token to ids
540 initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
541 # Check if initializer_token is a single token or a sequence of tokens
542 if len(initializer_token_ids) > 1:
543 raise ValueError(
544 f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.")
545 552
546 initializer_token_ids = torch.tensor(initializer_token_ids)
547 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) 553 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
548 554
549 # Load models and create wrapper for stable diffusion 555 # Load models and create wrapper for stable diffusion
@@ -630,6 +636,12 @@ def main():
630 num_train_timesteps=args.noise_timesteps 636 num_train_timesteps=args.noise_timesteps
631 ) 637 )
632 638
639 weight_dtype = torch.float32
640 if args.mixed_precision == "fp16":
641 weight_dtype = torch.float16
642 elif args.mixed_precision == "bf16":
643 weight_dtype = torch.bfloat16
644
633 def collate_fn(examples): 645 def collate_fn(examples):
634 prompts = [example["prompts"] for example in examples] 646 prompts = [example["prompts"] for example in examples]
635 nprompts = [example["nprompts"] for example in examples] 647 nprompts = [example["nprompts"] for example in examples]
@@ -642,7 +654,7 @@ def main():
642 pixel_values += [example["class_images"] for example in examples] 654 pixel_values += [example["class_images"] for example in examples]
643 655
644 pixel_values = torch.stack(pixel_values) 656 pixel_values = torch.stack(pixel_values)
645 pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) 657 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
646 658
647 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids 659 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
648 660
@@ -658,7 +670,7 @@ def main():
658 data_file=args.train_data_file, 670 data_file=args.train_data_file,
659 batch_size=args.train_batch_size, 671 batch_size=args.train_batch_size,
660 tokenizer=tokenizer, 672 tokenizer=tokenizer,
661 instance_identifier=args.placeholder_token, 673 instance_identifier=args.instance_identifier,
662 class_identifier=args.class_identifier, 674 class_identifier=args.class_identifier,
663 class_subdir="cls", 675 class_subdir="cls",
664 num_class_images=args.num_class_images, 676 num_class_images=args.num_class_images,
@@ -744,7 +756,7 @@ def main():
744 ) 756 )
745 757
746 # Move vae and unet to device 758 # Move vae and unet to device
747 vae.to(accelerator.device) 759 vae.to(accelerator.device, dtype=weight_dtype)
748 760
749 # Keep vae and unet in eval mode as we don't train these 761 # Keep vae and unet in eval mode as we don't train these
750 vae.eval() 762 vae.eval()
@@ -785,6 +797,7 @@ def main():
785 ema_unet=ema_unet, 797 ema_unet=ema_unet,
786 tokenizer=tokenizer, 798 tokenizer=tokenizer,
787 text_encoder=text_encoder, 799 text_encoder=text_encoder,
800 instance_identifier=args.instance_identifier,
788 placeholder_token=args.placeholder_token, 801 placeholder_token=args.placeholder_token,
789 placeholder_token_id=placeholder_token_id, 802 placeholder_token_id=placeholder_token_id,
790 output_dir=basepath, 803 output_dir=basepath,
@@ -830,7 +843,7 @@ def main():
830 latents = latents * 0.18215 843 latents = latents * 0.18215
831 844
832 # Sample noise that we'll add to the latents 845 # Sample noise that we'll add to the latents
833 noise = torch.randn(latents.shape).to(latents.device) 846 noise = torch.randn_like(latents)
834 bsz = latents.shape[0] 847 bsz = latents.shape[0]
835 # Sample a random timestep for each image 848 # Sample a random timestep for each image
836 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, 849 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
@@ -853,15 +866,15 @@ def main():
853 noise, noise_prior = torch.chunk(noise, 2, dim=0) 866 noise, noise_prior = torch.chunk(noise, 2, dim=0)
854 867
855 # Compute instance loss 868 # Compute instance loss
856 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 869 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
857 870
858 # Compute prior loss 871 # Compute prior loss
859 prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() 872 prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
860 873
861 # Add the prior loss to the instance loss. 874 # Add the prior loss to the instance loss.
862 loss = loss + args.prior_loss_weight * prior_loss 875 loss = loss + args.prior_loss_weight * prior_loss
863 else: 876 else:
864 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 877 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
865 878
866 accelerator.backward(loss) 879 accelerator.backward(loss)
867 880
@@ -933,7 +946,7 @@ def main():
933 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 946 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
934 latents = latents * 0.18215 947 latents = latents * 0.18215
935 948
936 noise = torch.randn(latents.shape).to(latents.device) 949 noise = torch.randn_like(latents)
937 bsz = latents.shape[0] 950 bsz = latents.shape[0]
938 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, 951 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
939 (bsz,), device=latents.device) 952 (bsz,), device=latents.device)
@@ -947,7 +960,7 @@ def main():
947 960
948 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) 961 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
949 962
950 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 963 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
951 964
952 loss = loss.detach().item() 965 loss = loss.detach().item()
953 val_loss += loss 966 val_loss += loss
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 2656b28..8b08a6f 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -301,7 +301,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
301 301
302 # scale and decode the image latents with vae 302 # scale and decode the image latents with vae
303 latents = 1 / 0.18215 * latents 303 latents = 1 / 0.18215 * latents
304 image = self.vae.decode(latents).sample 304 image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample
305 305
306 image = (image / 2 + 0.5).clamp(0, 1) 306 image = (image / 2 + 0.5).clamp(0, 1)
307 image = image.cpu().permute(0, 2, 3, 1).float().numpy() 307 image = image.cpu().permute(0, 2, 3, 1).float().numpy()
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py
index 6abe971..c097a8a 100644
--- a/schedulers/scheduling_euler_a.py
+++ b/schedulers/scheduling_euler_a.py
@@ -47,7 +47,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
47 beta_end: float = 0.02, 47 beta_end: float = 0.02,
48 beta_schedule: str = "linear", 48 beta_schedule: str = "linear",
49 trained_betas: Optional[np.ndarray] = None, 49 trained_betas: Optional[np.ndarray] = None,
50 tensor_format: str = "pt",
51 num_inference_steps=None, 50 num_inference_steps=None,
52 device='cuda' 51 device='cuda'
53 ): 52 ):
@@ -63,7 +62,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
63 raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") 62 raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
64 63
65 self.device = device 64 self.device = device
66 self.tensor_format = tensor_format
67 65
68 self.alphas = 1.0 - self.betas 66 self.alphas = 1.0 - self.betas
69 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 67 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
@@ -77,7 +75,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
77 # get sigmas 75 # get sigmas
78 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 76 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
79 self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps) 77 self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps)
80 self.set_format(tensor_format=tensor_format)
81 78
82 # A# take number of steps as input 79 # A# take number of steps as input
83 # A# store 1) number of steps 2) timesteps 3) schedule 80 # A# store 1) number of steps 2) timesteps 3) schedule
diff --git a/textual_inversion.py b/textual_inversion.py
index 0d5a742..69d9c7f 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -55,9 +55,9 @@ def parse_args():
55 help="A CSV file containing the training data." 55 help="A CSV file containing the training data."
56 ) 56 )
57 parser.add_argument( 57 parser.add_argument(
58 "--placeholder_token", 58 "--instance_identifier",
59 type=str, 59 type=str,
60 default="<*>", 60 default=None,
61 help="A token to use as a placeholder for the concept.", 61 help="A token to use as a placeholder for the concept.",
62 ) 62 )
63 parser.add_argument( 63 parser.add_argument(
@@ -67,6 +67,12 @@ def parse_args():
67 help="A token to use as a placeholder for the concept.", 67 help="A token to use as a placeholder for the concept.",
68 ) 68 )
69 parser.add_argument( 69 parser.add_argument(
70 "--placeholder_token",
71 type=str,
72 default="<*>",
73 help="A token to use as a placeholder for the concept.",
74 )
75 parser.add_argument(
70 "--initializer_token", 76 "--initializer_token",
71 type=str, 77 type=str,
72 default=None, 78 default=None,
@@ -333,6 +339,7 @@ class Checkpointer:
333 unet, 339 unet,
334 tokenizer, 340 tokenizer,
335 text_encoder, 341 text_encoder,
342 instance_identifier,
336 placeholder_token, 343 placeholder_token,
337 placeholder_token_id, 344 placeholder_token_id,
338 output_dir: Path, 345 output_dir: Path,
@@ -347,6 +354,7 @@ class Checkpointer:
347 self.unet = unet 354 self.unet = unet
348 self.tokenizer = tokenizer 355 self.tokenizer = tokenizer
349 self.text_encoder = text_encoder 356 self.text_encoder = text_encoder
357 self.instance_identifier = instance_identifier
350 self.placeholder_token = placeholder_token 358 self.placeholder_token = placeholder_token
351 self.placeholder_token_id = placeholder_token_id 359 self.placeholder_token_id = placeholder_token_id
352 self.output_dir = output_dir 360 self.output_dir = output_dir
@@ -413,7 +421,7 @@ class Checkpointer:
413 421
414 for i in range(self.sample_batches): 422 for i in range(self.sample_batches):
415 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] 423 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
416 prompt = [prompt.format(self.placeholder_token) 424 prompt = [prompt.format(self.instance_identifier)
417 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] 425 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size]
418 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] 426 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size]
419 427
@@ -428,7 +436,7 @@ class Checkpointer:
428 eta=eta, 436 eta=eta,
429 num_inference_steps=num_inference_steps, 437 num_inference_steps=num_inference_steps,
430 output_type='pil' 438 output_type='pil'
431 )["sample"] 439 ).images
432 440
433 all_samples += samples 441 all_samples += samples
434 442
@@ -480,28 +488,26 @@ def main():
480 if args.seed is not None: 488 if args.seed is not None:
481 set_seed(args.seed) 489 set_seed(args.seed)
482 490
491 args.instance_identifier = args.instance_identifier.format(args.placeholder_token)
492
483 # Load the tokenizer and add the placeholder token as a additional special token 493 # Load the tokenizer and add the placeholder token as a additional special token
484 if args.tokenizer_name: 494 if args.tokenizer_name:
485 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) 495 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
486 elif args.pretrained_model_name_or_path: 496 elif args.pretrained_model_name_or_path:
487 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') 497 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
488 498
499 # Convert the initializer_token, placeholder_token to ids
500 initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
501 print(f"Initializer token maps to {len(initializer_token_ids)} embeddings.")
502 initializer_token_ids = torch.tensor(initializer_token_ids[:1])
503
489 # Add the placeholder token in tokenizer 504 # Add the placeholder token in tokenizer
490 num_added_tokens = tokenizer.add_tokens(args.placeholder_token) 505 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
491 if num_added_tokens == 0: 506 if num_added_tokens == 0:
492 raise ValueError( 507 print(f"Re-using existing token {args.placeholder_token}.")
493 f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" 508 else:
494 " `placeholder_token` that is not already in the tokenizer." 509 print(f"Training new token {args.placeholder_token}.")
495 )
496
497 # Convert the initializer_token, placeholder_token to ids
498 initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
499 # Check if initializer_token is a single token or a sequence of tokens
500 if len(initializer_token_ids) > 1:
501 raise ValueError(
502 f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.")
503 510
504 initializer_token_ids = torch.tensor(initializer_token_ids)
505 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) 511 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
506 512
507 # Load models and create wrapper for stable diffusion 513 # Load models and create wrapper for stable diffusion
@@ -602,7 +608,7 @@ def main():
602 data_file=args.train_data_file, 608 data_file=args.train_data_file,
603 batch_size=args.train_batch_size, 609 batch_size=args.train_batch_size,
604 tokenizer=tokenizer, 610 tokenizer=tokenizer,
605 instance_identifier=args.placeholder_token, 611 instance_identifier=args.instance_identifier,
606 class_identifier=args.class_identifier, 612 class_identifier=args.class_identifier,
607 class_subdir="cls", 613 class_subdir="cls",
608 num_class_images=args.num_class_images, 614 num_class_images=args.num_class_images,
@@ -730,6 +736,7 @@ def main():
730 unet=unet, 736 unet=unet,
731 tokenizer=tokenizer, 737 tokenizer=tokenizer,
732 text_encoder=text_encoder, 738 text_encoder=text_encoder,
739 instance_identifier=args.instance_identifier,
733 placeholder_token=args.placeholder_token, 740 placeholder_token=args.placeholder_token,
734 placeholder_token_id=placeholder_token_id, 741 placeholder_token_id=placeholder_token_id,
735 output_dir=basepath, 742 output_dir=basepath,