diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-05 16:02:04 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-05 16:02:04 +0200 |
| commit | f4f996681ca340e940315ca0ebc162c655904a7d (patch) | |
| tree | 86379774ae04c4a89f831255e436daac3c067cd1 /train_lora.py | |
| parent | Fix choice args (diff) | |
| download | textual-inversion-diff-f4f996681ca340e940315ca0ebc162c655904a7d.tar.gz textual-inversion-diff-f4f996681ca340e940315ca0ebc162c655904a7d.tar.bz2 textual-inversion-diff-f4f996681ca340e940315ca0ebc162c655904a7d.zip | |
Add color jitter
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 15 |
1 files changed, 12 insertions, 3 deletions
diff --git a/train_lora.py b/train_lora.py index 538a7f7..73b3e19 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -317,7 +317,7 @@ def parse_args(): | |||
| 317 | "--optimizer", | 317 | "--optimizer", |
| 318 | type=str, | 318 | type=str, |
| 319 | default="dadan", | 319 | default="dadan", |
| 320 | choices=["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"], | 320 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], |
| 321 | help='Optimizer to use' | 321 | help='Optimizer to use' |
| 322 | ) | 322 | ) |
| 323 | parser.add_argument( | 323 | parser.add_argument( |
| @@ -544,8 +544,6 @@ def main(): | |||
| 544 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 544 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
| 545 | raise ValueError("--embeddings_dir must point to an existing directory") | 545 | raise ValueError("--embeddings_dir must point to an existing directory") |
| 546 | 546 | ||
| 547 | embeddings.persist() | ||
| 548 | |||
| 549 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 547 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
| 550 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 548 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
| 551 | 549 | ||
| @@ -580,6 +578,17 @@ def main(): | |||
| 580 | eps=args.adam_epsilon, | 578 | eps=args.adam_epsilon, |
| 581 | amsgrad=args.adam_amsgrad, | 579 | amsgrad=args.adam_amsgrad, |
| 582 | ) | 580 | ) |
| 581 | elif args.optimizer == 'adan': | ||
| 582 | try: | ||
| 583 | import timm.optim | ||
| 584 | except ImportError: | ||
| 585 | raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") | ||
| 586 | |||
| 587 | create_optimizer = partial( | ||
| 588 | timm.optim.Adan, | ||
| 589 | weight_decay=args.adam_weight_decay, | ||
| 590 | eps=args.adam_epsilon, | ||
| 591 | ) | ||
| 583 | elif args.optimizer == 'lion': | 592 | elif args.optimizer == 'lion': |
| 584 | try: | 593 | try: |
| 585 | import lion_pytorch | 594 | import lion_pytorch |
