From cd80af823d31148f9c0fa4d8045b773adfe1e6c3 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 10 Oct 2022 12:46:57 +0200 Subject: Dreambooth: Add EMA support --- dreambooth.py | 50 ++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 44 insertions(+), 6 deletions(-) diff --git a/dreambooth.py b/dreambooth.py index 9f1b7af..f7d31d2 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -16,6 +16,7 @@ from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel from PIL import Image from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer @@ -111,7 +112,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=600, + default=5000, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -128,7 +129,7 @@ def parse_args(): parser.add_argument( "--learning_rate", type=float, - default=1e-7, + default=1e-4, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -140,7 +141,7 @@ def parse_args(): parser.add_argument( "--lr_scheduler", type=str, - default="linear", + default="cosine", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup"]' @@ -152,9 +153,31 @@ def parse_args(): default=200, help="Number of steps for the warmup in the lr scheduler." ) + parser.add_argument( + "--use_ema", + action="store_true", + default=True, + help="Whether to use EMA model." + ) + parser.add_argument( + "--ema_inv_gamma", + type=float, + default=0.1 + ) + parser.add_argument( + "--ema_power", + type=float, + default=1 + ) + parser.add_argument( + "--ema_max_decay", + type=float, + default=0.9999 + ) parser.add_argument( "--use_8bit_adam", action="store_true", + default=True, help="Whether or not to use 8-bit Adam from bitsandbytes." ) parser.add_argument( @@ -172,7 +195,7 @@ def parse_args(): parser.add_argument( "--adam_weight_decay", type=float, - default=0, + default=1e-2, help="Weight decay to use." ) parser.add_argument( @@ -298,6 +321,7 @@ class Checkpointer: accelerator, vae, unet, + ema_unet, tokenizer, text_encoder, output_dir: Path, @@ -311,6 +335,7 @@ class Checkpointer: self.accelerator = accelerator self.vae = vae self.unet = unet + self.ema_unet = ema_unet self.tokenizer = tokenizer self.text_encoder = text_encoder self.output_dir = output_dir @@ -324,7 +349,8 @@ class Checkpointer: def checkpoint(self): print("Saving model...") - unwrapped = self.accelerator.unwrap_model(self.unet) + unwrapped = self.accelerator.unwrap_model( + self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) pipeline = VlpnStableDiffusion( text_encoder=self.text_encoder, vae=self.vae, @@ -346,7 +372,8 @@ class Checkpointer: def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): samples_path = Path(self.output_dir).joinpath("samples") - unwrapped = self.accelerator.unwrap_model(self.unet) + unwrapped = self.accelerator.unwrap_model( + self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) scheduler = EulerAScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) @@ -602,6 +629,13 @@ def main(): unet, optimizer, train_dataloader, val_dataloader, lr_scheduler ) + ema_unet = EMAModel( + unet, + inv_gamma=args.ema_inv_gamma, + power=args.ema_power, + max_value=args.ema_max_decay + ) if args.use_ema else None + # Move text_encoder and vae to device text_encoder.to(accelerator.device) vae.to(accelerator.device) @@ -643,6 +677,7 @@ def main(): accelerator=accelerator, vae=vae, unet=unet, + ema_unet=ema_unet, tokenizer=tokenizer, text_encoder=text_encoder, output_dir=basepath, @@ -737,6 +772,9 @@ def main(): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: + if args.use_ema: + ema_unet.step(unet) + local_progress_bar.update(1) global_progress_bar.update(1) -- cgit v1.2.3-70-g09d2