From fcbc11be99c011ab1003451ef72c95ca587902d8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 15 Oct 2022 18:42:27 +0200 Subject: Update --- dreambooth_plus.py | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) (limited to 'dreambooth_plus.py') diff --git a/dreambooth_plus.py b/dreambooth_plus.py index b5ec2fc..eeee424 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py @@ -16,7 +16,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel -from diffusers.optimization import get_scheduler +from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup from diffusers.training_utils import EMAModel from PIL import Image from tqdm.auto import tqdm @@ -118,7 +118,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=2300, + default=1300, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -317,6 +317,13 @@ def parse_args(): return args +def save_args(basepath: Path, args, extra={}): + info = {"args": vars(args)} + info["args"].update(extra) + with open(basepath.joinpath("args.json"), "w") as f: + json.dump(info, f, indent=4) + + def freeze_params(params): for param in params: param.requires_grad = False @@ -503,6 +510,8 @@ def main(): logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) + save_args(basepath, args) + # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) @@ -706,12 +715,21 @@ def main(): args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - ) + if args.lr_scheduler == "cosine_with_restarts": + lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=num_update_steps_per_epoch, + ) + else: + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler -- cgit v1.2.3-54-g00ecf