diff options
author | Volpeon <git@volpeon.ink> | 2023-02-18 13:00:13 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-02-18 13:00:13 +0100 |
commit | 2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2 (patch) | |
tree | e08741c9df3b30a05ade472da45d7410bbf972ae /train_ti.py | |
parent | Added Lion optimizer (diff) | |
download | textual-inversion-diff-2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2.tar.gz textual-inversion-diff-2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2.tar.bz2 textual-inversion-diff-2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2.zip |
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py index 507d710..12e3644 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -292,7 +292,7 @@ def parse_args(): | |||
292 | parser.add_argument( | 292 | parser.add_argument( |
293 | "--optimizer", | 293 | "--optimizer", |
294 | type=str, | 294 | type=str, |
295 | default="lion", | 295 | default="adam", |
296 | help='Optimizer to use ["adam", "adam8bit", "lion"]' | 296 | help='Optimizer to use ["adam", "adam8bit", "lion"]' |
297 | ) | 297 | ) |
298 | parser.add_argument( | 298 | parser.add_argument( |
@@ -586,13 +586,15 @@ def main(): | |||
586 | eps=args.adam_epsilon, | 586 | eps=args.adam_epsilon, |
587 | amsgrad=args.adam_amsgrad, | 587 | amsgrad=args.adam_amsgrad, |
588 | ) | 588 | ) |
589 | else: | 589 | elif args.optimizer == 'lion': |
590 | try: | 590 | try: |
591 | from lion_pytorch import Lion | 591 | from lion_pytorch import Lion |
592 | except ImportError: | 592 | except ImportError: |
593 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") | 593 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") |
594 | 594 | ||
595 | create_optimizer = partial(Lion, use_triton=True) | 595 | create_optimizer = partial(Lion, use_triton=True) |
596 | else: | ||
597 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | ||
596 | 598 | ||
597 | checkpoint_output_dir = output_dir/"checkpoints" | 599 | checkpoint_output_dir = output_dir/"checkpoints" |
598 | 600 | ||