summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-16 14:39:39 +0200
committerVolpeon <git@volpeon.ink>2022-10-16 14:39:39 +0200
commitdee4c7135754543f1eb7ea616ee3847d34a85b51 (patch)
tree4064b44bb79e499cf6a8f1ec38a83a4889f067a7 /textual_inversion.py
parentUpdate (diff)
downloadtextual-inversion-diff-dee4c7135754543f1eb7ea616ee3847d34a85b51.tar.gz
textual-inversion-diff-dee4c7135754543f1eb7ea616ee3847d34a85b51.tar.bz2
textual-inversion-diff-dee4c7135754543f1eb7ea616ee3847d34a85b51.zip
Update
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py28
1 files changed, 17 insertions, 11 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 2109d13..61c96b7 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -155,10 +155,16 @@ def parse_args():
155 parser.add_argument( 155 parser.add_argument(
156 "--lr_warmup_steps", 156 "--lr_warmup_steps",
157 type=int, 157 type=int,
158 default=500, 158 default=300,
159 help="Number of steps for the warmup in the lr scheduler." 159 help="Number of steps for the warmup in the lr scheduler."
160 ) 160 )
161 parser.add_argument( 161 parser.add_argument(
162 "--lr_cycles",
163 type=int,
164 default=15,
165 help="Number of restart cycles in the lr scheduler."
166 )
167 parser.add_argument(
162 "--use_8bit_adam", 168 "--use_8bit_adam",
163 action="store_true", 169 action="store_true",
164 help="Whether or not to use 8-bit Adam from bitsandbytes." 170 help="Whether or not to use 8-bit Adam from bitsandbytes."
@@ -515,13 +521,13 @@ def main():
515 521
516 # Initialise the newly added placeholder token with the embeddings of the initializer token 522 # Initialise the newly added placeholder token with the embeddings of the initializer token
517 token_embeds = text_encoder.get_input_embeddings().weight.data 523 token_embeds = text_encoder.get_input_embeddings().weight.data
518 524 original_token_embeds = token_embeds.detach().clone().to(accelerator.device)
519 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
520 525
521 if args.resume_checkpoint is not None: 526 if args.resume_checkpoint is not None:
522 token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[ 527 token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[
523 args.placeholder_token] 528 args.placeholder_token]
524 else: 529 else:
530 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
525 token_embeds[placeholder_token_id] = initializer_token_embeddings 531 token_embeds[placeholder_token_id] = initializer_token_embeddings
526 532
527 # Freeze vae and unet 533 # Freeze vae and unet
@@ -662,11 +668,10 @@ def main():
662 668
663 if args.lr_scheduler == "cosine_with_restarts": 669 if args.lr_scheduler == "cosine_with_restarts":
664 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 670 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
665 args.lr_scheduler,
666 optimizer=optimizer, 671 optimizer=optimizer,
667 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 672 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
668 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 673 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
669 num_cycles=num_update_steps_per_epoch, 674 num_cycles=args.lr_cycles,
670 ) 675 )
671 else: 676 else:
672 lr_scheduler = get_scheduler( 677 lr_scheduler = get_scheduler(
@@ -803,15 +808,16 @@ def main():
803 808
804 accelerator.backward(loss) 809 accelerator.backward(loss)
805 810
806 # Zero out the gradients for all token embeddings except the newly added 811 # Keep the token embeddings fixed except the newly added
807 # embeddings for the concept, as we only want to optimize the concept embeddings 812 # embeddings for the concept, as we only want to optimize the concept embeddings
808 if accelerator.num_processes > 1: 813 if accelerator.num_processes > 1:
809 grads = text_encoder.module.get_input_embeddings().weight.grad 814 token_embeds = text_encoder.module.get_input_embeddings().weight
810 else: 815 else:
811 grads = text_encoder.get_input_embeddings().weight.grad 816 token_embeds = text_encoder.get_input_embeddings().weight
812 # Get the index for tokens that we want to zero the grads for 817
813 index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id 818 # Get the index for tokens that we want to freeze
814 grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0) 819 index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id
820 token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :]
815 821
816 optimizer.step() 822 optimizer.step()
817 if not accelerator.optimizer_step_was_skipped: 823 if not accelerator.optimizer_step_was_skipped: