summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/train_lora.py b/train_lora.py
index d89b18d..59beb09 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -272,6 +272,12 @@ def parse_args():
272 help="Number of steps for the warmup in the lr scheduler." 272 help="Number of steps for the warmup in the lr scheduler."
273 ) 273 )
274 parser.add_argument( 274 parser.add_argument(
275 "--lr_mid_point",
276 type=float,
277 default=0.3,
278 help="OneCycle schedule mid point."
279 )
280 parser.add_argument(
275 "--lr_cycles", 281 "--lr_cycles",
276 type=int, 282 type=int,
277 default=None, 283 default=None,
@@ -662,6 +668,7 @@ def main():
662 end_lr=1e2, 668 end_lr=1e2,
663 train_epochs=num_train_epochs, 669 train_epochs=num_train_epochs,
664 warmup_epochs=args.lr_warmup_epochs, 670 warmup_epochs=args.lr_warmup_epochs,
671 mid_point=args.lr_mid_point,
665 ) 672 )
666 673
667 metrics = trainer( 674 metrics = trainer(