diff options
author | Volpeon <git@volpeon.ink> | 2022-10-16 19:00:08 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-16 19:00:08 +0200 |
commit | 25ba3b38e2c605bf90f838f99b3b30d489b48222 (patch) | |
tree | b65f0289f0b447fd76a974a8f4213fccfd408c2d /dreambooth.py | |
parent | Update (diff) | |
download | textual-inversion-diff-25ba3b38e2c605bf90f838f99b3b30d489b48222.tar.gz textual-inversion-diff-25ba3b38e2c605bf90f838f99b3b30d489b48222.tar.bz2 textual-inversion-diff-25ba3b38e2c605bf90f838f99b3b30d489b48222.zip |
Update
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/dreambooth.py b/dreambooth.py index 9e2645b..42d3980 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -112,7 +112,7 @@ def parse_args(): | |||
112 | parser.add_argument( | 112 | parser.add_argument( |
113 | "--max_train_steps", | 113 | "--max_train_steps", |
114 | type=int, | 114 | type=int, |
115 | default=2000, | 115 | default=1200, |
116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
117 | ) | 117 | ) |
118 | parser.add_argument( | 118 | parser.add_argument( |
@@ -129,7 +129,7 @@ def parse_args(): | |||
129 | parser.add_argument( | 129 | parser.add_argument( |
130 | "--learning_rate", | 130 | "--learning_rate", |
131 | type=float, | 131 | type=float, |
132 | default=5e-6, | 132 | default=5e-5, |
133 | help="Initial learning rate (after the potential warmup period) to use.", | 133 | help="Initial learning rate (after the potential warmup period) to use.", |
134 | ) | 134 | ) |
135 | parser.add_argument( | 135 | parser.add_argument( |
@@ -156,7 +156,7 @@ def parse_args(): | |||
156 | parser.add_argument( | 156 | parser.add_argument( |
157 | "--lr_cycles", | 157 | "--lr_cycles", |
158 | type=int, | 158 | type=int, |
159 | default=2, | 159 | default=None, |
160 | help="Number of restart cycles in the lr scheduler." | 160 | help="Number of restart cycles in the lr scheduler." |
161 | ) | 161 | ) |
162 | parser.add_argument( | 162 | parser.add_argument( |
@@ -628,13 +628,15 @@ def main(): | |||
628 | if args.max_train_steps is None: | 628 | if args.max_train_steps is None: |
629 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 629 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
630 | overrode_max_train_steps = True | 630 | overrode_max_train_steps = True |
631 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | ||
631 | 632 | ||
632 | if args.lr_scheduler == "cosine_with_restarts": | 633 | if args.lr_scheduler == "cosine_with_restarts": |
633 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 634 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
634 | optimizer=optimizer, | 635 | optimizer=optimizer, |
635 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 636 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
636 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 637 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
637 | num_cycles=args.lr_cycles, | 638 | num_cycles=args.lr_cycles or math.ceil( |
639 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), | ||
638 | ) | 640 | ) |
639 | else: | 641 | else: |
640 | lr_scheduler = get_scheduler( | 642 | lr_scheduler = get_scheduler( |