diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index dd2bf6e..b706d07 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -512,6 +512,34 @@ def main(): | |||
512 | eps=args.adam_epsilon, | 512 | eps=args.adam_epsilon, |
513 | amsgrad=args.adam_amsgrad, | 513 | amsgrad=args.adam_amsgrad, |
514 | ) | 514 | ) |
515 | elif args.optimizer == 'dadam': | ||
516 | try: | ||
517 | import dadaptation | ||
518 | except ImportError: | ||
519 | raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") | ||
520 | |||
521 | create_optimizer = partial( | ||
522 | dadaptation.DAdaptAdam, | ||
523 | betas=(args.adam_beta1, args.adam_beta2), | ||
524 | weight_decay=args.adam_weight_decay, | ||
525 | eps=args.adam_epsilon, | ||
526 | decouple=True, | ||
527 | ) | ||
528 | |||
529 | args.learning_rate = 1.0 | ||
530 | elif args.optimizer == 'dadan': | ||
531 | try: | ||
532 | import dadaptation | ||
533 | except ImportError: | ||
534 | raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") | ||
535 | |||
536 | create_optimizer = partial( | ||
537 | dadaptation.DAdaptAdan, | ||
538 | weight_decay=args.adam_weight_decay, | ||
539 | eps=args.adam_epsilon, | ||
540 | ) | ||
541 | |||
542 | args.learning_rate = 1.0 | ||
515 | else: | 543 | else: |
516 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 544 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
517 | 545 | ||