diff options
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 36 |
1 files changed, 25 insertions, 11 deletions
diff --git a/dreambooth.py b/dreambooth.py index 675320b..3110c6d 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -118,6 +118,12 @@ def parse_args(): | |||
118 | help="The output directory where the model predictions and checkpoints will be written.", | 118 | help="The output directory where the model predictions and checkpoints will be written.", |
119 | ) | 119 | ) |
120 | parser.add_argument( | 120 | parser.add_argument( |
121 | "--embeddings_dir", | ||
122 | type=str, | ||
123 | default="embeddings_ti", | ||
124 | help="The embeddings directory where Textual Inversion embeddings are stored.", | ||
125 | ) | ||
126 | parser.add_argument( | ||
121 | "--seed", | 127 | "--seed", |
122 | type=int, | 128 | type=int, |
123 | default=None, | 129 | default=None, |
@@ -521,7 +527,7 @@ class Checkpointer: | |||
521 | negative_prompt=nprompt, | 527 | negative_prompt=nprompt, |
522 | height=self.sample_image_size, | 528 | height=self.sample_image_size, |
523 | width=self.sample_image_size, | 529 | width=self.sample_image_size, |
524 | latents_or_image=latents[:len(prompt)] if latents is not None else None, | 530 | image=latents[:len(prompt)] if latents is not None else None, |
525 | generator=generator if latents is not None else None, | 531 | generator=generator if latents is not None else None, |
526 | guidance_scale=guidance_scale, | 532 | guidance_scale=guidance_scale, |
527 | eta=eta, | 533 | eta=eta, |
@@ -567,6 +573,8 @@ def main(): | |||
567 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) | 573 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) |
568 | basepath.mkdir(parents=True, exist_ok=True) | 574 | basepath.mkdir(parents=True, exist_ok=True) |
569 | 575 | ||
576 | embeddings_dir = Path(args.embeddings_dir) | ||
577 | |||
570 | accelerator = Accelerator( | 578 | accelerator = Accelerator( |
571 | log_with=LoggerType.TENSORBOARD, | 579 | log_with=LoggerType.TENSORBOARD, |
572 | logging_dir=f"{basepath}", | 580 | logging_dir=f"{basepath}", |
@@ -630,15 +638,25 @@ def main(): | |||
630 | 638 | ||
631 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 639 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
632 | 640 | ||
641 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | ||
642 | text_encoder.resize_token_embeddings(len(tokenizer)) | ||
643 | |||
644 | token_embeds = text_encoder.get_input_embeddings().weight.data | ||
645 | |||
633 | print(f"Token ID mappings:") | 646 | print(f"Token ID mappings:") |
634 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): | 647 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): |
635 | print(f"- {token_id} {token}") | 648 | print(f"- {token_id} {token}") |
636 | 649 | ||
637 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | 650 | embedding_file = embeddings_dir.joinpath(f"{token}.bin") |
638 | text_encoder.resize_token_embeddings(len(tokenizer)) | 651 | if embedding_file.exists() and embedding_file.is_file(): |
652 | embedding_data = torch.load(embedding_file, map_location="cpu") | ||
653 | |||
654 | emb = next(iter(embedding_data.values())) | ||
655 | if len(emb.shape) == 1: | ||
656 | emb = emb.unsqueeze(0) | ||
657 | |||
658 | token_embeds[token_id] = emb | ||
639 | 659 | ||
640 | # Initialise the newly added placeholder token with the embeddings of the initializer token | ||
641 | token_embeds = text_encoder.get_input_embeddings().weight.data | ||
642 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) | 660 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) |
643 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | 661 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) |
644 | 662 | ||
@@ -959,8 +977,6 @@ def main(): | |||
959 | else: | 977 | else: |
960 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 978 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
961 | 979 | ||
962 | del timesteps, noise, latents, noisy_latents, encoder_hidden_states | ||
963 | |||
964 | if args.num_class_images != 0: | 980 | if args.num_class_images != 0: |
965 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 981 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
966 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 982 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
@@ -977,6 +993,8 @@ def main(): | |||
977 | else: | 993 | else: |
978 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 994 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
979 | 995 | ||
996 | acc = (model_pred == latents).float().mean() | ||
997 | |||
980 | accelerator.backward(loss) | 998 | accelerator.backward(loss) |
981 | 999 | ||
982 | if not args.train_text_encoder: | 1000 | if not args.train_text_encoder: |
@@ -1004,8 +1022,6 @@ def main(): | |||
1004 | ema_unet.step(unet) | 1022 | ema_unet.step(unet) |
1005 | optimizer.zero_grad(set_to_none=True) | 1023 | optimizer.zero_grad(set_to_none=True) |
1006 | 1024 | ||
1007 | acc = (model_pred == latents).float().mean() | ||
1008 | |||
1009 | avg_loss.update(loss.detach_(), bsz) | 1025 | avg_loss.update(loss.detach_(), bsz) |
1010 | avg_acc.update(acc.detach_(), bsz) | 1026 | avg_acc.update(acc.detach_(), bsz) |
1011 | 1027 | ||
@@ -1069,8 +1085,6 @@ def main(): | |||
1069 | else: | 1085 | else: |
1070 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 1086 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
1071 | 1087 | ||
1072 | del timesteps, noise, latents, noisy_latents, encoder_hidden_states | ||
1073 | |||
1074 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 1088 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
1075 | 1089 | ||
1076 | acc = (model_pred == latents).float().mean() | 1090 | acc = (model_pred == latents).float().mean() |