diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/train_lora.py b/train_lora.py index d89b18d..59beb09 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -272,6 +272,12 @@ def parse_args(): | |||
272 | help="Number of steps for the warmup in the lr scheduler." | 272 | help="Number of steps for the warmup in the lr scheduler." |
273 | ) | 273 | ) |
274 | parser.add_argument( | 274 | parser.add_argument( |
275 | "--lr_mid_point", | ||
276 | type=float, | ||
277 | default=0.3, | ||
278 | help="OneCycle schedule mid point." | ||
279 | ) | ||
280 | parser.add_argument( | ||
275 | "--lr_cycles", | 281 | "--lr_cycles", |
276 | type=int, | 282 | type=int, |
277 | default=None, | 283 | default=None, |
@@ -662,6 +668,7 @@ def main(): | |||
662 | end_lr=1e2, | 668 | end_lr=1e2, |
663 | train_epochs=num_train_epochs, | 669 | train_epochs=num_train_epochs, |
664 | warmup_epochs=args.lr_warmup_epochs, | 670 | warmup_epochs=args.lr_warmup_epochs, |
671 | mid_point=args.lr_mid_point, | ||
665 | ) | 672 | ) |
666 | 673 | ||
667 | metrics = trainer( | 674 | metrics = trainer( |