From 25ba3b38e2c605bf90f838f99b3b30d489b48222 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 16 Oct 2022 19:00:08 +0200 Subject: Update --- dreambooth.py | 10 ++++++---- dreambooth_plus.py | 8 +++++--- textual_inversion.py | 6 ++++-- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/dreambooth.py b/dreambooth.py index 9e2645b..42d3980 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -112,7 +112,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=2000, + default=1200, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -129,7 +129,7 @@ def parse_args(): parser.add_argument( "--learning_rate", type=float, - default=5e-6, + default=5e-5, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -156,7 +156,7 @@ def parse_args(): parser.add_argument( "--lr_cycles", type=int, - default=2, + default=None, help="Number of restart cycles in the lr scheduler." ) parser.add_argument( @@ -628,13 +628,15 @@ def main(): if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True + num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) if args.lr_scheduler == "cosine_with_restarts": lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles=args.lr_cycles, + num_cycles=args.lr_cycles or math.ceil( + ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), ) else: lr_scheduler = get_scheduler( diff --git a/dreambooth_plus.py b/dreambooth_plus.py index 42994af..73225de 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py @@ -135,7 +135,7 @@ def parse_args(): parser.add_argument( "--learning_rate_unet", type=float, - default=5e-6, + default=5e-5, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -168,7 +168,7 @@ def parse_args(): parser.add_argument( "--lr_cycles", type=int, - default=2, + default=None, help="Number of restart cycles in the lr scheduler." ) parser.add_argument( @@ -721,13 +721,15 @@ def main(): if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True + num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) if args.lr_scheduler == "cosine_with_restarts": lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles=args.lr_cycles, + num_cycles=args.lr_cycles or math.ceil( + ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), ) else: lr_scheduler = get_scheduler( diff --git a/textual_inversion.py b/textual_inversion.py index 61c96b7..0d5a742 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -161,7 +161,7 @@ def parse_args(): parser.add_argument( "--lr_cycles", type=int, - default=15, + default=None, help="Number of restart cycles in the lr scheduler." ) parser.add_argument( @@ -665,13 +665,15 @@ def main(): if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True + num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) if args.lr_scheduler == "cosine_with_restarts": lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles=args.lr_cycles, + num_cycles=args.lr_cycles or math.ceil( + ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), ) else: lr_scheduler = get_scheduler( -- cgit v1.2.3-70-g09d2