diff options
| -rw-r--r-- | dreambooth.py | 10 | ||||
| -rw-r--r-- | dreambooth_plus.py | 8 | ||||
| -rw-r--r-- | textual_inversion.py | 6 |
3 files changed, 15 insertions, 9 deletions
diff --git a/dreambooth.py b/dreambooth.py index 9e2645b..42d3980 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -112,7 +112,7 @@ def parse_args(): | |||
| 112 | parser.add_argument( | 112 | parser.add_argument( |
| 113 | "--max_train_steps", | 113 | "--max_train_steps", |
| 114 | type=int, | 114 | type=int, |
| 115 | default=2000, | 115 | default=1200, |
| 116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 117 | ) | 117 | ) |
| 118 | parser.add_argument( | 118 | parser.add_argument( |
| @@ -129,7 +129,7 @@ def parse_args(): | |||
| 129 | parser.add_argument( | 129 | parser.add_argument( |
| 130 | "--learning_rate", | 130 | "--learning_rate", |
| 131 | type=float, | 131 | type=float, |
| 132 | default=5e-6, | 132 | default=5e-5, |
| 133 | help="Initial learning rate (after the potential warmup period) to use.", | 133 | help="Initial learning rate (after the potential warmup period) to use.", |
| 134 | ) | 134 | ) |
| 135 | parser.add_argument( | 135 | parser.add_argument( |
| @@ -156,7 +156,7 @@ def parse_args(): | |||
| 156 | parser.add_argument( | 156 | parser.add_argument( |
| 157 | "--lr_cycles", | 157 | "--lr_cycles", |
| 158 | type=int, | 158 | type=int, |
| 159 | default=2, | 159 | default=None, |
| 160 | help="Number of restart cycles in the lr scheduler." | 160 | help="Number of restart cycles in the lr scheduler." |
| 161 | ) | 161 | ) |
| 162 | parser.add_argument( | 162 | parser.add_argument( |
| @@ -628,13 +628,15 @@ def main(): | |||
| 628 | if args.max_train_steps is None: | 628 | if args.max_train_steps is None: |
| 629 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 629 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
| 630 | overrode_max_train_steps = True | 630 | overrode_max_train_steps = True |
| 631 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | ||
| 631 | 632 | ||
| 632 | if args.lr_scheduler == "cosine_with_restarts": | 633 | if args.lr_scheduler == "cosine_with_restarts": |
| 633 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 634 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
| 634 | optimizer=optimizer, | 635 | optimizer=optimizer, |
| 635 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 636 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
| 636 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 637 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| 637 | num_cycles=args.lr_cycles, | 638 | num_cycles=args.lr_cycles or math.ceil( |
| 639 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), | ||
| 638 | ) | 640 | ) |
| 639 | else: | 641 | else: |
| 640 | lr_scheduler = get_scheduler( | 642 | lr_scheduler = get_scheduler( |
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index 42994af..73225de 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
| @@ -135,7 +135,7 @@ def parse_args(): | |||
| 135 | parser.add_argument( | 135 | parser.add_argument( |
| 136 | "--learning_rate_unet", | 136 | "--learning_rate_unet", |
| 137 | type=float, | 137 | type=float, |
| 138 | default=5e-6, | 138 | default=5e-5, |
| 139 | help="Initial learning rate (after the potential warmup period) to use.", | 139 | help="Initial learning rate (after the potential warmup period) to use.", |
| 140 | ) | 140 | ) |
| 141 | parser.add_argument( | 141 | parser.add_argument( |
| @@ -168,7 +168,7 @@ def parse_args(): | |||
| 168 | parser.add_argument( | 168 | parser.add_argument( |
| 169 | "--lr_cycles", | 169 | "--lr_cycles", |
| 170 | type=int, | 170 | type=int, |
| 171 | default=2, | 171 | default=None, |
| 172 | help="Number of restart cycles in the lr scheduler." | 172 | help="Number of restart cycles in the lr scheduler." |
| 173 | ) | 173 | ) |
| 174 | parser.add_argument( | 174 | parser.add_argument( |
| @@ -721,13 +721,15 @@ def main(): | |||
| 721 | if args.max_train_steps is None: | 721 | if args.max_train_steps is None: |
| 722 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 722 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
| 723 | overrode_max_train_steps = True | 723 | overrode_max_train_steps = True |
| 724 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | ||
| 724 | 725 | ||
| 725 | if args.lr_scheduler == "cosine_with_restarts": | 726 | if args.lr_scheduler == "cosine_with_restarts": |
| 726 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 727 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
| 727 | optimizer=optimizer, | 728 | optimizer=optimizer, |
| 728 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 729 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
| 729 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 730 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| 730 | num_cycles=args.lr_cycles, | 731 | num_cycles=args.lr_cycles or math.ceil( |
| 732 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), | ||
| 731 | ) | 733 | ) |
| 732 | else: | 734 | else: |
| 733 | lr_scheduler = get_scheduler( | 735 | lr_scheduler = get_scheduler( |
diff --git a/textual_inversion.py b/textual_inversion.py index 61c96b7..0d5a742 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -161,7 +161,7 @@ def parse_args(): | |||
| 161 | parser.add_argument( | 161 | parser.add_argument( |
| 162 | "--lr_cycles", | 162 | "--lr_cycles", |
| 163 | type=int, | 163 | type=int, |
| 164 | default=15, | 164 | default=None, |
| 165 | help="Number of restart cycles in the lr scheduler." | 165 | help="Number of restart cycles in the lr scheduler." |
| 166 | ) | 166 | ) |
| 167 | parser.add_argument( | 167 | parser.add_argument( |
| @@ -665,13 +665,15 @@ def main(): | |||
| 665 | if args.max_train_steps is None: | 665 | if args.max_train_steps is None: |
| 666 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 666 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
| 667 | overrode_max_train_steps = True | 667 | overrode_max_train_steps = True |
| 668 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | ||
| 668 | 669 | ||
| 669 | if args.lr_scheduler == "cosine_with_restarts": | 670 | if args.lr_scheduler == "cosine_with_restarts": |
| 670 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 671 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
| 671 | optimizer=optimizer, | 672 | optimizer=optimizer, |
| 672 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 673 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
| 673 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 674 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| 674 | num_cycles=args.lr_cycles, | 675 | num_cycles=args.lr_cycles or math.ceil( |
| 676 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), | ||
| 675 | ) | 677 | ) |
| 676 | else: | 678 | else: |
| 677 | lr_scheduler = get_scheduler( | 679 | lr_scheduler = get_scheduler( |
