From dee4c7135754543f1eb7ea616ee3847d34a85b51 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 16 Oct 2022 14:39:39 +0200 Subject: Update --- dreambooth.py | 41 ++++++++++++++++++++++++++++++++--------- dreambooth_plus.py | 33 ++++++++++++++++++++------------- textual_inversion.py | 28 +++++++++++++++++----------- 3 files changed, 69 insertions(+), 33 deletions(-) diff --git a/dreambooth.py b/dreambooth.py index 1ba8dc0..9e2645b 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -15,7 +15,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel -from diffusers.optimization import get_scheduler +from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup from diffusers.training_utils import EMAModel from PIL import Image from tqdm.auto import tqdm @@ -150,9 +150,15 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, - default=500, + default=300, help="Number of steps for the warmup in the lr scheduler." ) + parser.add_argument( + "--lr_cycles", + type=int, + default=2, + help="Number of restart cycles in the lr scheduler." + ) parser.add_argument( "--use_ema", action="store_true", @@ -167,7 +173,7 @@ def parse_args(): parser.add_argument( "--ema_power", type=float, - default=6 / 7 + default=9 / 10 ) parser.add_argument( "--ema_max_decay", @@ -296,6 +302,13 @@ def parse_args(): return args +def save_args(basepath: Path, args, extra={}): + info = {"args": vars(args)} + info["args"].update(extra) + with open(basepath.joinpath("args.json"), "w") as f: + json.dump(info, f, indent=4) + + def freeze_params(params): for param in params: param.requires_grad = False @@ -455,6 +468,8 @@ def main(): logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) + save_args(basepath, args) + # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) @@ -614,12 +629,20 @@ def main(): args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - ) + if args.lr_scheduler == "cosine_with_restarts": + lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=args.lr_cycles, + ) + else: + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, val_dataloader, lr_scheduler diff --git a/dreambooth_plus.py b/dreambooth_plus.py index eeee424..42994af 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py @@ -118,7 +118,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=1300, + default=1200, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -141,7 +141,7 @@ def parse_args(): parser.add_argument( "--learning_rate_text", type=float, - default=5e-6, + default=1e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -153,7 +153,7 @@ def parse_args(): parser.add_argument( "--lr_scheduler", type=str, - default="cosine", + default="cosine_with_restarts", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup"]' @@ -162,9 +162,15 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, - default=500, + default=300, help="Number of steps for the warmup in the lr scheduler." ) + parser.add_argument( + "--lr_cycles", + type=int, + default=2, + help="Number of restart cycles in the lr scheduler." + ) parser.add_argument( "--use_ema", action="store_true", @@ -179,7 +185,7 @@ def parse_args(): parser.add_argument( "--ema_power", type=float, - default=6 / 7 + default=9 / 10 ) parser.add_argument( "--ema_max_decay", @@ -565,6 +571,7 @@ def main(): # Initialise the newly added placeholder token with the embeddings of the initializer token token_embeds = text_encoder.get_input_embeddings().weight.data + original_token_embeds = token_embeds.detach().clone().to(accelerator.device) initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) token_embeds[placeholder_token_id] = initializer_token_embeddings @@ -717,11 +724,10 @@ def main(): if args.lr_scheduler == "cosine_with_restarts": lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( - args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles=num_update_steps_per_epoch, + num_cycles=args.lr_cycles, ) else: lr_scheduler = get_scheduler( @@ -857,15 +863,16 @@ def main(): accelerator.backward(loss) - # Zero out the gradients for all token embeddings except the newly added + # Keep the token embeddings fixed except the newly added # embeddings for the concept, as we only want to optimize the concept embeddings if accelerator.num_processes > 1: - grads = text_encoder.module.get_input_embeddings().weight.grad + token_embeds = text_encoder.module.get_input_embeddings().weight else: - grads = text_encoder.get_input_embeddings().weight.grad - # Get the index for tokens that we want to zero the grads for - index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id - grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0) + token_embeds = text_encoder.get_input_embeddings().weight + + # Get the index for tokens that we want to freeze + index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id + token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] if accelerator.sync_gradients: accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) diff --git a/textual_inversion.py b/textual_inversion.py index 2109d13..61c96b7 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -155,9 +155,15 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, - default=500, + default=300, help="Number of steps for the warmup in the lr scheduler." ) + parser.add_argument( + "--lr_cycles", + type=int, + default=15, + help="Number of restart cycles in the lr scheduler." + ) parser.add_argument( "--use_8bit_adam", action="store_true", @@ -515,13 +521,13 @@ def main(): # Initialise the newly added placeholder token with the embeddings of the initializer token token_embeds = text_encoder.get_input_embeddings().weight.data - - initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) + original_token_embeds = token_embeds.detach().clone().to(accelerator.device) if args.resume_checkpoint is not None: token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[ args.placeholder_token] else: + initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) token_embeds[placeholder_token_id] = initializer_token_embeddings # Freeze vae and unet @@ -662,11 +668,10 @@ def main(): if args.lr_scheduler == "cosine_with_restarts": lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( - args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles=num_update_steps_per_epoch, + num_cycles=args.lr_cycles, ) else: lr_scheduler = get_scheduler( @@ -803,15 +808,16 @@ def main(): accelerator.backward(loss) - # Zero out the gradients for all token embeddings except the newly added + # Keep the token embeddings fixed except the newly added # embeddings for the concept, as we only want to optimize the concept embeddings if accelerator.num_processes > 1: - grads = text_encoder.module.get_input_embeddings().weight.grad + token_embeds = text_encoder.module.get_input_embeddings().weight else: - grads = text_encoder.get_input_embeddings().weight.grad - # Get the index for tokens that we want to zero the grads for - index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id - grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0) + token_embeds = text_encoder.get_input_embeddings().weight + + # Get the index for tokens that we want to freeze + index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id + token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] optimizer.step() if not accelerator.optimizer_step_was_skipped: -- cgit v1.2.3-70-g09d2