diff options
author | Volpeon <git@volpeon.ink> | 2023-04-03 12:39:17 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-03 12:39:17 +0200 |
commit | a08cd4b1581ca195f619e8bdb6cb6448287d4d2f (patch) | |
tree | 49c753924b44f8b82403ee52ed21f3ab60c76748 /train_ti.py | |
parent | Fix memory leak (diff) | |
download | textual-inversion-diff-a08cd4b1581ca195f619e8bdb6cb6448287d4d2f.tar.gz textual-inversion-diff-a08cd4b1581ca195f619e8bdb6cb6448287d4d2f.tar.bz2 textual-inversion-diff-a08cd4b1581ca195f619e8bdb6cb6448287d4d2f.zip |
Bring back Lion optimizer
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 30 |
1 files changed, 27 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index 651dfbe..c242625 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -330,7 +330,7 @@ def parse_args(): | |||
330 | "--optimizer", | 330 | "--optimizer", |
331 | type=str, | 331 | type=str, |
332 | default="dadan", | 332 | default="dadan", |
333 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' | 333 | help='Optimizer to use ["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"]' |
334 | ) | 334 | ) |
335 | parser.add_argument( | 335 | parser.add_argument( |
336 | "--dadaptation_d0", | 336 | "--dadaptation_d0", |
@@ -341,13 +341,13 @@ def parse_args(): | |||
341 | parser.add_argument( | 341 | parser.add_argument( |
342 | "--adam_beta1", | 342 | "--adam_beta1", |
343 | type=float, | 343 | type=float, |
344 | default=0.9, | 344 | default=None, |
345 | help="The beta1 parameter for the Adam optimizer." | 345 | help="The beta1 parameter for the Adam optimizer." |
346 | ) | 346 | ) |
347 | parser.add_argument( | 347 | parser.add_argument( |
348 | "--adam_beta2", | 348 | "--adam_beta2", |
349 | type=float, | 349 | type=float, |
350 | default=0.999, | 350 | default=None, |
351 | help="The beta2 parameter for the Adam optimizer." | 351 | help="The beta2 parameter for the Adam optimizer." |
352 | ) | 352 | ) |
353 | parser.add_argument( | 353 | parser.add_argument( |
@@ -566,6 +566,18 @@ def parse_args(): | |||
566 | if args.output_dir is None: | 566 | if args.output_dir is None: |
567 | raise ValueError("You must specify --output_dir") | 567 | raise ValueError("You must specify --output_dir") |
568 | 568 | ||
569 | if args.adam_beta1 is None: | ||
570 | if args.optimizer in ('adam', 'adam8bit'): | ||
571 | args.adam_beta1 = 0.9 | ||
572 | elif args.optimizer == 'lion': | ||
573 | args.adam_beta1 = 0.95 | ||
574 | |||
575 | if args.adam_beta2 is None: | ||
576 | if args.optimizer in ('adam', 'adam8bit'): | ||
577 | args.adam_beta2 = 0.999 | ||
578 | elif args.optimizer == 'lion': | ||
579 | args.adam_beta2 = 0.98 | ||
580 | |||
569 | return args | 581 | return args |
570 | 582 | ||
571 | 583 | ||
@@ -666,6 +678,18 @@ def main(): | |||
666 | eps=args.adam_epsilon, | 678 | eps=args.adam_epsilon, |
667 | amsgrad=args.adam_amsgrad, | 679 | amsgrad=args.adam_amsgrad, |
668 | ) | 680 | ) |
681 | elif args.optimizer == 'lion': | ||
682 | try: | ||
683 | import lion_pytorch | ||
684 | except ImportError: | ||
685 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") | ||
686 | |||
687 | create_optimizer = partial( | ||
688 | lion_pytorch.Lion, | ||
689 | betas=(args.adam_beta1, args.adam_beta2), | ||
690 | weight_decay=args.adam_weight_decay, | ||
691 | use_triton=True, | ||
692 | ) | ||
669 | elif args.optimizer == 'adafactor': | 693 | elif args.optimizer == 'adafactor': |
670 | create_optimizer = partial( | 694 | create_optimizer = partial( |
671 | transformers.optimization.Adafactor, | 695 | transformers.optimization.Adafactor, |