From ca914af018632b6231fb3ee4fcd5cdbdc467c784 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 21 Oct 2022 09:50:46 +0200 Subject: Add optional TI functionality to Dreambooth --- dreambooth.py | 101 ++++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 66 insertions(+), 35 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index da8399f..72c56cd 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -170,14 +170,14 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, - default=300, + default=500, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( "--lr_cycles", type=int, default=None, - help="Number of restart cycles in the lr scheduler." + help="Number of restart cycles in the lr scheduler (if supported)." ) parser.add_argument( "--use_ema", @@ -506,11 +506,10 @@ def main(): logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) - save_args(basepath, args) + args.seed = args.seed or (torch.random.seed() >> 32) + set_seed(args.seed) - # If passed along, set the training seed now. - if args.seed is not None: - set_seed(args.seed) + save_args(basepath, args) # Load the tokenizer and add the placeholder token as a additional special token if args.tokenizer_name: @@ -523,13 +522,22 @@ def main(): vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') - ema_unet = EMAModel( - unet, - inv_gamma=args.ema_inv_gamma, - power=args.ema_power, - max_value=args.ema_max_decay, - device=accelerator.device - ) if args.use_ema else None + ema_unet = None + if args.use_ema: + ema_unet = EMAModel( + unet, + inv_gamma=args.ema_inv_gamma, + power=args.ema_power, + max_value=args.ema_max_decay, + device=accelerator.device + ) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() + + # Freeze text_encoder and vae + freeze_params(vae.parameters()) if args.initializer_token is not None: # Convert the initializer_token, placeholder_token to ids @@ -545,22 +553,22 @@ def main(): print(f"Training new token {args.placeholder_token}.") placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) + # Resize the token embeddings as we are adding new special tokens to the tokenizer text_encoder.resize_token_embeddings(len(tokenizer)) - token_embeds = text_encoder.get_input_embeddings() - initializer_token_embeddings = token_embeds(initializer_token_ids) - token_embeds.weight.data[placeholder_token_id] = initializer_token_embeddings - - prompt_processor = PromptProcessor(tokenizer, text_encoder) - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - text_encoder.gradient_checkpointing_enable() + # Initialise the newly added placeholder token with the embeddings of the initializer token + token_embeds = text_encoder.get_input_embeddings().weight.data + original_token_embeds = token_embeds.detach().clone().to(accelerator.device) + initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) + token_embeds[placeholder_token_id] = initializer_token_embeddings - # slice_size = unet.config.attention_head_dim // 2 - # unet.set_attention_slice(slice_size) + freeze_params(itertools.chain( + text_encoder.text_model.encoder.parameters(), + text_encoder.text_model.final_layer_norm.parameters(), + text_encoder.text_model.embeddings.position_embedding.parameters(), + )) - # Freeze text_encoder and vae - freeze_params(vae.parameters()) + prompt_processor = PromptProcessor(tokenizer, text_encoder) if args.scale_lr: args.learning_rate_unet = ( @@ -583,6 +591,11 @@ def main(): else: optimizer_class = torch.optim.AdamW + if args.initializer_token is not None: + text_encoder_params_to_optimize = text_encoder.get_input_embeddings().parameters() + else: + text_encoder_params_to_optimize = text_encoder.parameters() + # Initialize the optimizer optimizer = optimizer_class( [ @@ -591,7 +604,7 @@ def main(): 'lr': args.learning_rate_unet, }, { - 'params': text_encoder.parameters(), + 'params': text_encoder_params_to_optimize, 'lr': args.learning_rate_text, } ], @@ -849,9 +862,27 @@ def main(): loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") accelerator.backward(loss) + + if args.initializer_token is not None: + # Keep the token embeddings fixed except the newly added + # embeddings for the concept, as we only want to optimize the concept embeddings + if accelerator.num_processes > 1: + token_embeds = text_encoder.module.get_input_embeddings().weight + else: + token_embeds = text_encoder.get_input_embeddings().weight + + # Get the index for tokens that we want to freeze + index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id + token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] + if accelerator.sync_gradients: - accelerator.clip_grad_norm_(itertools.chain( - unet.parameters(), text_encoder.parameters()), args.max_grad_norm) + params_to_clip = ( + unet.parameters() + if args.initializer_token is not None + else itertools.chain(unet.parameters(), text_encoder.parameters()) + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() if not accelerator.optimizer_step_was_skipped: lr_scheduler.step() @@ -896,8 +927,8 @@ def main(): text_encoder.eval() val_loss = 0.0 - for step, batch in enumerate(val_dataloader): - with torch.no_grad(): + with torch.inference_mode(): + for step, batch in enumerate(val_dataloader): latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * 0.18215 @@ -920,12 +951,12 @@ def main(): loss = loss.detach().item() val_loss += loss - if accelerator.sync_gradients: - local_progress_bar.update(1) - global_progress_bar.update(1) + if accelerator.sync_gradients: + local_progress_bar.update(1) + global_progress_bar.update(1) - logs = {"val/loss": loss} - local_progress_bar.set_postfix(**logs) + logs = {"val/loss": loss} + local_progress_bar.set_postfix(**logs) val_loss /= len(val_dataloader) -- cgit v1.2.3-54-g00ecf