summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-05 16:02:04 +0200
committerVolpeon <git@volpeon.ink>2023-04-05 16:02:04 +0200
commitf4f996681ca340e940315ca0ebc162c655904a7d (patch)
tree86379774ae04c4a89f831255e436daac3c067cd1 /train_dreambooth.py
parentFix choice args (diff)
downloadtextual-inversion-diff-f4f996681ca340e940315ca0ebc162c655904a7d.tar.gz
textual-inversion-diff-f4f996681ca340e940315ca0ebc162c655904a7d.tar.bz2
textual-inversion-diff-f4f996681ca340e940315ca0ebc162c655904a7d.zip
Add color jitter
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py15
1 files changed, 12 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 4c36ae4..48921d4 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -306,7 +306,7 @@ def parse_args():
306 "--optimizer", 306 "--optimizer",
307 type=str, 307 type=str,
308 default="dadan", 308 default="dadan",
309 choices=["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"], 309 choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"],
310 help='Optimizer to use' 310 help='Optimizer to use'
311 ) 311 )
312 parser.add_argument( 312 parser.add_argument(
@@ -513,8 +513,6 @@ def main():
513 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 513 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
514 raise ValueError("--embeddings_dir must point to an existing directory") 514 raise ValueError("--embeddings_dir must point to an existing directory")
515 515
516 embeddings.persist()
517
518 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 516 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
519 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 517 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
520 518
@@ -549,6 +547,17 @@ def main():
549 eps=args.adam_epsilon, 547 eps=args.adam_epsilon,
550 amsgrad=args.adam_amsgrad, 548 amsgrad=args.adam_amsgrad,
551 ) 549 )
550 elif args.optimizer == 'adan':
551 try:
552 import timm.optim
553 except ImportError:
554 raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.")
555
556 create_optimizer = partial(
557 timm.optim.Adan,
558 weight_decay=args.adam_weight_decay,
559 eps=args.adam_epsilon,
560 )
552 elif args.optimizer == 'lion': 561 elif args.optimizer == 'lion':
553 try: 562 try:
554 import lion_pytorch 563 import lion_pytorch