summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-03 12:39:17 +0200
committerVolpeon <git@volpeon.ink>2023-04-03 12:39:17 +0200
commita08cd4b1581ca195f619e8bdb6cb6448287d4d2f (patch)
tree49c753924b44f8b82403ee52ed21f3ab60c76748 /train_lora.py
parentFix memory leak (diff)
downloadtextual-inversion-diff-a08cd4b1581ca195f619e8bdb6cb6448287d4d2f.tar.gz
textual-inversion-diff-a08cd4b1581ca195f619e8bdb6cb6448287d4d2f.tar.bz2
textual-inversion-diff-a08cd4b1581ca195f619e8bdb6cb6448287d4d2f.zip
Bring back Lion optimizer
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py30
1 files changed, 27 insertions, 3 deletions
diff --git a/train_lora.py b/train_lora.py
index cf73645..a0cd174 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -318,7 +318,7 @@ def parse_args():
318 "--optimizer", 318 "--optimizer",
319 type=str, 319 type=str,
320 default="dadan", 320 default="dadan",
321 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' 321 help='Optimizer to use ["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"]'
322 ) 322 )
323 parser.add_argument( 323 parser.add_argument(
324 "--dadaptation_d0", 324 "--dadaptation_d0",
@@ -329,13 +329,13 @@ def parse_args():
329 parser.add_argument( 329 parser.add_argument(
330 "--adam_beta1", 330 "--adam_beta1",
331 type=float, 331 type=float,
332 default=0.9, 332 default=None,
333 help="The beta1 parameter for the Adam optimizer." 333 help="The beta1 parameter for the Adam optimizer."
334 ) 334 )
335 parser.add_argument( 335 parser.add_argument(
336 "--adam_beta2", 336 "--adam_beta2",
337 type=float, 337 type=float,
338 default=0.999, 338 default=None,
339 help="The beta2 parameter for the Adam optimizer." 339 help="The beta2 parameter for the Adam optimizer."
340 ) 340 )
341 parser.add_argument( 341 parser.add_argument(
@@ -468,6 +468,18 @@ def parse_args():
468 if args.output_dir is None: 468 if args.output_dir is None:
469 raise ValueError("You must specify --output_dir") 469 raise ValueError("You must specify --output_dir")
470 470
471 if args.adam_beta1 is None:
472 if args.optimizer in ('adam', 'adam8bit'):
473 args.adam_beta1 = 0.9
474 elif args.optimizer == 'lion':
475 args.adam_beta1 = 0.95
476
477 if args.adam_beta2 is None:
478 if args.optimizer in ('adam', 'adam8bit'):
479 args.adam_beta2 = 0.999
480 elif args.optimizer == 'lion':
481 args.adam_beta2 = 0.98
482
471 return args 483 return args
472 484
473 485
@@ -568,6 +580,18 @@ def main():
568 eps=args.adam_epsilon, 580 eps=args.adam_epsilon,
569 amsgrad=args.adam_amsgrad, 581 amsgrad=args.adam_amsgrad,
570 ) 582 )
583 elif args.optimizer == 'lion':
584 try:
585 import lion_pytorch
586 except ImportError:
587 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.")
588
589 create_optimizer = partial(
590 lion_pytorch.Lion,
591 betas=(args.adam_beta1, args.adam_beta2),
592 weight_decay=args.adam_weight_decay,
593 use_triton=True,
594 )
571 elif args.optimizer == 'adafactor': 595 elif args.optimizer == 'adafactor':
572 create_optimizer = partial( 596 create_optimizer = partial(
573 transformers.optimization.Adafactor, 597 transformers.optimization.Adafactor,