diff options
Diffstat (limited to 'dreambooth_plus.py')
-rw-r--r-- | dreambooth_plus.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index fa3a22b..06ff45b 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
@@ -125,7 +125,7 @@ def parse_args(): | |||
125 | parser.add_argument( | 125 | parser.add_argument( |
126 | "--max_train_steps", | 126 | "--max_train_steps", |
127 | type=int, | 127 | type=int, |
128 | default=1400, | 128 | default=2400, |
129 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 129 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
130 | ) | 130 | ) |
131 | parser.add_argument( | 131 | parser.add_argument( |
@@ -752,8 +752,8 @@ def main(): | |||
752 | optimizer=optimizer, | 752 | optimizer=optimizer, |
753 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 753 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
754 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 754 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
755 | num_cycles=args.lr_cycles or math.ceil( | 755 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( |
756 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), | 756 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), |
757 | ) | 757 | ) |
758 | else: | 758 | else: |
759 | lr_scheduler = get_scheduler( | 759 | lr_scheduler = get_scheduler( |