diff options
Diffstat (limited to 'dreambooth_plus.py')
| -rw-r--r-- | dreambooth_plus.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index fa3a22b..06ff45b 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
| @@ -125,7 +125,7 @@ def parse_args(): | |||
| 125 | parser.add_argument( | 125 | parser.add_argument( |
| 126 | "--max_train_steps", | 126 | "--max_train_steps", |
| 127 | type=int, | 127 | type=int, |
| 128 | default=1400, | 128 | default=2400, |
| 129 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 129 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 130 | ) | 130 | ) |
| 131 | parser.add_argument( | 131 | parser.add_argument( |
| @@ -752,8 +752,8 @@ def main(): | |||
| 752 | optimizer=optimizer, | 752 | optimizer=optimizer, |
| 753 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 753 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
| 754 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 754 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| 755 | num_cycles=args.lr_cycles or math.ceil( | 755 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( |
| 756 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), | 756 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), |
| 757 | ) | 757 | ) |
| 758 | else: | 758 | else: |
| 759 | lr_scheduler = get_scheduler( | 759 | lr_scheduler = get_scheduler( |
