diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 0634376..2c884d2 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -240,6 +240,12 @@ def parse_args(): | |||
240 | help="Number of steps for the warmup in the lr scheduler." | 240 | help="Number of steps for the warmup in the lr scheduler." |
241 | ) | 241 | ) |
242 | parser.add_argument( | 242 | parser.add_argument( |
243 | "--lr_mid_point", | ||
244 | type=float, | ||
245 | default=0.3, | ||
246 | help="OneCycle schedule mid point." | ||
247 | ) | ||
248 | parser.add_argument( | ||
243 | "--lr_cycles", | 249 | "--lr_cycles", |
244 | type=int, | 250 | type=int, |
245 | default=None, | 251 | default=None, |
@@ -634,6 +640,7 @@ def main(): | |||
634 | end_lr=1e2, | 640 | end_lr=1e2, |
635 | train_epochs=num_train_epochs, | 641 | train_epochs=num_train_epochs, |
636 | warmup_epochs=args.lr_warmup_epochs, | 642 | warmup_epochs=args.lr_warmup_epochs, |
643 | mid_point=args.lr_mid_point, | ||
637 | ) | 644 | ) |
638 | 645 | ||
639 | metrics = trainer( | 646 | metrics = trainer( |