summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-18 13:00:13 +0100
committerVolpeon <git@volpeon.ink>2023-02-18 13:00:13 +0100
commit2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2 (patch)
treee08741c9df3b30a05ade472da45d7410bbf972ae /train_dreambooth.py
parentAdded Lion optimizer (diff)
downloadtextual-inversion-diff-2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2.tar.gz
textual-inversion-diff-2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2.tar.bz2
textual-inversion-diff-2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2.zip
Update
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 8f0c6ea..e039df0 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -287,7 +287,7 @@ def parse_args():
287 parser.add_argument( 287 parser.add_argument(
288 "--optimizer", 288 "--optimizer",
289 type=str, 289 type=str,
290 default="lion", 290 default="adam",
291 help='Optimizer to use ["adam", "adam8bit", "lion"]' 291 help='Optimizer to use ["adam", "adam8bit", "lion"]'
292 ) 292 )
293 parser.add_argument( 293 parser.add_argument(
@@ -459,7 +459,7 @@ def main():
459 save_args(output_dir, args) 459 save_args(output_dir, args)
460 460
461 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 461 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
462 args.pretrained_model_name_or_path) 462 args.pretrained_model_name_or_path, noise_scheduler="deis")
463 463
464 tokenizer.set_use_vector_shuffle(args.vector_shuffle) 464 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
465 tokenizer.set_dropout(args.vector_dropout) 465 tokenizer.set_dropout(args.vector_dropout)
@@ -513,13 +513,15 @@ def main():
513 eps=args.adam_epsilon, 513 eps=args.adam_epsilon,
514 amsgrad=args.adam_amsgrad, 514 amsgrad=args.adam_amsgrad,
515 ) 515 )
516 else: 516 elif args.optimizer == 'lion':
517 try: 517 try:
518 from lion_pytorch import Lion 518 from lion_pytorch import Lion
519 except ImportError: 519 except ImportError:
520 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") 520 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.")
521 521
522 create_optimizer = partial(Lion, use_triton=True) 522 create_optimizer = partial(Lion, use_triton=True)
523 else:
524 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
523 525
524 trainer = partial( 526 trainer = partial(
525 train, 527 train,