diff options
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 36 |
1 files changed, 25 insertions, 11 deletions
diff --git a/dreambooth.py b/dreambooth.py index 675320b..3110c6d 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -118,6 +118,12 @@ def parse_args(): | |||
| 118 | help="The output directory where the model predictions and checkpoints will be written.", | 118 | help="The output directory where the model predictions and checkpoints will be written.", |
| 119 | ) | 119 | ) |
| 120 | parser.add_argument( | 120 | parser.add_argument( |
| 121 | "--embeddings_dir", | ||
| 122 | type=str, | ||
| 123 | default="embeddings_ti", | ||
| 124 | help="The embeddings directory where Textual Inversion embeddings are stored.", | ||
| 125 | ) | ||
| 126 | parser.add_argument( | ||
| 121 | "--seed", | 127 | "--seed", |
| 122 | type=int, | 128 | type=int, |
| 123 | default=None, | 129 | default=None, |
| @@ -521,7 +527,7 @@ class Checkpointer: | |||
| 521 | negative_prompt=nprompt, | 527 | negative_prompt=nprompt, |
| 522 | height=self.sample_image_size, | 528 | height=self.sample_image_size, |
| 523 | width=self.sample_image_size, | 529 | width=self.sample_image_size, |
| 524 | latents_or_image=latents[:len(prompt)] if latents is not None else None, | 530 | image=latents[:len(prompt)] if latents is not None else None, |
| 525 | generator=generator if latents is not None else None, | 531 | generator=generator if latents is not None else None, |
| 526 | guidance_scale=guidance_scale, | 532 | guidance_scale=guidance_scale, |
| 527 | eta=eta, | 533 | eta=eta, |
| @@ -567,6 +573,8 @@ def main(): | |||
| 567 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) | 573 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) |
| 568 | basepath.mkdir(parents=True, exist_ok=True) | 574 | basepath.mkdir(parents=True, exist_ok=True) |
| 569 | 575 | ||
| 576 | embeddings_dir = Path(args.embeddings_dir) | ||
| 577 | |||
| 570 | accelerator = Accelerator( | 578 | accelerator = Accelerator( |
| 571 | log_with=LoggerType.TENSORBOARD, | 579 | log_with=LoggerType.TENSORBOARD, |
| 572 | logging_dir=f"{basepath}", | 580 | logging_dir=f"{basepath}", |
| @@ -630,15 +638,25 @@ def main(): | |||
| 630 | 638 | ||
| 631 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 639 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
| 632 | 640 | ||
| 641 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | ||
| 642 | text_encoder.resize_token_embeddings(len(tokenizer)) | ||
| 643 | |||
| 644 | token_embeds = text_encoder.get_input_embeddings().weight.data | ||
| 645 | |||
| 633 | print(f"Token ID mappings:") | 646 | print(f"Token ID mappings:") |
| 634 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): | 647 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): |
| 635 | print(f"- {token_id} {token}") | 648 | print(f"- {token_id} {token}") |
| 636 | 649 | ||
| 637 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | 650 | embedding_file = embeddings_dir.joinpath(f"{token}.bin") |
| 638 | text_encoder.resize_token_embeddings(len(tokenizer)) | 651 | if embedding_file.exists() and embedding_file.is_file(): |
| 652 | embedding_data = torch.load(embedding_file, map_location="cpu") | ||
| 653 | |||
| 654 | emb = next(iter(embedding_data.values())) | ||
| 655 | if len(emb.shape) == 1: | ||
| 656 | emb = emb.unsqueeze(0) | ||
| 657 | |||
| 658 | token_embeds[token_id] = emb | ||
| 639 | 659 | ||
| 640 | # Initialise the newly added placeholder token with the embeddings of the initializer token | ||
| 641 | token_embeds = text_encoder.get_input_embeddings().weight.data | ||
| 642 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) | 660 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) |
| 643 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | 661 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) |
| 644 | 662 | ||
| @@ -959,8 +977,6 @@ def main(): | |||
| 959 | else: | 977 | else: |
| 960 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 978 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
| 961 | 979 | ||
| 962 | del timesteps, noise, latents, noisy_latents, encoder_hidden_states | ||
| 963 | |||
| 964 | if args.num_class_images != 0: | 980 | if args.num_class_images != 0: |
| 965 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 981 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
| 966 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 982 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
| @@ -977,6 +993,8 @@ def main(): | |||
| 977 | else: | 993 | else: |
| 978 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 994 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
| 979 | 995 | ||
| 996 | acc = (model_pred == latents).float().mean() | ||
| 997 | |||
| 980 | accelerator.backward(loss) | 998 | accelerator.backward(loss) |
| 981 | 999 | ||
| 982 | if not args.train_text_encoder: | 1000 | if not args.train_text_encoder: |
| @@ -1004,8 +1022,6 @@ def main(): | |||
| 1004 | ema_unet.step(unet) | 1022 | ema_unet.step(unet) |
| 1005 | optimizer.zero_grad(set_to_none=True) | 1023 | optimizer.zero_grad(set_to_none=True) |
| 1006 | 1024 | ||
| 1007 | acc = (model_pred == latents).float().mean() | ||
| 1008 | |||
| 1009 | avg_loss.update(loss.detach_(), bsz) | 1025 | avg_loss.update(loss.detach_(), bsz) |
| 1010 | avg_acc.update(acc.detach_(), bsz) | 1026 | avg_acc.update(acc.detach_(), bsz) |
| 1011 | 1027 | ||
| @@ -1069,8 +1085,6 @@ def main(): | |||
| 1069 | else: | 1085 | else: |
| 1070 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 1086 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
| 1071 | 1087 | ||
| 1072 | del timesteps, noise, latents, noisy_latents, encoder_hidden_states | ||
| 1073 | |||
| 1074 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 1088 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
| 1075 | 1089 | ||
| 1076 | acc = (model_pred == latents).float().mean() | 1090 | acc = (model_pred == latents).float().mean() |
