summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 9e2645b..42d3980 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -112,7 +112,7 @@ def parse_args():
112 parser.add_argument( 112 parser.add_argument(
113 "--max_train_steps", 113 "--max_train_steps",
114 type=int, 114 type=int,
115 default=2000, 115 default=1200,
116 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 116 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
117 ) 117 )
118 parser.add_argument( 118 parser.add_argument(
@@ -129,7 +129,7 @@ def parse_args():
129 parser.add_argument( 129 parser.add_argument(
130 "--learning_rate", 130 "--learning_rate",
131 type=float, 131 type=float,
132 default=5e-6, 132 default=5e-5,
133 help="Initial learning rate (after the potential warmup period) to use.", 133 help="Initial learning rate (after the potential warmup period) to use.",
134 ) 134 )
135 parser.add_argument( 135 parser.add_argument(
@@ -156,7 +156,7 @@ def parse_args():
156 parser.add_argument( 156 parser.add_argument(
157 "--lr_cycles", 157 "--lr_cycles",
158 type=int, 158 type=int,
159 default=2, 159 default=None,
160 help="Number of restart cycles in the lr scheduler." 160 help="Number of restart cycles in the lr scheduler."
161 ) 161 )
162 parser.add_argument( 162 parser.add_argument(
@@ -628,13 +628,15 @@ def main():
628 if args.max_train_steps is None: 628 if args.max_train_steps is None:
629 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 629 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
630 overrode_max_train_steps = True 630 overrode_max_train_steps = True
631 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
631 632
632 if args.lr_scheduler == "cosine_with_restarts": 633 if args.lr_scheduler == "cosine_with_restarts":
633 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 634 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
634 optimizer=optimizer, 635 optimizer=optimizer,
635 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 636 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
636 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 637 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
637 num_cycles=args.lr_cycles, 638 num_cycles=args.lr_cycles or math.ceil(
639 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2),
638 ) 640 )
639 else: 641 else:
640 lr_scheduler = get_scheduler( 642 lr_scheduler = get_scheduler(