From 01fee7d37a116265edb0f16e0b2f75d2116eb9f6 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 4 Jan 2023 12:18:07 +0100 Subject: Various updates --- train_dreambooth.py | 8 ++++++++ 1 file changed, 8 insertions(+) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 5e6e35d..2e0696b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -268,6 +268,12 @@ def parse_args(): default=3, help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' ) + parser.add_argument( + "--lr_min_lr", + type=float, + default=None, + help="Minimum learning rate in the lr scheduler." + ) parser.add_argument( "--use_ema", action="store_true", @@ -799,6 +805,7 @@ def main(): warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps if args.lr_scheduler == "one_cycle": + lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate lr_scheduler = get_one_cycle_schedule( optimizer=optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, @@ -806,6 +813,7 @@ def main(): annealing=args.lr_annealing_func, warmup_exp=args.lr_warmup_exp, annealing_exp=args.lr_annealing_exp, + min_lr=lr_min_lr, ) elif args.lr_scheduler == "cosine_with_restarts": lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( -- cgit v1.2.3-54-g00ecf