summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py36
1 files changed, 25 insertions, 11 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 675320b..3110c6d 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -118,6 +118,12 @@ def parse_args():
118 help="The output directory where the model predictions and checkpoints will be written.", 118 help="The output directory where the model predictions and checkpoints will be written.",
119 ) 119 )
120 parser.add_argument( 120 parser.add_argument(
121 "--embeddings_dir",
122 type=str,
123 default="embeddings_ti",
124 help="The embeddings directory where Textual Inversion embeddings are stored.",
125 )
126 parser.add_argument(
121 "--seed", 127 "--seed",
122 type=int, 128 type=int,
123 default=None, 129 default=None,
@@ -521,7 +527,7 @@ class Checkpointer:
521 negative_prompt=nprompt, 527 negative_prompt=nprompt,
522 height=self.sample_image_size, 528 height=self.sample_image_size,
523 width=self.sample_image_size, 529 width=self.sample_image_size,
524 latents_or_image=latents[:len(prompt)] if latents is not None else None, 530 image=latents[:len(prompt)] if latents is not None else None,
525 generator=generator if latents is not None else None, 531 generator=generator if latents is not None else None,
526 guidance_scale=guidance_scale, 532 guidance_scale=guidance_scale,
527 eta=eta, 533 eta=eta,
@@ -567,6 +573,8 @@ def main():
567 basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) 573 basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now)
568 basepath.mkdir(parents=True, exist_ok=True) 574 basepath.mkdir(parents=True, exist_ok=True)
569 575
576 embeddings_dir = Path(args.embeddings_dir)
577
570 accelerator = Accelerator( 578 accelerator = Accelerator(
571 log_with=LoggerType.TENSORBOARD, 579 log_with=LoggerType.TENSORBOARD,
572 logging_dir=f"{basepath}", 580 logging_dir=f"{basepath}",
@@ -630,15 +638,25 @@ def main():
630 638
631 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) 639 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
632 640
641 # Resize the token embeddings as we are adding new special tokens to the tokenizer
642 text_encoder.resize_token_embeddings(len(tokenizer))
643
644 token_embeds = text_encoder.get_input_embeddings().weight.data
645
633 print(f"Token ID mappings:") 646 print(f"Token ID mappings:")
634 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): 647 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token):
635 print(f"- {token_id} {token}") 648 print(f"- {token_id} {token}")
636 649
637 # Resize the token embeddings as we are adding new special tokens to the tokenizer 650 embedding_file = embeddings_dir.joinpath(f"{token}.bin")
638 text_encoder.resize_token_embeddings(len(tokenizer)) 651 if embedding_file.exists() and embedding_file.is_file():
652 embedding_data = torch.load(embedding_file, map_location="cpu")
653
654 emb = next(iter(embedding_data.values()))
655 if len(emb.shape) == 1:
656 emb = emb.unsqueeze(0)
657
658 token_embeds[token_id] = emb
639 659
640 # Initialise the newly added placeholder token with the embeddings of the initializer token
641 token_embeds = text_encoder.get_input_embeddings().weight.data
642 original_token_embeds = token_embeds.detach().clone().to(accelerator.device) 660 original_token_embeds = token_embeds.detach().clone().to(accelerator.device)
643 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) 661 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
644 662
@@ -959,8 +977,6 @@ def main():
959 else: 977 else:
960 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 978 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
961 979
962 del timesteps, noise, latents, noisy_latents, encoder_hidden_states
963
964 if args.num_class_images != 0: 980 if args.num_class_images != 0:
965 # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 981 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
966 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 982 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
@@ -977,6 +993,8 @@ def main():
977 else: 993 else:
978 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 994 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
979 995
996 acc = (model_pred == latents).float().mean()
997
980 accelerator.backward(loss) 998 accelerator.backward(loss)
981 999
982 if not args.train_text_encoder: 1000 if not args.train_text_encoder:
@@ -1004,8 +1022,6 @@ def main():
1004 ema_unet.step(unet) 1022 ema_unet.step(unet)
1005 optimizer.zero_grad(set_to_none=True) 1023 optimizer.zero_grad(set_to_none=True)
1006 1024
1007 acc = (model_pred == latents).float().mean()
1008
1009 avg_loss.update(loss.detach_(), bsz) 1025 avg_loss.update(loss.detach_(), bsz)
1010 avg_acc.update(acc.detach_(), bsz) 1026 avg_acc.update(acc.detach_(), bsz)
1011 1027
@@ -1069,8 +1085,6 @@ def main():
1069 else: 1085 else:
1070 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 1086 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1071 1087
1072 del timesteps, noise, latents, noisy_latents, encoder_hidden_states
1073
1074 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 1088 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1075 1089
1076 acc = (model_pred == latents).float().mean() 1090 acc = (model_pred == latents).float().mean()