diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/train_ti.py b/train_ti.py index b182a72..83043ad 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -258,6 +258,12 @@ def parse_args(): | |||
| 258 | help="Number of steps for the warmup in the lr scheduler." | 258 | help="Number of steps for the warmup in the lr scheduler." |
| 259 | ) | 259 | ) |
| 260 | parser.add_argument( | 260 | parser.add_argument( |
| 261 | "--lr_mid_point", | ||
| 262 | type=float, | ||
| 263 | default=0.3, | ||
| 264 | help="OneCycle schedule mid point." | ||
| 265 | ) | ||
| 266 | parser.add_argument( | ||
| 261 | "--lr_cycles", | 267 | "--lr_cycles", |
| 262 | type=int, | 268 | type=int, |
| 263 | default=None, | 269 | default=None, |
| @@ -790,6 +796,7 @@ def main(): | |||
| 790 | end_lr=1e3, | 796 | end_lr=1e3, |
| 791 | train_epochs=num_train_epochs, | 797 | train_epochs=num_train_epochs, |
| 792 | warmup_epochs=args.lr_warmup_epochs, | 798 | warmup_epochs=args.lr_warmup_epochs, |
| 799 | mid_point=args.lr_mid_point, | ||
| 793 | ) | 800 | ) |
| 794 | 801 | ||
| 795 | metrics = trainer( | 802 | metrics = trainer( |
