summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-16 19:00:08 +0200
committerVolpeon <git@volpeon.ink>2022-10-16 19:00:08 +0200
commit25ba3b38e2c605bf90f838f99b3b30d489b48222 (patch)
treeb65f0289f0b447fd76a974a8f4213fccfd408c2d /textual_inversion.py
parentUpdate (diff)
downloadtextual-inversion-diff-25ba3b38e2c605bf90f838f99b3b30d489b48222.tar.gz
textual-inversion-diff-25ba3b38e2c605bf90f838f99b3b30d489b48222.tar.bz2
textual-inversion-diff-25ba3b38e2c605bf90f838f99b3b30d489b48222.zip
Update
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 61c96b7..0d5a742 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -161,7 +161,7 @@ def parse_args():
161 parser.add_argument( 161 parser.add_argument(
162 "--lr_cycles", 162 "--lr_cycles",
163 type=int, 163 type=int,
164 default=15, 164 default=None,
165 help="Number of restart cycles in the lr scheduler." 165 help="Number of restart cycles in the lr scheduler."
166 ) 166 )
167 parser.add_argument( 167 parser.add_argument(
@@ -665,13 +665,15 @@ def main():
665 if args.max_train_steps is None: 665 if args.max_train_steps is None:
666 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 666 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
667 overrode_max_train_steps = True 667 overrode_max_train_steps = True
668 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
668 669
669 if args.lr_scheduler == "cosine_with_restarts": 670 if args.lr_scheduler == "cosine_with_restarts":
670 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 671 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
671 optimizer=optimizer, 672 optimizer=optimizer,
672 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 673 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
673 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 674 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
674 num_cycles=args.lr_cycles, 675 num_cycles=args.lr_cycles or math.ceil(
676 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2),
675 ) 677 )
676 else: 678 else:
677 lr_scheduler = get_scheduler( 679 lr_scheduler = get_scheduler(