summaryrefslogtreecommitdiffstats
path: root/dreambooth_plus.py
diff options
context:
space:
mode:
Diffstat (limited to 'dreambooth_plus.py')
-rw-r--r--dreambooth_plus.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py
index 42994af..73225de 100644
--- a/dreambooth_plus.py
+++ b/dreambooth_plus.py
@@ -135,7 +135,7 @@ def parse_args():
135 parser.add_argument( 135 parser.add_argument(
136 "--learning_rate_unet", 136 "--learning_rate_unet",
137 type=float, 137 type=float,
138 default=5e-6, 138 default=5e-5,
139 help="Initial learning rate (after the potential warmup period) to use.", 139 help="Initial learning rate (after the potential warmup period) to use.",
140 ) 140 )
141 parser.add_argument( 141 parser.add_argument(
@@ -168,7 +168,7 @@ def parse_args():
168 parser.add_argument( 168 parser.add_argument(
169 "--lr_cycles", 169 "--lr_cycles",
170 type=int, 170 type=int,
171 default=2, 171 default=None,
172 help="Number of restart cycles in the lr scheduler." 172 help="Number of restart cycles in the lr scheduler."
173 ) 173 )
174 parser.add_argument( 174 parser.add_argument(
@@ -721,13 +721,15 @@ def main():
721 if args.max_train_steps is None: 721 if args.max_train_steps is None:
722 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 722 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
723 overrode_max_train_steps = True 723 overrode_max_train_steps = True
724 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
724 725
725 if args.lr_scheduler == "cosine_with_restarts": 726 if args.lr_scheduler == "cosine_with_restarts":
726 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 727 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
727 optimizer=optimizer, 728 optimizer=optimizer,
728 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 729 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
729 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 730 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
730 num_cycles=args.lr_cycles, 731 num_cycles=args.lr_cycles or math.ceil(
732 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2),
731 ) 733 )
732 else: 734 else:
733 lr_scheduler = get_scheduler( 735 lr_scheduler = get_scheduler(