diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 30 |
1 files changed, 27 insertions, 3 deletions
diff --git a/train_lora.py b/train_lora.py index cf73645..a0cd174 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -318,7 +318,7 @@ def parse_args(): | |||
318 | "--optimizer", | 318 | "--optimizer", |
319 | type=str, | 319 | type=str, |
320 | default="dadan", | 320 | default="dadan", |
321 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' | 321 | help='Optimizer to use ["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"]' |
322 | ) | 322 | ) |
323 | parser.add_argument( | 323 | parser.add_argument( |
324 | "--dadaptation_d0", | 324 | "--dadaptation_d0", |
@@ -329,13 +329,13 @@ def parse_args(): | |||
329 | parser.add_argument( | 329 | parser.add_argument( |
330 | "--adam_beta1", | 330 | "--adam_beta1", |
331 | type=float, | 331 | type=float, |
332 | default=0.9, | 332 | default=None, |
333 | help="The beta1 parameter for the Adam optimizer." | 333 | help="The beta1 parameter for the Adam optimizer." |
334 | ) | 334 | ) |
335 | parser.add_argument( | 335 | parser.add_argument( |
336 | "--adam_beta2", | 336 | "--adam_beta2", |
337 | type=float, | 337 | type=float, |
338 | default=0.999, | 338 | default=None, |
339 | help="The beta2 parameter for the Adam optimizer." | 339 | help="The beta2 parameter for the Adam optimizer." |
340 | ) | 340 | ) |
341 | parser.add_argument( | 341 | parser.add_argument( |
@@ -468,6 +468,18 @@ def parse_args(): | |||
468 | if args.output_dir is None: | 468 | if args.output_dir is None: |
469 | raise ValueError("You must specify --output_dir") | 469 | raise ValueError("You must specify --output_dir") |
470 | 470 | ||
471 | if args.adam_beta1 is None: | ||
472 | if args.optimizer in ('adam', 'adam8bit'): | ||
473 | args.adam_beta1 = 0.9 | ||
474 | elif args.optimizer == 'lion': | ||
475 | args.adam_beta1 = 0.95 | ||
476 | |||
477 | if args.adam_beta2 is None: | ||
478 | if args.optimizer in ('adam', 'adam8bit'): | ||
479 | args.adam_beta2 = 0.999 | ||
480 | elif args.optimizer == 'lion': | ||
481 | args.adam_beta2 = 0.98 | ||
482 | |||
471 | return args | 483 | return args |
472 | 484 | ||
473 | 485 | ||
@@ -568,6 +580,18 @@ def main(): | |||
568 | eps=args.adam_epsilon, | 580 | eps=args.adam_epsilon, |
569 | amsgrad=args.adam_amsgrad, | 581 | amsgrad=args.adam_amsgrad, |
570 | ) | 582 | ) |
583 | elif args.optimizer == 'lion': | ||
584 | try: | ||
585 | import lion_pytorch | ||
586 | except ImportError: | ||
587 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") | ||
588 | |||
589 | create_optimizer = partial( | ||
590 | lion_pytorch.Lion, | ||
591 | betas=(args.adam_beta1, args.adam_beta2), | ||
592 | weight_decay=args.adam_weight_decay, | ||
593 | use_triton=True, | ||
594 | ) | ||
571 | elif args.optimizer == 'adafactor': | 595 | elif args.optimizer == 'adafactor': |
572 | create_optimizer = partial( | 596 | create_optimizer = partial( |
573 | transformers.optimization.Adafactor, | 597 | transformers.optimization.Adafactor, |