summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py42
1 files changed, 36 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py
index 1d0cb6f..a7d2924 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -337,7 +337,16 @@ def parse_args():
337 "--optimizer", 337 "--optimizer",
338 type=str, 338 type=str,
339 default="adan", 339 default="adan",
340 choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], 340 choices=[
341 "adam",
342 "adam8bit",
343 "adan",
344 "lion",
345 "dadam",
346 "dadan",
347 "adafactor",
348 "prodigy",
349 ],
341 help="Optimizer to use", 350 help="Optimizer to use",
342 ) 351 )
343 parser.add_argument( 352 parser.add_argument(
@@ -819,6 +828,23 @@ def main():
819 eps=args.adam_epsilon, 828 eps=args.adam_epsilon,
820 d0=args.dadaptation_d0, 829 d0=args.dadaptation_d0,
821 ) 830 )
831 elif args.optimizer == "prodigy":
832 try:
833 import prodigyopt
834 except ImportError:
835 raise ImportError(
836 "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`."
837 )
838
839 create_optimizer = partial(
840 prodigyopt.Prodigy,
841 betas=(args.adam_beta1, args.adam_beta2),
842 weight_decay=args.adam_weight_decay,
843 eps=args.adam_epsilon,
844 d0=args.dadaptation_d0,
845 )
846
847 args.learning_rate = 1.0
822 else: 848 else:
823 raise ValueError(f'Unknown --optimizer "{args.optimizer}"') 849 raise ValueError(f'Unknown --optimizer "{args.optimizer}"')
824 850
@@ -959,7 +985,11 @@ def main():
959 avg_acc_val = AverageMeter() 985 avg_acc_val = AverageMeter()
960 986
961 optimizer = create_optimizer( 987 optimizer = create_optimizer(
962 text_encoder.text_model.embeddings.token_embedding.parameters(), 988 (
989 param
990 for param in text_encoder.text_model.embeddings.token_embedding.parameters()
991 if param.requires_grad
992 ),
963 lr=args.learning_rate, 993 lr=args.learning_rate,
964 ) 994 )
965 995
@@ -973,9 +1003,11 @@ def main():
973 1003
974 if response.lower().strip() == "o": 1004 if response.lower().strip() == "o":
975 if args.learning_rate is not None: 1005 if args.learning_rate is not None:
976 learning_rate = args.learning_rate * 2 1006 learning_rate = (
1007 args.learning_rate * 2 * (args.cycle_decay**training_iter)
1008 )
977 else: 1009 else:
978 learning_rate = args.learning_rate 1010 learning_rate = args.learning_rate * (args.cycle_decay**training_iter)
979 1011
980 if response.lower().strip() == "o": 1012 if response.lower().strip() == "o":
981 lr_scheduler = "one_cycle" 1013 lr_scheduler = "one_cycle"
@@ -1045,8 +1077,6 @@ def main():
1045 ) 1077 )
1046 1078
1047 training_iter += 1 1079 training_iter += 1
1048 if learning_rate is not None:
1049 learning_rate *= args.cycle_decay
1050 1080
1051 accelerator.end_training() 1081 accelerator.end_training()
1052 1082