From b4a00845721fbc95819ad888dfd7c24013bbf4d0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 19 Oct 2022 12:19:23 +0200 Subject: Updated Dreambooth training --- dreambooth.py | 104 +++++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 71 insertions(+), 33 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 9786e0f..d1cf535 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -1,4 +1,5 @@ import argparse +import itertools import math import os import datetime @@ -113,7 +114,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=1200, + default=3600, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -128,9 +129,15 @@ def parse_args(): help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) parser.add_argument( - "--learning_rate", + "--learning_rate_unet", type=float, - default=5e-5, + default=3e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--learning_rate_text", + type=float, + default=3e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -358,12 +365,14 @@ class Checkpointer: def save_model(self): print("Saving model...") - unwrapped = self.accelerator.unwrap_model( + unwrapped_unet = self.accelerator.unwrap_model( self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) + unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) + pipeline = VlpnStableDiffusion( - text_encoder=self.text_encoder, + text_encoder=unwrapped_text_encoder, vae=self.vae, - unet=unwrapped, + unet=unwrapped_unet, tokenizer=self.tokenizer, scheduler=PNDMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True @@ -371,7 +380,8 @@ class Checkpointer: ) pipeline.save_pretrained(self.output_dir.joinpath("model")) - del unwrapped + del unwrapped_unet + del unwrapped_text_encoder del pipeline if torch.cuda.is_available(): @@ -381,16 +391,18 @@ class Checkpointer: def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): samples_path = Path(self.output_dir).joinpath("samples") - unwrapped = self.accelerator.unwrap_model( + unwrapped_unet = self.accelerator.unwrap_model( self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) + unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) + scheduler = EulerAScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) pipeline = VlpnStableDiffusion( - text_encoder=self.text_encoder, + text_encoder=unwrapped_text_encoder, vae=self.vae, - unet=unwrapped, + unet=unwrapped_unet, tokenizer=self.tokenizer, scheduler=scheduler, ).to(self.accelerator.device) @@ -416,9 +428,16 @@ class Checkpointer: for i in range(self.sample_batches): batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] - prompt = [prompt.format(self.instance_identifier) - for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] - nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] + prompt = [ + prompt.format(self.instance_identifier) + for batch in batches + for prompt in batch["prompts"] + ][:self.sample_batch_size] + nprompt = [ + prompt + for batch in batches + for prompt in batch["nprompts"] + ][:self.sample_batch_size] samples = pipeline( prompt=prompt, @@ -443,7 +462,8 @@ class Checkpointer: del all_samples del image_grid - del unwrapped + del unwrapped_unet + del unwrapped_text_encoder del scheduler del pipeline del generator @@ -482,8 +502,7 @@ def main(): tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') # Load models and create wrapper for stable diffusion - text_encoder = CLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder='text_encoder') + text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') @@ -499,17 +518,21 @@ def main(): if args.gradient_checkpointing: unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() # slice_size = unet.config.attention_head_dim // 2 # unet.set_attention_slice(slice_size) # Freeze text_encoder and vae freeze_params(vae.parameters()) - freeze_params(text_encoder.parameters()) if args.scale_lr: - args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * + args.learning_rate_unet = ( + args.learning_rate_unet * args.gradient_accumulation_steps * + args.train_batch_size * accelerator.num_processes + ) + args.learning_rate_text = ( + args.learning_rate_text * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) @@ -526,8 +549,16 @@ def main(): # Initialize the optimizer optimizer = optimizer_class( - unet.parameters(), # only optimize unet - lr=args.learning_rate, + [ + { + 'params': unet.parameters(), + 'lr': args.learning_rate_unet, + }, + { + 'params': text_encoder.parameters(), + 'lr': args.learning_rate_text, + } + ], betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, @@ -592,8 +623,10 @@ def main(): missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] if len(missing_data) != 0: - batched_data = [missing_data[i:i+args.sample_batch_size] - for i in range(0, len(missing_data), args.sample_batch_size)] + batched_data = [ + missing_data[i:i+args.sample_batch_size] + for i in range(0, len(missing_data), args.sample_batch_size) + ] scheduler = EulerAScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" @@ -610,9 +643,9 @@ def main(): with torch.inference_mode(): for batch in batched_data: - image_name = [p.class_image_path for p in batch] - prompt = [p.prompt.format(args.class_identifier) for p in batch] - nprompt = [p.nprompt for p in batch] + image_name = [item.class_image_path for item in batch] + prompt = [item.prompt.format(args.class_identifier) for item in batch] + nprompt = [item.nprompt for item in batch] images = pipeline( prompt=prompt, @@ -655,16 +688,14 @@ def main(): num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, ) - unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - unet, optimizer, train_dataloader, val_dataloader, lr_scheduler + unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler ) # Move text_encoder and vae to device - text_encoder.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) # Keep text_encoder and vae in eval mode as we don't train these - text_encoder.eval() vae.eval() # We need to recalculate our total training steps as the size of the training dataloader may have changed. @@ -736,12 +767,13 @@ def main(): local_progress_bar.reset() unet.train() + text_encoder.train() train_loss = 0.0 sample_checkpoint = False for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(unet): + with accelerator.accumulate(itertools.chain(unet, text_encoder)): # Convert images to latent space latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * 0.18215 @@ -782,7 +814,8 @@ def main(): accelerator.backward(loss) if accelerator.sync_gradients: - accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) + accelerator.clip_grad_norm_(itertools.chain( + unet.parameters(), text_encoder.parameters()), args.max_grad_norm) optimizer.step() if not accelerator.optimizer_step_was_skipped: lr_scheduler.step() @@ -804,7 +837,11 @@ def main(): if global_step % args.sample_frequency == 0: sample_checkpoint = True - logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} + logs = { + "train/loss": loss, + "lr/unet": lr_scheduler.get_last_lr()[0], + "lr/text": lr_scheduler.get_last_lr()[1] + } if args.use_ema: logs["ema_decay"] = ema_unet.decay @@ -820,6 +857,7 @@ def main(): accelerator.wait_for_everyone() unet.eval() + text_encoder.eval() val_loss = 0.0 for step, batch in enumerate(val_dataloader): -- cgit v1.2.3-54-g00ecf