diff options
| author | Volpeon <git@volpeon.ink> | 2023-03-21 14:18:08 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-03-21 14:18:08 +0100 |
| commit | 744b87831f5e854d86c9f39c131386c3b26e9304 (patch) | |
| tree | 66226b7a8dfe5403b2dacf2c7397833d981ab3c1 /train_lora.py | |
| parent | Fixed SNR weighting, re-enabled xformers (diff) | |
| download | textual-inversion-diff-744b87831f5e854d86c9f39c131386c3b26e9304.tar.gz textual-inversion-diff-744b87831f5e854d86c9f39c131386c3b26e9304.tar.bz2 textual-inversion-diff-744b87831f5e854d86c9f39c131386c3b26e9304.zip | |
Added dadaptation
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/train_lora.py b/train_lora.py index 2a798f3..ce8fb50 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -476,6 +476,34 @@ def main(): | |||
| 476 | eps=args.adam_epsilon, | 476 | eps=args.adam_epsilon, |
| 477 | amsgrad=args.adam_amsgrad, | 477 | amsgrad=args.adam_amsgrad, |
| 478 | ) | 478 | ) |
| 479 | elif args.optimizer == 'dadam': | ||
| 480 | try: | ||
| 481 | import dadaptation | ||
| 482 | except ImportError: | ||
| 483 | raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") | ||
| 484 | |||
| 485 | create_optimizer = partial( | ||
| 486 | dadaptation.DAdaptAdam, | ||
| 487 | betas=(args.adam_beta1, args.adam_beta2), | ||
| 488 | weight_decay=args.adam_weight_decay, | ||
| 489 | eps=args.adam_epsilon, | ||
| 490 | decouple=True, | ||
| 491 | ) | ||
| 492 | |||
| 493 | args.learning_rate = 1.0 | ||
| 494 | elif args.optimizer == 'dadan': | ||
| 495 | try: | ||
| 496 | import dadaptation | ||
| 497 | except ImportError: | ||
| 498 | raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") | ||
| 499 | |||
| 500 | create_optimizer = partial( | ||
| 501 | dadaptation.DAdaptAdan, | ||
| 502 | weight_decay=args.adam_weight_decay, | ||
| 503 | eps=args.adam_epsilon, | ||
| 504 | ) | ||
| 505 | |||
| 506 | args.learning_rate = 1.0 | ||
| 479 | else: | 507 | else: |
| 480 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 508 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
| 481 | 509 | ||
