diff options
| author | Volpeon <git@volpeon.ink> | 2023-02-17 21:06:11 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-02-17 21:06:11 +0100 |
| commit | f894dfecfaa3ec17903b2ac37ac4f071408613db (patch) | |
| tree | 02bf8439315c832528651186285f8b1fbd649f32 /train_lora.py | |
| parent | Inference script: Better scheduler config (diff) | |
| download | textual-inversion-diff-f894dfecfaa3ec17903b2ac37ac4f071408613db.tar.gz textual-inversion-diff-f894dfecfaa3ec17903b2ac37ac4f071408613db.tar.bz2 textual-inversion-diff-f894dfecfaa3ec17903b2ac37ac4f071408613db.zip | |
Added Lion optimizer
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 38 |
1 files changed, 27 insertions, 11 deletions
diff --git a/train_lora.py b/train_lora.py index 330bcd6..368c29b 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -245,9 +245,10 @@ def parse_args(): | |||
| 245 | help="Minimum learning rate in the lr scheduler." | 245 | help="Minimum learning rate in the lr scheduler." |
| 246 | ) | 246 | ) |
| 247 | parser.add_argument( | 247 | parser.add_argument( |
| 248 | "--use_8bit_adam", | 248 | "--optimizer", |
| 249 | action="store_true", | 249 | type=str, |
| 250 | help="Whether or not to use 8-bit Adam from bitsandbytes." | 250 | default="lion", |
| 251 | help='Optimizer to use ["adam", "adam8bit", "lion"]' | ||
| 251 | ) | 252 | ) |
| 252 | parser.add_argument( | 253 | parser.add_argument( |
| 253 | "--adam_beta1", | 254 | "--adam_beta1", |
| @@ -466,15 +467,34 @@ def main(): | |||
| 466 | args.learning_rate = 1e-6 | 467 | args.learning_rate = 1e-6 |
| 467 | args.lr_scheduler = "exponential_growth" | 468 | args.lr_scheduler = "exponential_growth" |
| 468 | 469 | ||
| 469 | if args.use_8bit_adam: | 470 | if args.optimizer == 'adam8bit': |
| 470 | try: | 471 | try: |
| 471 | import bitsandbytes as bnb | 472 | import bitsandbytes as bnb |
| 472 | except ImportError: | 473 | except ImportError: |
| 473 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") | 474 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") |
| 474 | 475 | ||
| 475 | optimizer_class = bnb.optim.AdamW8bit | 476 | create_optimizer = partial( |
| 477 | bnb.optim.AdamW8bit, | ||
| 478 | betas=(args.adam_beta1, args.adam_beta2), | ||
| 479 | weight_decay=args.adam_weight_decay, | ||
| 480 | eps=args.adam_epsilon, | ||
| 481 | amsgrad=args.adam_amsgrad, | ||
| 482 | ) | ||
| 483 | elif args.optimizer == 'adam': | ||
| 484 | create_optimizer = partial( | ||
| 485 | torch.optim.AdamW, | ||
| 486 | betas=(args.adam_beta1, args.adam_beta2), | ||
| 487 | weight_decay=args.adam_weight_decay, | ||
| 488 | eps=args.adam_epsilon, | ||
| 489 | amsgrad=args.adam_amsgrad, | ||
| 490 | ) | ||
| 476 | else: | 491 | else: |
| 477 | optimizer_class = torch.optim.AdamW | 492 | try: |
| 493 | from lion_pytorch import Lion | ||
| 494 | except ImportError: | ||
| 495 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") | ||
| 496 | |||
| 497 | create_optimizer = partial(Lion, use_triton=True) | ||
| 478 | 498 | ||
| 479 | trainer = partial( | 499 | trainer = partial( |
| 480 | train, | 500 | train, |
| @@ -516,13 +536,9 @@ def main(): | |||
| 516 | ) | 536 | ) |
| 517 | datamodule.setup() | 537 | datamodule.setup() |
| 518 | 538 | ||
| 519 | optimizer = optimizer_class( | 539 | optimizer = create_optimizer( |
| 520 | lora_layers.parameters(), | 540 | lora_layers.parameters(), |
| 521 | lr=args.learning_rate, | 541 | lr=args.learning_rate, |
| 522 | betas=(args.adam_beta1, args.adam_beta2), | ||
| 523 | weight_decay=args.adam_weight_decay, | ||
| 524 | eps=args.adam_epsilon, | ||
| 525 | amsgrad=args.adam_amsgrad, | ||
| 526 | ) | 542 | ) |
| 527 | 543 | ||
| 528 | lr_scheduler = get_scheduler( | 544 | lr_scheduler = get_scheduler( |
