summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--dreambooth.py10
-rw-r--r--dreambooth_plus.py8
-rw-r--r--textual_inversion.py6
3 files changed, 15 insertions, 9 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 9e2645b..42d3980 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -112,7 +112,7 @@ def parse_args():
112 parser.add_argument( 112 parser.add_argument(
113 "--max_train_steps", 113 "--max_train_steps",
114 type=int, 114 type=int,
115 default=2000, 115 default=1200,
116 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 116 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
117 ) 117 )
118 parser.add_argument( 118 parser.add_argument(
@@ -129,7 +129,7 @@ def parse_args():
129 parser.add_argument( 129 parser.add_argument(
130 "--learning_rate", 130 "--learning_rate",
131 type=float, 131 type=float,
132 default=5e-6, 132 default=5e-5,
133 help="Initial learning rate (after the potential warmup period) to use.", 133 help="Initial learning rate (after the potential warmup period) to use.",
134 ) 134 )
135 parser.add_argument( 135 parser.add_argument(
@@ -156,7 +156,7 @@ def parse_args():
156 parser.add_argument( 156 parser.add_argument(
157 "--lr_cycles", 157 "--lr_cycles",
158 type=int, 158 type=int,
159 default=2, 159 default=None,
160 help="Number of restart cycles in the lr scheduler." 160 help="Number of restart cycles in the lr scheduler."
161 ) 161 )
162 parser.add_argument( 162 parser.add_argument(
@@ -628,13 +628,15 @@ def main():
628 if args.max_train_steps is None: 628 if args.max_train_steps is None:
629 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 629 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
630 overrode_max_train_steps = True 630 overrode_max_train_steps = True
631 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
631 632
632 if args.lr_scheduler == "cosine_with_restarts": 633 if args.lr_scheduler == "cosine_with_restarts":
633 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 634 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
634 optimizer=optimizer, 635 optimizer=optimizer,
635 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 636 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
636 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 637 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
637 num_cycles=args.lr_cycles, 638 num_cycles=args.lr_cycles or math.ceil(
639 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2),
638 ) 640 )
639 else: 641 else:
640 lr_scheduler = get_scheduler( 642 lr_scheduler = get_scheduler(
diff --git a/dreambooth_plus.py b/dreambooth_plus.py
index 42994af..73225de 100644
--- a/dreambooth_plus.py
+++ b/dreambooth_plus.py
@@ -135,7 +135,7 @@ def parse_args():
135 parser.add_argument( 135 parser.add_argument(
136 "--learning_rate_unet", 136 "--learning_rate_unet",
137 type=float, 137 type=float,
138 default=5e-6, 138 default=5e-5,
139 help="Initial learning rate (after the potential warmup period) to use.", 139 help="Initial learning rate (after the potential warmup period) to use.",
140 ) 140 )
141 parser.add_argument( 141 parser.add_argument(
@@ -168,7 +168,7 @@ def parse_args():
168 parser.add_argument( 168 parser.add_argument(
169 "--lr_cycles", 169 "--lr_cycles",
170 type=int, 170 type=int,
171 default=2, 171 default=None,
172 help="Number of restart cycles in the lr scheduler." 172 help="Number of restart cycles in the lr scheduler."
173 ) 173 )
174 parser.add_argument( 174 parser.add_argument(
@@ -721,13 +721,15 @@ def main():
721 if args.max_train_steps is None: 721 if args.max_train_steps is None:
722 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 722 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
723 overrode_max_train_steps = True 723 overrode_max_train_steps = True
724 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
724 725
725 if args.lr_scheduler == "cosine_with_restarts": 726 if args.lr_scheduler == "cosine_with_restarts":
726 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 727 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
727 optimizer=optimizer, 728 optimizer=optimizer,
728 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 729 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
729 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 730 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
730 num_cycles=args.lr_cycles, 731 num_cycles=args.lr_cycles or math.ceil(
732 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2),
731 ) 733 )
732 else: 734 else:
733 lr_scheduler = get_scheduler( 735 lr_scheduler = get_scheduler(
diff --git a/textual_inversion.py b/textual_inversion.py
index 61c96b7..0d5a742 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -161,7 +161,7 @@ def parse_args():
161 parser.add_argument( 161 parser.add_argument(
162 "--lr_cycles", 162 "--lr_cycles",
163 type=int, 163 type=int,
164 default=15, 164 default=None,
165 help="Number of restart cycles in the lr scheduler." 165 help="Number of restart cycles in the lr scheduler."
166 ) 166 )
167 parser.add_argument( 167 parser.add_argument(
@@ -665,13 +665,15 @@ def main():
665 if args.max_train_steps is None: 665 if args.max_train_steps is None:
666 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 666 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
667 overrode_max_train_steps = True 667 overrode_max_train_steps = True
668 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
668 669
669 if args.lr_scheduler == "cosine_with_restarts": 670 if args.lr_scheduler == "cosine_with_restarts":
670 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 671 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
671 optimizer=optimizer, 672 optimizer=optimizer,
672 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 673 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
673 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 674 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
674 num_cycles=args.lr_cycles, 675 num_cycles=args.lr_cycles or math.ceil(
676 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2),
675 ) 677 )
676 else: 678 else:
677 lr_scheduler = get_scheduler( 679 lr_scheduler = get_scheduler(