diff options
| author | Volpeon <git@volpeon.ink> | 2023-02-18 13:00:13 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-02-18 13:00:13 +0100 |
| commit | 2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2 (patch) | |
| tree | e08741c9df3b30a05ade472da45d7410bbf972ae /train_dreambooth.py | |
| parent | Added Lion optimizer (diff) | |
| download | textual-inversion-diff-2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2.tar.gz textual-inversion-diff-2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2.tar.bz2 textual-inversion-diff-2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2.zip | |
Update
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 8f0c6ea..e039df0 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -287,7 +287,7 @@ def parse_args(): | |||
| 287 | parser.add_argument( | 287 | parser.add_argument( |
| 288 | "--optimizer", | 288 | "--optimizer", |
| 289 | type=str, | 289 | type=str, |
| 290 | default="lion", | 290 | default="adam", |
| 291 | help='Optimizer to use ["adam", "adam8bit", "lion"]' | 291 | help='Optimizer to use ["adam", "adam8bit", "lion"]' |
| 292 | ) | 292 | ) |
| 293 | parser.add_argument( | 293 | parser.add_argument( |
| @@ -459,7 +459,7 @@ def main(): | |||
| 459 | save_args(output_dir, args) | 459 | save_args(output_dir, args) |
| 460 | 460 | ||
| 461 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 461 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| 462 | args.pretrained_model_name_or_path) | 462 | args.pretrained_model_name_or_path, noise_scheduler="deis") |
| 463 | 463 | ||
| 464 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 464 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
| 465 | tokenizer.set_dropout(args.vector_dropout) | 465 | tokenizer.set_dropout(args.vector_dropout) |
| @@ -513,13 +513,15 @@ def main(): | |||
| 513 | eps=args.adam_epsilon, | 513 | eps=args.adam_epsilon, |
| 514 | amsgrad=args.adam_amsgrad, | 514 | amsgrad=args.adam_amsgrad, |
| 515 | ) | 515 | ) |
| 516 | else: | 516 | elif args.optimizer == 'lion': |
| 517 | try: | 517 | try: |
| 518 | from lion_pytorch import Lion | 518 | from lion_pytorch import Lion |
| 519 | except ImportError: | 519 | except ImportError: |
| 520 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") | 520 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") |
| 521 | 521 | ||
| 522 | create_optimizer = partial(Lion, use_triton=True) | 522 | create_optimizer = partial(Lion, use_triton=True) |
| 523 | else: | ||
| 524 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | ||
| 523 | 525 | ||
| 524 | trainer = partial( | 526 | trainer = partial( |
| 525 | train, | 527 | train, |
