diff options
| author | Volpeon <git@volpeon.ink> | 2023-02-18 13:00:13 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-02-18 13:00:13 +0100 |
| commit | 2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2 (patch) | |
| tree | e08741c9df3b30a05ade472da45d7410bbf972ae /train_ti.py | |
| parent | Added Lion optimizer (diff) | |
| download | textual-inversion-diff-2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2.tar.gz textual-inversion-diff-2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2.tar.bz2 textual-inversion-diff-2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2.zip | |
Update
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py index 507d710..12e3644 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -292,7 +292,7 @@ def parse_args(): | |||
| 292 | parser.add_argument( | 292 | parser.add_argument( |
| 293 | "--optimizer", | 293 | "--optimizer", |
| 294 | type=str, | 294 | type=str, |
| 295 | default="lion", | 295 | default="adam", |
| 296 | help='Optimizer to use ["adam", "adam8bit", "lion"]' | 296 | help='Optimizer to use ["adam", "adam8bit", "lion"]' |
| 297 | ) | 297 | ) |
| 298 | parser.add_argument( | 298 | parser.add_argument( |
| @@ -586,13 +586,15 @@ def main(): | |||
| 586 | eps=args.adam_epsilon, | 586 | eps=args.adam_epsilon, |
| 587 | amsgrad=args.adam_amsgrad, | 587 | amsgrad=args.adam_amsgrad, |
| 588 | ) | 588 | ) |
| 589 | else: | 589 | elif args.optimizer == 'lion': |
| 590 | try: | 590 | try: |
| 591 | from lion_pytorch import Lion | 591 | from lion_pytorch import Lion |
| 592 | except ImportError: | 592 | except ImportError: |
| 593 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") | 593 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") |
| 594 | 594 | ||
| 595 | create_optimizer = partial(Lion, use_triton=True) | 595 | create_optimizer = partial(Lion, use_triton=True) |
| 596 | else: | ||
| 597 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | ||
| 596 | 598 | ||
| 597 | checkpoint_output_dir = output_dir/"checkpoints" | 599 | checkpoint_output_dir = output_dir/"checkpoints" |
| 598 | 600 | ||
