summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 61c96b7..0d5a742 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -161,7 +161,7 @@ def parse_args():
161 parser.add_argument( 161 parser.add_argument(
162 "--lr_cycles", 162 "--lr_cycles",
163 type=int, 163 type=int,
164 default=15, 164 default=None,
165 help="Number of restart cycles in the lr scheduler." 165 help="Number of restart cycles in the lr scheduler."
166 ) 166 )
167 parser.add_argument( 167 parser.add_argument(
@@ -665,13 +665,15 @@ def main():
665 if args.max_train_steps is None: 665 if args.max_train_steps is None:
666 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 666 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
667 overrode_max_train_steps = True 667 overrode_max_train_steps = True
668 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
668 669
669 if args.lr_scheduler == "cosine_with_restarts": 670 if args.lr_scheduler == "cosine_with_restarts":
670 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 671 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
671 optimizer=optimizer, 672 optimizer=optimizer,
672 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 673 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
673 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 674 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
674 num_cycles=args.lr_cycles, 675 num_cycles=args.lr_cycles or math.ceil(
676 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2),
675 ) 677 )
676 else: 678 else:
677 lr_scheduler = get_scheduler( 679 lr_scheduler = get_scheduler(