From 25ba3b38e2c605bf90f838f99b3b30d489b48222 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 16 Oct 2022 19:00:08 +0200 Subject: Update --- dreambooth_plus.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'dreambooth_plus.py') diff --git a/dreambooth_plus.py b/dreambooth_plus.py index 42994af..73225de 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py @@ -135,7 +135,7 @@ def parse_args(): parser.add_argument( "--learning_rate_unet", type=float, - default=5e-6, + default=5e-5, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -168,7 +168,7 @@ def parse_args(): parser.add_argument( "--lr_cycles", type=int, - default=2, + default=None, help="Number of restart cycles in the lr scheduler." ) parser.add_argument( @@ -721,13 +721,15 @@ def main(): if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True + num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) if args.lr_scheduler == "cosine_with_restarts": lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles=args.lr_cycles, + num_cycles=args.lr_cycles or math.ceil( + ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), ) else: lr_scheduler = get_scheduler( -- cgit v1.2.3-54-g00ecf