summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 0634376..2c884d2 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -240,6 +240,12 @@ def parse_args():
240 help="Number of steps for the warmup in the lr scheduler." 240 help="Number of steps for the warmup in the lr scheduler."
241 ) 241 )
242 parser.add_argument( 242 parser.add_argument(
243 "--lr_mid_point",
244 type=float,
245 default=0.3,
246 help="OneCycle schedule mid point."
247 )
248 parser.add_argument(
243 "--lr_cycles", 249 "--lr_cycles",
244 type=int, 250 type=int,
245 default=None, 251 default=None,
@@ -634,6 +640,7 @@ def main():
634 end_lr=1e2, 640 end_lr=1e2,
635 train_epochs=num_train_epochs, 641 train_epochs=num_train_epochs,
636 warmup_epochs=args.lr_warmup_epochs, 642 warmup_epochs=args.lr_warmup_epochs,
643 mid_point=args.lr_mid_point,
637 ) 644 )
638 645
639 metrics = trainer( 646 metrics = trainer(