From dee4c7135754543f1eb7ea616ee3847d34a85b51 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 16 Oct 2022 14:39:39 +0200 Subject: Update --- dreambooth.py | 41 ++++++++++++++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 9 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 1ba8dc0..9e2645b 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -15,7 +15,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel -from diffusers.optimization import get_scheduler +from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup from diffusers.training_utils import EMAModel from PIL import Image from tqdm.auto import tqdm @@ -150,9 +150,15 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, - default=500, + default=300, help="Number of steps for the warmup in the lr scheduler." ) + parser.add_argument( + "--lr_cycles", + type=int, + default=2, + help="Number of restart cycles in the lr scheduler." + ) parser.add_argument( "--use_ema", action="store_true", @@ -167,7 +173,7 @@ def parse_args(): parser.add_argument( "--ema_power", type=float, - default=6 / 7 + default=9 / 10 ) parser.add_argument( "--ema_max_decay", @@ -296,6 +302,13 @@ def parse_args(): return args +def save_args(basepath: Path, args, extra={}): + info = {"args": vars(args)} + info["args"].update(extra) + with open(basepath.joinpath("args.json"), "w") as f: + json.dump(info, f, indent=4) + + def freeze_params(params): for param in params: param.requires_grad = False @@ -455,6 +468,8 @@ def main(): logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) + save_args(basepath, args) + # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) @@ -614,12 +629,20 @@ def main(): args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - ) + if args.lr_scheduler == "cosine_with_restarts": + lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=args.lr_cycles, + ) + else: + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, val_dataloader, lr_scheduler -- cgit v1.2.3-54-g00ecf