diff options
author | Volpeon <git@volpeon.ink> | 2023-02-17 21:06:11 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-02-17 21:06:11 +0100 |
commit | f894dfecfaa3ec17903b2ac37ac4f071408613db (patch) | |
tree | 02bf8439315c832528651186285f8b1fbd649f32 /train_ti.py | |
parent | Inference script: Better scheduler config (diff) | |
download | textual-inversion-diff-f894dfecfaa3ec17903b2ac37ac4f071408613db.tar.gz textual-inversion-diff-f894dfecfaa3ec17903b2ac37ac4f071408613db.tar.bz2 textual-inversion-diff-f894dfecfaa3ec17903b2ac37ac4f071408613db.zip |
Added Lion optimizer
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 38 |
1 files changed, 27 insertions, 11 deletions
diff --git a/train_ti.py b/train_ti.py index 3aa1027..507d710 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -290,9 +290,10 @@ def parse_args(): | |||
290 | default=0.9999 | 290 | default=0.9999 |
291 | ) | 291 | ) |
292 | parser.add_argument( | 292 | parser.add_argument( |
293 | "--use_8bit_adam", | 293 | "--optimizer", |
294 | action="store_true", | 294 | type=str, |
295 | help="Whether or not to use 8-bit Adam from bitsandbytes." | 295 | default="lion", |
296 | help='Optimizer to use ["adam", "adam8bit", "lion"]' | ||
296 | ) | 297 | ) |
297 | parser.add_argument( | 298 | parser.add_argument( |
298 | "--adam_beta1", | 299 | "--adam_beta1", |
@@ -564,15 +565,34 @@ def main(): | |||
564 | args.learning_rate = 1e-5 | 565 | args.learning_rate = 1e-5 |
565 | args.lr_scheduler = "exponential_growth" | 566 | args.lr_scheduler = "exponential_growth" |
566 | 567 | ||
567 | if args.use_8bit_adam: | 568 | if args.optimizer == 'adam8bit': |
568 | try: | 569 | try: |
569 | import bitsandbytes as bnb | 570 | import bitsandbytes as bnb |
570 | except ImportError: | 571 | except ImportError: |
571 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") | 572 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") |
572 | 573 | ||
573 | optimizer_class = bnb.optim.AdamW8bit | 574 | create_optimizer = partial( |
575 | bnb.optim.AdamW8bit, | ||
576 | betas=(args.adam_beta1, args.adam_beta2), | ||
577 | weight_decay=args.adam_weight_decay, | ||
578 | eps=args.adam_epsilon, | ||
579 | amsgrad=args.adam_amsgrad, | ||
580 | ) | ||
581 | elif args.optimizer == 'adam': | ||
582 | create_optimizer = partial( | ||
583 | torch.optim.AdamW, | ||
584 | betas=(args.adam_beta1, args.adam_beta2), | ||
585 | weight_decay=args.adam_weight_decay, | ||
586 | eps=args.adam_epsilon, | ||
587 | amsgrad=args.adam_amsgrad, | ||
588 | ) | ||
574 | else: | 589 | else: |
575 | optimizer_class = torch.optim.AdamW | 590 | try: |
591 | from lion_pytorch import Lion | ||
592 | except ImportError: | ||
593 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") | ||
594 | |||
595 | create_optimizer = partial(Lion, use_triton=True) | ||
576 | 596 | ||
577 | checkpoint_output_dir = output_dir/"checkpoints" | 597 | checkpoint_output_dir = output_dir/"checkpoints" |
578 | 598 | ||
@@ -658,13 +678,9 @@ def main(): | |||
658 | ) | 678 | ) |
659 | datamodule.setup() | 679 | datamodule.setup() |
660 | 680 | ||
661 | optimizer = optimizer_class( | 681 | optimizer = create_optimizer( |
662 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | 682 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), |
663 | lr=args.learning_rate, | 683 | lr=args.learning_rate, |
664 | betas=(args.adam_beta1, args.adam_beta2), | ||
665 | weight_decay=args.adam_weight_decay, | ||
666 | eps=args.adam_epsilon, | ||
667 | amsgrad=args.adam_amsgrad, | ||
668 | ) | 684 | ) |
669 | 685 | ||
670 | lr_scheduler = get_scheduler( | 686 | lr_scheduler = get_scheduler( |