From 35116fdf6fb1aedbe0da3cfa9372d53ddb455a26 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 10 Oct 2022 17:55:08 +0200 Subject: Added EMA support to Textual Inversion --- dreambooth.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index f7d31d2..02f83c6 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -129,7 +129,7 @@ def parse_args(): parser.add_argument( "--learning_rate", type=float, - default=1e-4, + default=1e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -150,7 +150,7 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, - default=200, + default=600, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( @@ -162,12 +162,12 @@ def parse_args(): parser.add_argument( "--ema_inv_gamma", type=float, - default=0.1 + default=1.0 ) parser.add_argument( "--ema_power", type=float, - default=1 + default=1.0 ) parser.add_argument( "--ema_max_decay", @@ -783,7 +783,12 @@ def main(): if global_step % args.sample_frequency == 0: sample_checkpoint = True - logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} + logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} + if args.use_ema: + logs["ema_decay"] = ema_unet.decay + + accelerator.log(logs, step=global_step) + local_progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: @@ -824,16 +829,12 @@ def main(): local_progress_bar.update(1) global_progress_bar.update(1) - logs = {"mode": "validation", "loss": loss} + logs = {"val/loss": loss} local_progress_bar.set_postfix(**logs) val_loss /= len(val_dataloader) - accelerator.log({ - "train/loss": train_loss, - "val/loss": val_loss, - "lr": lr_scheduler.get_last_lr()[0] - }, step=global_step) + accelerator.log({"val/loss": val_loss}, step=global_step) local_progress_bar.clear() global_progress_bar.clear() -- cgit v1.2.3-54-g00ecf