From 25ba3b38e2c605bf90f838f99b3b30d489b48222 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 16 Oct 2022 19:00:08 +0200 Subject: Update --- textual_inversion.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index 61c96b7..0d5a742 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -161,7 +161,7 @@ def parse_args(): parser.add_argument( "--lr_cycles", type=int, - default=15, + default=None, help="Number of restart cycles in the lr scheduler." ) parser.add_argument( @@ -665,13 +665,15 @@ def main(): if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True + num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) if args.lr_scheduler == "cosine_with_restarts": lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles=args.lr_cycles, + num_cycles=args.lr_cycles or math.ceil( + ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), ) else: lr_scheduler = get_scheduler( -- cgit v1.2.3-54-g00ecf