diff options
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index dd2bf6e..b706d07 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -512,6 +512,34 @@ def main(): | |||
| 512 | eps=args.adam_epsilon, | 512 | eps=args.adam_epsilon, |
| 513 | amsgrad=args.adam_amsgrad, | 513 | amsgrad=args.adam_amsgrad, |
| 514 | ) | 514 | ) |
| 515 | elif args.optimizer == 'dadam': | ||
| 516 | try: | ||
| 517 | import dadaptation | ||
| 518 | except ImportError: | ||
| 519 | raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") | ||
| 520 | |||
| 521 | create_optimizer = partial( | ||
| 522 | dadaptation.DAdaptAdam, | ||
| 523 | betas=(args.adam_beta1, args.adam_beta2), | ||
| 524 | weight_decay=args.adam_weight_decay, | ||
| 525 | eps=args.adam_epsilon, | ||
| 526 | decouple=True, | ||
| 527 | ) | ||
| 528 | |||
| 529 | args.learning_rate = 1.0 | ||
| 530 | elif args.optimizer == 'dadan': | ||
| 531 | try: | ||
| 532 | import dadaptation | ||
| 533 | except ImportError: | ||
| 534 | raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") | ||
| 535 | |||
| 536 | create_optimizer = partial( | ||
| 537 | dadaptation.DAdaptAdan, | ||
| 538 | weight_decay=args.adam_weight_decay, | ||
| 539 | eps=args.adam_epsilon, | ||
| 540 | ) | ||
| 541 | |||
| 542 | args.learning_rate = 1.0 | ||
| 515 | else: | 543 | else: |
| 516 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 544 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
| 517 | 545 | ||
