summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py28
1 files changed, 28 insertions, 0 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index dd2bf6e..b706d07 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -512,6 +512,34 @@ def main():
512 eps=args.adam_epsilon, 512 eps=args.adam_epsilon,
513 amsgrad=args.adam_amsgrad, 513 amsgrad=args.adam_amsgrad,
514 ) 514 )
515 elif args.optimizer == 'dadam':
516 try:
517 import dadaptation
518 except ImportError:
519 raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.")
520
521 create_optimizer = partial(
522 dadaptation.DAdaptAdam,
523 betas=(args.adam_beta1, args.adam_beta2),
524 weight_decay=args.adam_weight_decay,
525 eps=args.adam_epsilon,
526 decouple=True,
527 )
528
529 args.learning_rate = 1.0
530 elif args.optimizer == 'dadan':
531 try:
532 import dadaptation
533 except ImportError:
534 raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.")
535
536 create_optimizer = partial(
537 dadaptation.DAdaptAdan,
538 weight_decay=args.adam_weight_decay,
539 eps=args.adam_epsilon,
540 )
541
542 args.learning_rate = 1.0
515 else: 543 else:
516 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 544 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
517 545