diff options
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 28 |
1 files changed, 17 insertions, 11 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 2109d13..61c96b7 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -155,10 +155,16 @@ def parse_args(): | |||
155 | parser.add_argument( | 155 | parser.add_argument( |
156 | "--lr_warmup_steps", | 156 | "--lr_warmup_steps", |
157 | type=int, | 157 | type=int, |
158 | default=500, | 158 | default=300, |
159 | help="Number of steps for the warmup in the lr scheduler." | 159 | help="Number of steps for the warmup in the lr scheduler." |
160 | ) | 160 | ) |
161 | parser.add_argument( | 161 | parser.add_argument( |
162 | "--lr_cycles", | ||
163 | type=int, | ||
164 | default=15, | ||
165 | help="Number of restart cycles in the lr scheduler." | ||
166 | ) | ||
167 | parser.add_argument( | ||
162 | "--use_8bit_adam", | 168 | "--use_8bit_adam", |
163 | action="store_true", | 169 | action="store_true", |
164 | help="Whether or not to use 8-bit Adam from bitsandbytes." | 170 | help="Whether or not to use 8-bit Adam from bitsandbytes." |
@@ -515,13 +521,13 @@ def main(): | |||
515 | 521 | ||
516 | # Initialise the newly added placeholder token with the embeddings of the initializer token | 522 | # Initialise the newly added placeholder token with the embeddings of the initializer token |
517 | token_embeds = text_encoder.get_input_embeddings().weight.data | 523 | token_embeds = text_encoder.get_input_embeddings().weight.data |
518 | 524 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) | |
519 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | ||
520 | 525 | ||
521 | if args.resume_checkpoint is not None: | 526 | if args.resume_checkpoint is not None: |
522 | token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[ | 527 | token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[ |
523 | args.placeholder_token] | 528 | args.placeholder_token] |
524 | else: | 529 | else: |
530 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | ||
525 | token_embeds[placeholder_token_id] = initializer_token_embeddings | 531 | token_embeds[placeholder_token_id] = initializer_token_embeddings |
526 | 532 | ||
527 | # Freeze vae and unet | 533 | # Freeze vae and unet |
@@ -662,11 +668,10 @@ def main(): | |||
662 | 668 | ||
663 | if args.lr_scheduler == "cosine_with_restarts": | 669 | if args.lr_scheduler == "cosine_with_restarts": |
664 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 670 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
665 | args.lr_scheduler, | ||
666 | optimizer=optimizer, | 671 | optimizer=optimizer, |
667 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 672 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
668 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 673 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
669 | num_cycles=num_update_steps_per_epoch, | 674 | num_cycles=args.lr_cycles, |
670 | ) | 675 | ) |
671 | else: | 676 | else: |
672 | lr_scheduler = get_scheduler( | 677 | lr_scheduler = get_scheduler( |
@@ -803,15 +808,16 @@ def main(): | |||
803 | 808 | ||
804 | accelerator.backward(loss) | 809 | accelerator.backward(loss) |
805 | 810 | ||
806 | # Zero out the gradients for all token embeddings except the newly added | 811 | # Keep the token embeddings fixed except the newly added |
807 | # embeddings for the concept, as we only want to optimize the concept embeddings | 812 | # embeddings for the concept, as we only want to optimize the concept embeddings |
808 | if accelerator.num_processes > 1: | 813 | if accelerator.num_processes > 1: |
809 | grads = text_encoder.module.get_input_embeddings().weight.grad | 814 | token_embeds = text_encoder.module.get_input_embeddings().weight |
810 | else: | 815 | else: |
811 | grads = text_encoder.get_input_embeddings().weight.grad | 816 | token_embeds = text_encoder.get_input_embeddings().weight |
812 | # Get the index for tokens that we want to zero the grads for | 817 | |
813 | index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id | 818 | # Get the index for tokens that we want to freeze |
814 | grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0) | 819 | index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id |
820 | token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] | ||
815 | 821 | ||
816 | optimizer.step() | 822 | optimizer.step() |
817 | if not accelerator.optimizer_step_was_skipped: | 823 | if not accelerator.optimizer_step_was_skipped: |