From 35116fdf6fb1aedbe0da3cfa9372d53ddb455a26 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 10 Oct 2022 17:55:08 +0200 Subject: Added EMA support to Textual Inversion --- dreambooth.py | 23 +++++++++++----------- textual_inversion.py | 54 ++++++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 60 insertions(+), 17 deletions(-) diff --git a/dreambooth.py b/dreambooth.py index f7d31d2..02f83c6 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -129,7 +129,7 @@ def parse_args(): parser.add_argument( "--learning_rate", type=float, - default=1e-4, + default=1e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -150,7 +150,7 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, - default=200, + default=600, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( @@ -162,12 +162,12 @@ def parse_args(): parser.add_argument( "--ema_inv_gamma", type=float, - default=0.1 + default=1.0 ) parser.add_argument( "--ema_power", type=float, - default=1 + default=1.0 ) parser.add_argument( "--ema_max_decay", @@ -783,7 +783,12 @@ def main(): if global_step % args.sample_frequency == 0: sample_checkpoint = True - logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} + logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} + if args.use_ema: + logs["ema_decay"] = ema_unet.decay + + accelerator.log(logs, step=global_step) + local_progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: @@ -824,16 +829,12 @@ def main(): local_progress_bar.update(1) global_progress_bar.update(1) - logs = {"mode": "validation", "loss": loss} + logs = {"val/loss": loss} local_progress_bar.set_postfix(**logs) val_loss /= len(val_dataloader) - accelerator.log({ - "train/loss": train_loss, - "val/loss": val_loss, - "lr": lr_scheduler.get_last_lr()[0] - }, step=global_step) + accelerator.log({"val/loss": val_loss}, step=global_step) local_progress_bar.clear() global_progress_bar.clear() diff --git a/textual_inversion.py b/textual_inversion.py index b01bdbc..e6d856a 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -17,6 +17,7 @@ from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel from PIL import Image from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer @@ -149,9 +150,30 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, - default=200, + default=600, help="Number of steps for the warmup in the lr scheduler." ) + parser.add_argument( + "--use_ema", + action="store_true", + default=True, + help="Whether to use EMA model." + ) + parser.add_argument( + "--ema_inv_gamma", + type=float, + default=1.0 + ) + parser.add_argument( + "--ema_power", + type=float, + default=1.0 + ) + parser.add_argument( + "--ema_max_decay", + type=float, + default=0.9999 + ) parser.add_argument( "--use_8bit_adam", action="store_true", @@ -326,6 +348,7 @@ class Checkpointer: unet, tokenizer, text_encoder, + ema_text_encoder, placeholder_token, placeholder_token_id, output_dir: Path, @@ -340,6 +363,7 @@ class Checkpointer: self.unet = unet self.tokenizer = tokenizer self.text_encoder = text_encoder + self.ema_text_encoder = ema_text_encoder self.placeholder_token = placeholder_token self.placeholder_token_id = placeholder_token_id self.output_dir = output_dir @@ -356,7 +380,8 @@ class Checkpointer: checkpoints_path = self.output_dir.joinpath("checkpoints") checkpoints_path.mkdir(parents=True, exist_ok=True) - unwrapped = self.accelerator.unwrap_model(self.text_encoder) + unwrapped = self.accelerator.unwrap_model( + self.ema_text_encoder.averaged_model if self.ema_text_encoder is not None else self.text_encoder) # Save a checkpoint learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] @@ -375,7 +400,8 @@ class Checkpointer: def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): samples_path = Path(self.output_dir).joinpath("samples") - unwrapped = self.accelerator.unwrap_model(self.text_encoder) + unwrapped = self.accelerator.unwrap_model( + self.ema_text_encoder.averaged_model if self.ema_text_encoder is not None else self.text_encoder) scheduler = EulerAScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) @@ -681,6 +707,13 @@ def main(): text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler ) + ema_text_encoder = EMAModel( + text_encoder, + inv_gamma=args.ema_inv_gamma, + power=args.ema_power, + max_value=args.ema_max_decay + ) if args.use_ema else None + # Move vae and unet to device vae.to(accelerator.device) unet.to(accelerator.device) @@ -724,6 +757,7 @@ def main(): unet=unet, tokenizer=tokenizer, text_encoder=text_encoder, + ema_text_encoder=ema_text_encoder, placeholder_token=args.placeholder_token, placeholder_token_id=placeholder_token_id, output_dir=basepath, @@ -825,6 +859,9 @@ def main(): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: + if args.use_ema: + ema_text_encoder.step(unet) + local_progress_bar.update(1) global_progress_bar.update(1) @@ -843,7 +880,12 @@ def main(): "resume_checkpoint": f"{basepath}/checkpoints/last.bin" }) - logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} + logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} + if args.use_ema: + logs["ema_decay"] = ema_text_encoder.decay + + accelerator.log(logs, step=global_step) + local_progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: @@ -884,12 +926,12 @@ def main(): local_progress_bar.update(1) global_progress_bar.update(1) - logs = {"mode": "validation", "loss": loss} + logs = {"val/loss": loss} local_progress_bar.set_postfix(**logs) val_loss /= len(val_dataloader) - accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) + accelerator.log({"val/loss": val_loss}, step=global_step) local_progress_bar.clear() global_progress_bar.clear() -- cgit v1.2.3-54-g00ecf