From dee4c7135754543f1eb7ea616ee3847d34a85b51 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 16 Oct 2022 14:39:39 +0200 Subject: Update --- textual_inversion.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index 2109d13..61c96b7 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -155,9 +155,15 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, - default=500, + default=300, help="Number of steps for the warmup in the lr scheduler." ) + parser.add_argument( + "--lr_cycles", + type=int, + default=15, + help="Number of restart cycles in the lr scheduler." + ) parser.add_argument( "--use_8bit_adam", action="store_true", @@ -515,13 +521,13 @@ def main(): # Initialise the newly added placeholder token with the embeddings of the initializer token token_embeds = text_encoder.get_input_embeddings().weight.data - - initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) + original_token_embeds = token_embeds.detach().clone().to(accelerator.device) if args.resume_checkpoint is not None: token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[ args.placeholder_token] else: + initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) token_embeds[placeholder_token_id] = initializer_token_embeddings # Freeze vae and unet @@ -662,11 +668,10 @@ def main(): if args.lr_scheduler == "cosine_with_restarts": lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( - args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles=num_update_steps_per_epoch, + num_cycles=args.lr_cycles, ) else: lr_scheduler = get_scheduler( @@ -803,15 +808,16 @@ def main(): accelerator.backward(loss) - # Zero out the gradients for all token embeddings except the newly added + # Keep the token embeddings fixed except the newly added # embeddings for the concept, as we only want to optimize the concept embeddings if accelerator.num_processes > 1: - grads = text_encoder.module.get_input_embeddings().weight.grad + token_embeds = text_encoder.module.get_input_embeddings().weight else: - grads = text_encoder.get_input_embeddings().weight.grad - # Get the index for tokens that we want to zero the grads for - index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id - grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0) + token_embeds = text_encoder.get_input_embeddings().weight + + # Get the index for tokens that we want to freeze + index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id + token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] optimizer.step() if not accelerator.optimizer_step_was_skipped: -- cgit v1.2.3-54-g00ecf