From 2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 18 Feb 2023 13:00:13 +0100 Subject: Update --- train_dreambooth.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 8f0c6ea..e039df0 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -287,7 +287,7 @@ def parse_args(): parser.add_argument( "--optimizer", type=str, - default="lion", + default="adam", help='Optimizer to use ["adam", "adam8bit", "lion"]' ) parser.add_argument( @@ -459,7 +459,7 @@ def main(): save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( - args.pretrained_model_name_or_path) + args.pretrained_model_name_or_path, noise_scheduler="deis") tokenizer.set_use_vector_shuffle(args.vector_shuffle) tokenizer.set_dropout(args.vector_dropout) @@ -513,13 +513,15 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) - else: + elif args.optimizer == 'lion': try: from lion_pytorch import Lion except ImportError: raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") create_optimizer = partial(Lion, use_triton=True) + else: + raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") trainer = partial( train, -- cgit v1.2.3-54-g00ecf