diff options
author | Volpeon <git@volpeon.ink> | 2023-03-21 14:18:08 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-21 14:18:08 +0100 |
commit | 744b87831f5e854d86c9f39c131386c3b26e9304 (patch) | |
tree | 66226b7a8dfe5403b2dacf2c7397833d981ab3c1 /train_ti.py | |
parent | Fixed SNR weighting, re-enabled xformers (diff) | |
download | textual-inversion-diff-744b87831f5e854d86c9f39c131386c3b26e9304.tar.gz textual-inversion-diff-744b87831f5e854d86c9f39c131386c3b26e9304.tar.bz2 textual-inversion-diff-744b87831f5e854d86c9f39c131386c3b26e9304.zip |
Added dadaptation
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/train_ti.py b/train_ti.py index 2e92ae4..ee65b44 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -607,6 +607,34 @@ def main(): | |||
607 | eps=args.adam_epsilon, | 607 | eps=args.adam_epsilon, |
608 | amsgrad=args.adam_amsgrad, | 608 | amsgrad=args.adam_amsgrad, |
609 | ) | 609 | ) |
610 | elif args.optimizer == 'dadam': | ||
611 | try: | ||
612 | import dadaptation | ||
613 | except ImportError: | ||
614 | raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") | ||
615 | |||
616 | create_optimizer = partial( | ||
617 | dadaptation.DAdaptAdam, | ||
618 | betas=(args.adam_beta1, args.adam_beta2), | ||
619 | weight_decay=args.adam_weight_decay, | ||
620 | eps=args.adam_epsilon, | ||
621 | decouple=True, | ||
622 | ) | ||
623 | |||
624 | args.learning_rate = 1.0 | ||
625 | elif args.optimizer == 'dadan': | ||
626 | try: | ||
627 | import dadaptation | ||
628 | except ImportError: | ||
629 | raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") | ||
630 | |||
631 | create_optimizer = partial( | ||
632 | dadaptation.DAdaptAdan, | ||
633 | weight_decay=args.adam_weight_decay, | ||
634 | eps=args.adam_epsilon, | ||
635 | ) | ||
636 | |||
637 | args.learning_rate = 1.0 | ||
610 | else: | 638 | else: |
611 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 639 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
612 | 640 | ||