diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/train_ti.py b/train_ti.py index b182a72..83043ad 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -258,6 +258,12 @@ def parse_args(): | |||
258 | help="Number of steps for the warmup in the lr scheduler." | 258 | help="Number of steps for the warmup in the lr scheduler." |
259 | ) | 259 | ) |
260 | parser.add_argument( | 260 | parser.add_argument( |
261 | "--lr_mid_point", | ||
262 | type=float, | ||
263 | default=0.3, | ||
264 | help="OneCycle schedule mid point." | ||
265 | ) | ||
266 | parser.add_argument( | ||
261 | "--lr_cycles", | 267 | "--lr_cycles", |
262 | type=int, | 268 | type=int, |
263 | default=None, | 269 | default=None, |
@@ -790,6 +796,7 @@ def main(): | |||
790 | end_lr=1e3, | 796 | end_lr=1e3, |
791 | train_epochs=num_train_epochs, | 797 | train_epochs=num_train_epochs, |
792 | warmup_epochs=args.lr_warmup_epochs, | 798 | warmup_epochs=args.lr_warmup_epochs, |
799 | mid_point=args.lr_mid_point, | ||
793 | ) | 800 | ) |
794 | 801 | ||
795 | metrics = trainer( | 802 | metrics = trainer( |