diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 15 |
1 files changed, 12 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 4c36ae4..48921d4 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -306,7 +306,7 @@ def parse_args(): | |||
306 | "--optimizer", | 306 | "--optimizer", |
307 | type=str, | 307 | type=str, |
308 | default="dadan", | 308 | default="dadan", |
309 | choices=["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"], | 309 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], |
310 | help='Optimizer to use' | 310 | help='Optimizer to use' |
311 | ) | 311 | ) |
312 | parser.add_argument( | 312 | parser.add_argument( |
@@ -513,8 +513,6 @@ def main(): | |||
513 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 513 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
514 | raise ValueError("--embeddings_dir must point to an existing directory") | 514 | raise ValueError("--embeddings_dir must point to an existing directory") |
515 | 515 | ||
516 | embeddings.persist() | ||
517 | |||
518 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 516 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
519 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 517 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
520 | 518 | ||
@@ -549,6 +547,17 @@ def main(): | |||
549 | eps=args.adam_epsilon, | 547 | eps=args.adam_epsilon, |
550 | amsgrad=args.adam_amsgrad, | 548 | amsgrad=args.adam_amsgrad, |
551 | ) | 549 | ) |
550 | elif args.optimizer == 'adan': | ||
551 | try: | ||
552 | import timm.optim | ||
553 | except ImportError: | ||
554 | raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") | ||
555 | |||
556 | create_optimizer = partial( | ||
557 | timm.optim.Adan, | ||
558 | weight_decay=args.adam_weight_decay, | ||
559 | eps=args.adam_epsilon, | ||
560 | ) | ||
552 | elif args.optimizer == 'lion': | 561 | elif args.optimizer == 'lion': |
553 | try: | 562 | try: |
554 | import lion_pytorch | 563 | import lion_pytorch |