diff options
author | Volpeon <git@volpeon.ink> | 2023-02-21 09:09:50 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-02-21 09:09:50 +0100 |
commit | 16b92605a59d59c65789c89b54bb97da51908056 (patch) | |
tree | b0cbf8677897c3f44c736b710fd034eb2c5de6a0 /train_lora.py | |
parent | Update (diff) | |
download | textual-inversion-diff-16b92605a59d59c65789c89b54bb97da51908056.tar.gz textual-inversion-diff-16b92605a59d59c65789c89b54bb97da51908056.tar.bz2 textual-inversion-diff-16b92605a59d59c65789c89b54bb97da51908056.zip |
Embedding normalization: Ignore tensors with grad = 0
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 11 |
1 files changed, 2 insertions, 9 deletions
diff --git a/train_lora.py b/train_lora.py index db5330a..a06591d 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -248,7 +248,7 @@ def parse_args(): | |||
248 | "--optimizer", | 248 | "--optimizer", |
249 | type=str, | 249 | type=str, |
250 | default="adam", | 250 | default="adam", |
251 | help='Optimizer to use ["adam", "adam8bit", "lion"]' | 251 | help='Optimizer to use ["adam", "adam8bit"]' |
252 | ) | 252 | ) |
253 | parser.add_argument( | 253 | parser.add_argument( |
254 | "--adam_beta1", | 254 | "--adam_beta1", |
@@ -419,7 +419,7 @@ def main(): | |||
419 | save_args(output_dir, args) | 419 | save_args(output_dir, args) |
420 | 420 | ||
421 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 421 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
422 | args.pretrained_model_name_or_path, noise_scheduler="deis") | 422 | args.pretrained_model_name_or_path) |
423 | 423 | ||
424 | vae.enable_slicing() | 424 | vae.enable_slicing() |
425 | vae.set_use_memory_efficient_attention_xformers(True) | 425 | vae.set_use_memory_efficient_attention_xformers(True) |
@@ -488,13 +488,6 @@ def main(): | |||
488 | eps=args.adam_epsilon, | 488 | eps=args.adam_epsilon, |
489 | amsgrad=args.adam_amsgrad, | 489 | amsgrad=args.adam_amsgrad, |
490 | ) | 490 | ) |
491 | elif args.optimizer == 'lion': | ||
492 | try: | ||
493 | from lion_pytorch import Lion | ||
494 | except ImportError: | ||
495 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") | ||
496 | |||
497 | create_optimizer = partial(Lion, use_triton=True) | ||
498 | else: | 491 | else: |
499 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 492 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
500 | 493 | ||