diff options
author | Volpeon <git@volpeon.ink> | 2023-02-17 21:06:11 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-02-17 21:06:11 +0100 |
commit | f894dfecfaa3ec17903b2ac37ac4f071408613db (patch) | |
tree | 02bf8439315c832528651186285f8b1fbd649f32 /train_lora.py | |
parent | Inference script: Better scheduler config (diff) | |
download | textual-inversion-diff-f894dfecfaa3ec17903b2ac37ac4f071408613db.tar.gz textual-inversion-diff-f894dfecfaa3ec17903b2ac37ac4f071408613db.tar.bz2 textual-inversion-diff-f894dfecfaa3ec17903b2ac37ac4f071408613db.zip |
Added Lion optimizer
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 38 |
1 files changed, 27 insertions, 11 deletions
diff --git a/train_lora.py b/train_lora.py index 330bcd6..368c29b 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -245,9 +245,10 @@ def parse_args(): | |||
245 | help="Minimum learning rate in the lr scheduler." | 245 | help="Minimum learning rate in the lr scheduler." |
246 | ) | 246 | ) |
247 | parser.add_argument( | 247 | parser.add_argument( |
248 | "--use_8bit_adam", | 248 | "--optimizer", |
249 | action="store_true", | 249 | type=str, |
250 | help="Whether or not to use 8-bit Adam from bitsandbytes." | 250 | default="lion", |
251 | help='Optimizer to use ["adam", "adam8bit", "lion"]' | ||
251 | ) | 252 | ) |
252 | parser.add_argument( | 253 | parser.add_argument( |
253 | "--adam_beta1", | 254 | "--adam_beta1", |
@@ -466,15 +467,34 @@ def main(): | |||
466 | args.learning_rate = 1e-6 | 467 | args.learning_rate = 1e-6 |
467 | args.lr_scheduler = "exponential_growth" | 468 | args.lr_scheduler = "exponential_growth" |
468 | 469 | ||
469 | if args.use_8bit_adam: | 470 | if args.optimizer == 'adam8bit': |
470 | try: | 471 | try: |
471 | import bitsandbytes as bnb | 472 | import bitsandbytes as bnb |
472 | except ImportError: | 473 | except ImportError: |
473 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") | 474 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") |
474 | 475 | ||
475 | optimizer_class = bnb.optim.AdamW8bit | 476 | create_optimizer = partial( |
477 | bnb.optim.AdamW8bit, | ||
478 | betas=(args.adam_beta1, args.adam_beta2), | ||
479 | weight_decay=args.adam_weight_decay, | ||
480 | eps=args.adam_epsilon, | ||
481 | amsgrad=args.adam_amsgrad, | ||
482 | ) | ||
483 | elif args.optimizer == 'adam': | ||
484 | create_optimizer = partial( | ||
485 | torch.optim.AdamW, | ||
486 | betas=(args.adam_beta1, args.adam_beta2), | ||
487 | weight_decay=args.adam_weight_decay, | ||
488 | eps=args.adam_epsilon, | ||
489 | amsgrad=args.adam_amsgrad, | ||
490 | ) | ||
476 | else: | 491 | else: |
477 | optimizer_class = torch.optim.AdamW | 492 | try: |
493 | from lion_pytorch import Lion | ||
494 | except ImportError: | ||
495 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") | ||
496 | |||
497 | create_optimizer = partial(Lion, use_triton=True) | ||
478 | 498 | ||
479 | trainer = partial( | 499 | trainer = partial( |
480 | train, | 500 | train, |
@@ -516,13 +536,9 @@ def main(): | |||
516 | ) | 536 | ) |
517 | datamodule.setup() | 537 | datamodule.setup() |
518 | 538 | ||
519 | optimizer = optimizer_class( | 539 | optimizer = create_optimizer( |
520 | lora_layers.parameters(), | 540 | lora_layers.parameters(), |
521 | lr=args.learning_rate, | 541 | lr=args.learning_rate, |
522 | betas=(args.adam_beta1, args.adam_beta2), | ||
523 | weight_decay=args.adam_weight_decay, | ||
524 | eps=args.adam_epsilon, | ||
525 | amsgrad=args.adam_amsgrad, | ||
526 | ) | 542 | ) |
527 | 543 | ||
528 | lr_scheduler = get_scheduler( | 544 | lr_scheduler = get_scheduler( |