From dee4c7135754543f1eb7ea616ee3847d34a85b51 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 16 Oct 2022 14:39:39 +0200 Subject: Update --- dreambooth_plus.py | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) (limited to 'dreambooth_plus.py') diff --git a/dreambooth_plus.py b/dreambooth_plus.py index eeee424..42994af 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py @@ -118,7 +118,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=1300, + default=1200, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -141,7 +141,7 @@ def parse_args(): parser.add_argument( "--learning_rate_text", type=float, - default=5e-6, + default=1e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -153,7 +153,7 @@ def parse_args(): parser.add_argument( "--lr_scheduler", type=str, - default="cosine", + default="cosine_with_restarts", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup"]' @@ -162,9 +162,15 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, - default=500, + default=300, help="Number of steps for the warmup in the lr scheduler." ) + parser.add_argument( + "--lr_cycles", + type=int, + default=2, + help="Number of restart cycles in the lr scheduler." + ) parser.add_argument( "--use_ema", action="store_true", @@ -179,7 +185,7 @@ def parse_args(): parser.add_argument( "--ema_power", type=float, - default=6 / 7 + default=9 / 10 ) parser.add_argument( "--ema_max_decay", @@ -565,6 +571,7 @@ def main(): # Initialise the newly added placeholder token with the embeddings of the initializer token token_embeds = text_encoder.get_input_embeddings().weight.data + original_token_embeds = token_embeds.detach().clone().to(accelerator.device) initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) token_embeds[placeholder_token_id] = initializer_token_embeddings @@ -717,11 +724,10 @@ def main(): if args.lr_scheduler == "cosine_with_restarts": lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( - args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles=num_update_steps_per_epoch, + num_cycles=args.lr_cycles, ) else: lr_scheduler = get_scheduler( @@ -857,15 +863,16 @@ def main(): accelerator.backward(loss) - # Zero out the gradients for all token embeddings except the newly added + # Keep the token embeddings fixed except the newly added # embeddings for the concept, as we only want to optimize the concept embeddings if accelerator.num_processes > 1: - grads = text_encoder.module.get_input_embeddings().weight.grad + token_embeds = text_encoder.module.get_input_embeddings().weight else: - grads = text_encoder.get_input_embeddings().weight.grad - # Get the index for tokens that we want to zero the grads for - index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id - grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0) + token_embeds = text_encoder.get_input_embeddings().weight + + # Get the index for tokens that we want to freeze + index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id + token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] if accelerator.sync_gradients: accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) -- cgit v1.2.3-54-g00ecf