diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-16 19:00:08 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-16 19:00:08 +0200 |
| commit | 25ba3b38e2c605bf90f838f99b3b30d489b48222 (patch) | |
| tree | b65f0289f0b447fd76a974a8f4213fccfd408c2d /textual_inversion.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-25ba3b38e2c605bf90f838f99b3b30d489b48222.tar.gz textual-inversion-diff-25ba3b38e2c605bf90f838f99b3b30d489b48222.tar.bz2 textual-inversion-diff-25ba3b38e2c605bf90f838f99b3b30d489b48222.zip | |
Update
Diffstat (limited to 'textual_inversion.py')
| -rw-r--r-- | textual_inversion.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 61c96b7..0d5a742 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -161,7 +161,7 @@ def parse_args(): | |||
| 161 | parser.add_argument( | 161 | parser.add_argument( |
| 162 | "--lr_cycles", | 162 | "--lr_cycles", |
| 163 | type=int, | 163 | type=int, |
| 164 | default=15, | 164 | default=None, |
| 165 | help="Number of restart cycles in the lr scheduler." | 165 | help="Number of restart cycles in the lr scheduler." |
| 166 | ) | 166 | ) |
| 167 | parser.add_argument( | 167 | parser.add_argument( |
| @@ -665,13 +665,15 @@ def main(): | |||
| 665 | if args.max_train_steps is None: | 665 | if args.max_train_steps is None: |
| 666 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 666 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
| 667 | overrode_max_train_steps = True | 667 | overrode_max_train_steps = True |
| 668 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | ||
| 668 | 669 | ||
| 669 | if args.lr_scheduler == "cosine_with_restarts": | 670 | if args.lr_scheduler == "cosine_with_restarts": |
| 670 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 671 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
| 671 | optimizer=optimizer, | 672 | optimizer=optimizer, |
| 672 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 673 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
| 673 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 674 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| 674 | num_cycles=args.lr_cycles, | 675 | num_cycles=args.lr_cycles or math.ceil( |
| 676 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), | ||
| 675 | ) | 677 | ) |
| 676 | else: | 678 | else: |
| 677 | lr_scheduler = get_scheduler( | 679 | lr_scheduler = get_scheduler( |
