summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/train_ti.py b/train_ti.py
index b182a72..83043ad 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -258,6 +258,12 @@ def parse_args():
258 help="Number of steps for the warmup in the lr scheduler." 258 help="Number of steps for the warmup in the lr scheduler."
259 ) 259 )
260 parser.add_argument( 260 parser.add_argument(
261 "--lr_mid_point",
262 type=float,
263 default=0.3,
264 help="OneCycle schedule mid point."
265 )
266 parser.add_argument(
261 "--lr_cycles", 267 "--lr_cycles",
262 type=int, 268 type=int,
263 default=None, 269 default=None,
@@ -790,6 +796,7 @@ def main():
790 end_lr=1e3, 796 end_lr=1e3,
791 train_epochs=num_train_epochs, 797 train_epochs=num_train_epochs,
792 warmup_epochs=args.lr_warmup_epochs, 798 warmup_epochs=args.lr_warmup_epochs,
799 mid_point=args.lr_mid_point,
793 ) 800 )
794 801
795 metrics = trainer( 802 metrics = trainer(