diff options
author | Volpeon <git@volpeon.ink> | 2023-04-03 12:39:17 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-03 12:39:17 +0200 |
commit | a08cd4b1581ca195f619e8bdb6cb6448287d4d2f (patch) | |
tree | 49c753924b44f8b82403ee52ed21f3ab60c76748 /train_dreambooth.py | |
parent | Fix memory leak (diff) | |
download | textual-inversion-diff-a08cd4b1581ca195f619e8bdb6cb6448287d4d2f.tar.gz textual-inversion-diff-a08cd4b1581ca195f619e8bdb6cb6448287d4d2f.tar.bz2 textual-inversion-diff-a08cd4b1581ca195f619e8bdb6cb6448287d4d2f.zip |
Bring back Lion optimizer
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 30 |
1 files changed, 27 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 48b7926..be7d6fe 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -306,7 +306,7 @@ def parse_args(): | |||
306 | "--optimizer", | 306 | "--optimizer", |
307 | type=str, | 307 | type=str, |
308 | default="dadan", | 308 | default="dadan", |
309 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' | 309 | help='Optimizer to use ["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"]' |
310 | ) | 310 | ) |
311 | parser.add_argument( | 311 | parser.add_argument( |
312 | "--dadaptation_d0", | 312 | "--dadaptation_d0", |
@@ -317,13 +317,13 @@ def parse_args(): | |||
317 | parser.add_argument( | 317 | parser.add_argument( |
318 | "--adam_beta1", | 318 | "--adam_beta1", |
319 | type=float, | 319 | type=float, |
320 | default=0.9, | 320 | default=None, |
321 | help="The beta1 parameter for the Adam optimizer." | 321 | help="The beta1 parameter for the Adam optimizer." |
322 | ) | 322 | ) |
323 | parser.add_argument( | 323 | parser.add_argument( |
324 | "--adam_beta2", | 324 | "--adam_beta2", |
325 | type=float, | 325 | type=float, |
326 | default=0.999, | 326 | default=None, |
327 | help="The beta2 parameter for the Adam optimizer." | 327 | help="The beta2 parameter for the Adam optimizer." |
328 | ) | 328 | ) |
329 | parser.add_argument( | 329 | parser.add_argument( |
@@ -450,6 +450,18 @@ def parse_args(): | |||
450 | if args.output_dir is None: | 450 | if args.output_dir is None: |
451 | raise ValueError("You must specify --output_dir") | 451 | raise ValueError("You must specify --output_dir") |
452 | 452 | ||
453 | if args.adam_beta1 is None: | ||
454 | if args.optimizer in ('adam', 'adam8bit'): | ||
455 | args.adam_beta1 = 0.9 | ||
456 | elif args.optimizer == 'lion': | ||
457 | args.adam_beta1 = 0.95 | ||
458 | |||
459 | if args.adam_beta2 is None: | ||
460 | if args.optimizer in ('adam', 'adam8bit'): | ||
461 | args.adam_beta2 = 0.999 | ||
462 | elif args.optimizer == 'lion': | ||
463 | args.adam_beta2 = 0.98 | ||
464 | |||
453 | return args | 465 | return args |
454 | 466 | ||
455 | 467 | ||
@@ -536,6 +548,18 @@ def main(): | |||
536 | eps=args.adam_epsilon, | 548 | eps=args.adam_epsilon, |
537 | amsgrad=args.adam_amsgrad, | 549 | amsgrad=args.adam_amsgrad, |
538 | ) | 550 | ) |
551 | elif args.optimizer == 'lion': | ||
552 | try: | ||
553 | import lion_pytorch | ||
554 | except ImportError: | ||
555 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") | ||
556 | |||
557 | create_optimizer = partial( | ||
558 | lion_pytorch.Lion, | ||
559 | betas=(args.adam_beta1, args.adam_beta2), | ||
560 | weight_decay=args.adam_weight_decay, | ||
561 | use_triton=True, | ||
562 | ) | ||
539 | elif args.optimizer == 'adafactor': | 563 | elif args.optimizer == 'adafactor': |
540 | create_optimizer = partial( | 564 | create_optimizer = partial( |
541 | transformers.optimization.Adafactor, | 565 | transformers.optimization.Adafactor, |