diff options
author | Volpeon <git@volpeon.ink> | 2022-10-16 19:00:08 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-16 19:00:08 +0200 |
commit | 25ba3b38e2c605bf90f838f99b3b30d489b48222 (patch) | |
tree | b65f0289f0b447fd76a974a8f4213fccfd408c2d /textual_inversion.py | |
parent | Update (diff) | |
download | textual-inversion-diff-25ba3b38e2c605bf90f838f99b3b30d489b48222.tar.gz textual-inversion-diff-25ba3b38e2c605bf90f838f99b3b30d489b48222.tar.bz2 textual-inversion-diff-25ba3b38e2c605bf90f838f99b3b30d489b48222.zip |
Update
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 61c96b7..0d5a742 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -161,7 +161,7 @@ def parse_args(): | |||
161 | parser.add_argument( | 161 | parser.add_argument( |
162 | "--lr_cycles", | 162 | "--lr_cycles", |
163 | type=int, | 163 | type=int, |
164 | default=15, | 164 | default=None, |
165 | help="Number of restart cycles in the lr scheduler." | 165 | help="Number of restart cycles in the lr scheduler." |
166 | ) | 166 | ) |
167 | parser.add_argument( | 167 | parser.add_argument( |
@@ -665,13 +665,15 @@ def main(): | |||
665 | if args.max_train_steps is None: | 665 | if args.max_train_steps is None: |
666 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 666 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
667 | overrode_max_train_steps = True | 667 | overrode_max_train_steps = True |
668 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | ||
668 | 669 | ||
669 | if args.lr_scheduler == "cosine_with_restarts": | 670 | if args.lr_scheduler == "cosine_with_restarts": |
670 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 671 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
671 | optimizer=optimizer, | 672 | optimizer=optimizer, |
672 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 673 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
673 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 674 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
674 | num_cycles=args.lr_cycles, | 675 | num_cycles=args.lr_cycles or math.ceil( |
676 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), | ||
675 | ) | 677 | ) |
676 | else: | 678 | else: |
677 | lr_scheduler = get_scheduler( | 679 | lr_scheduler = get_scheduler( |