diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 11 |
1 files changed, 2 insertions, 9 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index e039df0..431ff3d 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -288,7 +288,7 @@ def parse_args(): | |||
288 | "--optimizer", | 288 | "--optimizer", |
289 | type=str, | 289 | type=str, |
290 | default="adam", | 290 | default="adam", |
291 | help='Optimizer to use ["adam", "adam8bit", "lion"]' | 291 | help='Optimizer to use ["adam", "adam8bit"]' |
292 | ) | 292 | ) |
293 | parser.add_argument( | 293 | parser.add_argument( |
294 | "--adam_beta1", | 294 | "--adam_beta1", |
@@ -459,7 +459,7 @@ def main(): | |||
459 | save_args(output_dir, args) | 459 | save_args(output_dir, args) |
460 | 460 | ||
461 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 461 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
462 | args.pretrained_model_name_or_path, noise_scheduler="deis") | 462 | args.pretrained_model_name_or_path) |
463 | 463 | ||
464 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 464 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
465 | tokenizer.set_dropout(args.vector_dropout) | 465 | tokenizer.set_dropout(args.vector_dropout) |
@@ -513,13 +513,6 @@ def main(): | |||
513 | eps=args.adam_epsilon, | 513 | eps=args.adam_epsilon, |
514 | amsgrad=args.adam_amsgrad, | 514 | amsgrad=args.adam_amsgrad, |
515 | ) | 515 | ) |
516 | elif args.optimizer == 'lion': | ||
517 | try: | ||
518 | from lion_pytorch import Lion | ||
519 | except ImportError: | ||
520 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") | ||
521 | |||
522 | create_optimizer = partial(Lion, use_triton=True) | ||
523 | else: | 516 | else: |
524 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 517 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
525 | 518 | ||