From 744b87831f5e854d86c9f39c131386c3b26e9304 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 21 Mar 2023 14:18:08 +0100 Subject: Added dadaptation --- train_ti.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 2e92ae4..ee65b44 100644 --- a/train_ti.py +++ b/train_ti.py @@ -607,6 +607,34 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) + elif args.optimizer == 'dadam': + try: + import dadaptation + except ImportError: + raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") + + create_optimizer = partial( + dadaptation.DAdaptAdam, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=True, + ) + + args.learning_rate = 1.0 + elif args.optimizer == 'dadan': + try: + import dadaptation + except ImportError: + raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") + + create_optimizer = partial( + dadaptation.DAdaptAdan, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + args.learning_rate = 1.0 else: raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") -- cgit v1.2.3-54-g00ecf