summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py11
1 files changed, 2 insertions, 9 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index e039df0..431ff3d 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -288,7 +288,7 @@ def parse_args():
288 "--optimizer", 288 "--optimizer",
289 type=str, 289 type=str,
290 default="adam", 290 default="adam",
291 help='Optimizer to use ["adam", "adam8bit", "lion"]' 291 help='Optimizer to use ["adam", "adam8bit"]'
292 ) 292 )
293 parser.add_argument( 293 parser.add_argument(
294 "--adam_beta1", 294 "--adam_beta1",
@@ -459,7 +459,7 @@ def main():
459 save_args(output_dir, args) 459 save_args(output_dir, args)
460 460
461 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 461 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
462 args.pretrained_model_name_or_path, noise_scheduler="deis") 462 args.pretrained_model_name_or_path)
463 463
464 tokenizer.set_use_vector_shuffle(args.vector_shuffle) 464 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
465 tokenizer.set_dropout(args.vector_dropout) 465 tokenizer.set_dropout(args.vector_dropout)
@@ -513,13 +513,6 @@ def main():
513 eps=args.adam_epsilon, 513 eps=args.adam_epsilon,
514 amsgrad=args.adam_amsgrad, 514 amsgrad=args.adam_amsgrad,
515 ) 515 )
516 elif args.optimizer == 'lion':
517 try:
518 from lion_pytorch import Lion
519 except ImportError:
520 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.")
521
522 create_optimizer = partial(Lion, use_triton=True)
523 else: 516 else:
524 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 517 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
525 518