summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-03 12:39:17 +0200
committerVolpeon <git@volpeon.ink>2023-04-03 12:39:17 +0200
commita08cd4b1581ca195f619e8bdb6cb6448287d4d2f (patch)
tree49c753924b44f8b82403ee52ed21f3ab60c76748 /train_ti.py
parentFix memory leak (diff)
downloadtextual-inversion-diff-a08cd4b1581ca195f619e8bdb6cb6448287d4d2f.tar.gz
textual-inversion-diff-a08cd4b1581ca195f619e8bdb6cb6448287d4d2f.tar.bz2
textual-inversion-diff-a08cd4b1581ca195f619e8bdb6cb6448287d4d2f.zip
Bring back Lion optimizer
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py30
1 files changed, 27 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py
index 651dfbe..c242625 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -330,7 +330,7 @@ def parse_args():
330 "--optimizer", 330 "--optimizer",
331 type=str, 331 type=str,
332 default="dadan", 332 default="dadan",
333 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' 333 help='Optimizer to use ["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"]'
334 ) 334 )
335 parser.add_argument( 335 parser.add_argument(
336 "--dadaptation_d0", 336 "--dadaptation_d0",
@@ -341,13 +341,13 @@ def parse_args():
341 parser.add_argument( 341 parser.add_argument(
342 "--adam_beta1", 342 "--adam_beta1",
343 type=float, 343 type=float,
344 default=0.9, 344 default=None,
345 help="The beta1 parameter for the Adam optimizer." 345 help="The beta1 parameter for the Adam optimizer."
346 ) 346 )
347 parser.add_argument( 347 parser.add_argument(
348 "--adam_beta2", 348 "--adam_beta2",
349 type=float, 349 type=float,
350 default=0.999, 350 default=None,
351 help="The beta2 parameter for the Adam optimizer." 351 help="The beta2 parameter for the Adam optimizer."
352 ) 352 )
353 parser.add_argument( 353 parser.add_argument(
@@ -566,6 +566,18 @@ def parse_args():
566 if args.output_dir is None: 566 if args.output_dir is None:
567 raise ValueError("You must specify --output_dir") 567 raise ValueError("You must specify --output_dir")
568 568
569 if args.adam_beta1 is None:
570 if args.optimizer in ('adam', 'adam8bit'):
571 args.adam_beta1 = 0.9
572 elif args.optimizer == 'lion':
573 args.adam_beta1 = 0.95
574
575 if args.adam_beta2 is None:
576 if args.optimizer in ('adam', 'adam8bit'):
577 args.adam_beta2 = 0.999
578 elif args.optimizer == 'lion':
579 args.adam_beta2 = 0.98
580
569 return args 581 return args
570 582
571 583
@@ -666,6 +678,18 @@ def main():
666 eps=args.adam_epsilon, 678 eps=args.adam_epsilon,
667 amsgrad=args.adam_amsgrad, 679 amsgrad=args.adam_amsgrad,
668 ) 680 )
681 elif args.optimizer == 'lion':
682 try:
683 import lion_pytorch
684 except ImportError:
685 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.")
686
687 create_optimizer = partial(
688 lion_pytorch.Lion,
689 betas=(args.adam_beta1, args.adam_beta2),
690 weight_decay=args.adam_weight_decay,
691 use_triton=True,
692 )
669 elif args.optimizer == 'adafactor': 693 elif args.optimizer == 'adafactor':
670 create_optimizer = partial( 694 create_optimizer = partial(
671 transformers.optimization.Adafactor, 695 transformers.optimization.Adafactor,