diff options
author | Volpeon <git@volpeon.ink> | 2022-10-16 19:00:08 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-16 19:00:08 +0200 |
commit | 25ba3b38e2c605bf90f838f99b3b30d489b48222 (patch) | |
tree | b65f0289f0b447fd76a974a8f4213fccfd408c2d /dreambooth_plus.py | |
parent | Update (diff) | |
download | textual-inversion-diff-25ba3b38e2c605bf90f838f99b3b30d489b48222.tar.gz textual-inversion-diff-25ba3b38e2c605bf90f838f99b3b30d489b48222.tar.bz2 textual-inversion-diff-25ba3b38e2c605bf90f838f99b3b30d489b48222.zip |
Update
Diffstat (limited to 'dreambooth_plus.py')
-rw-r--r-- | dreambooth_plus.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index 42994af..73225de 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
@@ -135,7 +135,7 @@ def parse_args(): | |||
135 | parser.add_argument( | 135 | parser.add_argument( |
136 | "--learning_rate_unet", | 136 | "--learning_rate_unet", |
137 | type=float, | 137 | type=float, |
138 | default=5e-6, | 138 | default=5e-5, |
139 | help="Initial learning rate (after the potential warmup period) to use.", | 139 | help="Initial learning rate (after the potential warmup period) to use.", |
140 | ) | 140 | ) |
141 | parser.add_argument( | 141 | parser.add_argument( |
@@ -168,7 +168,7 @@ def parse_args(): | |||
168 | parser.add_argument( | 168 | parser.add_argument( |
169 | "--lr_cycles", | 169 | "--lr_cycles", |
170 | type=int, | 170 | type=int, |
171 | default=2, | 171 | default=None, |
172 | help="Number of restart cycles in the lr scheduler." | 172 | help="Number of restart cycles in the lr scheduler." |
173 | ) | 173 | ) |
174 | parser.add_argument( | 174 | parser.add_argument( |
@@ -721,13 +721,15 @@ def main(): | |||
721 | if args.max_train_steps is None: | 721 | if args.max_train_steps is None: |
722 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 722 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
723 | overrode_max_train_steps = True | 723 | overrode_max_train_steps = True |
724 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | ||
724 | 725 | ||
725 | if args.lr_scheduler == "cosine_with_restarts": | 726 | if args.lr_scheduler == "cosine_with_restarts": |
726 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 727 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
727 | optimizer=optimizer, | 728 | optimizer=optimizer, |
728 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 729 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
729 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 730 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
730 | num_cycles=args.lr_cycles, | 731 | num_cycles=args.lr_cycles or math.ceil( |
732 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), | ||
731 | ) | 733 | ) |
732 | else: | 734 | else: |
733 | lr_scheduler = get_scheduler( | 735 | lr_scheduler = get_scheduler( |