From 4f6ecc2fd38dc894ac7fcb4f130d3ab8af49d132 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 25 Jun 2023 08:40:05 +0200 Subject: Update --- train_ti.py | 42 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 6 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 1d0cb6f..a7d2924 100644 --- a/train_ti.py +++ b/train_ti.py @@ -337,7 +337,16 @@ def parse_args(): "--optimizer", type=str, default="adan", - choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], + choices=[ + "adam", + "adam8bit", + "adan", + "lion", + "dadam", + "dadan", + "adafactor", + "prodigy", + ], help="Optimizer to use", ) parser.add_argument( @@ -819,6 +828,23 @@ def main(): eps=args.adam_epsilon, d0=args.dadaptation_d0, ) + elif args.optimizer == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError( + "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." + ) + + create_optimizer = partial( + prodigyopt.Prodigy, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + d0=args.dadaptation_d0, + ) + + args.learning_rate = 1.0 else: raise ValueError(f'Unknown --optimizer "{args.optimizer}"') @@ -959,7 +985,11 @@ def main(): avg_acc_val = AverageMeter() optimizer = create_optimizer( - text_encoder.text_model.embeddings.token_embedding.parameters(), + ( + param + for param in text_encoder.text_model.embeddings.token_embedding.parameters() + if param.requires_grad + ), lr=args.learning_rate, ) @@ -973,9 +1003,11 @@ def main(): if response.lower().strip() == "o": if args.learning_rate is not None: - learning_rate = args.learning_rate * 2 + learning_rate = ( + args.learning_rate * 2 * (args.cycle_decay**training_iter) + ) else: - learning_rate = args.learning_rate + learning_rate = args.learning_rate * (args.cycle_decay**training_iter) if response.lower().strip() == "o": lr_scheduler = "one_cycle" @@ -1045,8 +1077,6 @@ def main(): ) training_iter += 1 - if learning_rate is not None: - learning_rate *= args.cycle_decay accelerator.end_training() -- cgit v1.2.3-54-g00ecf