summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py11
1 files changed, 2 insertions, 9 deletions
diff --git a/train_lora.py b/train_lora.py
index db5330a..a06591d 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -248,7 +248,7 @@ def parse_args():
248 "--optimizer", 248 "--optimizer",
249 type=str, 249 type=str,
250 default="adam", 250 default="adam",
251 help='Optimizer to use ["adam", "adam8bit", "lion"]' 251 help='Optimizer to use ["adam", "adam8bit"]'
252 ) 252 )
253 parser.add_argument( 253 parser.add_argument(
254 "--adam_beta1", 254 "--adam_beta1",
@@ -419,7 +419,7 @@ def main():
419 save_args(output_dir, args) 419 save_args(output_dir, args)
420 420
421 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 421 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
422 args.pretrained_model_name_or_path, noise_scheduler="deis") 422 args.pretrained_model_name_or_path)
423 423
424 vae.enable_slicing() 424 vae.enable_slicing()
425 vae.set_use_memory_efficient_attention_xformers(True) 425 vae.set_use_memory_efficient_attention_xformers(True)
@@ -488,13 +488,6 @@ def main():
488 eps=args.adam_epsilon, 488 eps=args.adam_epsilon,
489 amsgrad=args.adam_amsgrad, 489 amsgrad=args.adam_amsgrad,
490 ) 490 )
491 elif args.optimizer == 'lion':
492 try:
493 from lion_pytorch import Lion
494 except ImportError:
495 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.")
496
497 create_optimizer = partial(Lion, use_triton=True)
498 else: 491 else:
499 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 492 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
500 493