diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 42 |
1 files changed, 36 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py index 1d0cb6f..a7d2924 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -337,7 +337,16 @@ def parse_args(): | |||
337 | "--optimizer", | 337 | "--optimizer", |
338 | type=str, | 338 | type=str, |
339 | default="adan", | 339 | default="adan", |
340 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], | 340 | choices=[ |
341 | "adam", | ||
342 | "adam8bit", | ||
343 | "adan", | ||
344 | "lion", | ||
345 | "dadam", | ||
346 | "dadan", | ||
347 | "adafactor", | ||
348 | "prodigy", | ||
349 | ], | ||
341 | help="Optimizer to use", | 350 | help="Optimizer to use", |
342 | ) | 351 | ) |
343 | parser.add_argument( | 352 | parser.add_argument( |
@@ -819,6 +828,23 @@ def main(): | |||
819 | eps=args.adam_epsilon, | 828 | eps=args.adam_epsilon, |
820 | d0=args.dadaptation_d0, | 829 | d0=args.dadaptation_d0, |
821 | ) | 830 | ) |
831 | elif args.optimizer == "prodigy": | ||
832 | try: | ||
833 | import prodigyopt | ||
834 | except ImportError: | ||
835 | raise ImportError( | ||
836 | "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." | ||
837 | ) | ||
838 | |||
839 | create_optimizer = partial( | ||
840 | prodigyopt.Prodigy, | ||
841 | betas=(args.adam_beta1, args.adam_beta2), | ||
842 | weight_decay=args.adam_weight_decay, | ||
843 | eps=args.adam_epsilon, | ||
844 | d0=args.dadaptation_d0, | ||
845 | ) | ||
846 | |||
847 | args.learning_rate = 1.0 | ||
822 | else: | 848 | else: |
823 | raise ValueError(f'Unknown --optimizer "{args.optimizer}"') | 849 | raise ValueError(f'Unknown --optimizer "{args.optimizer}"') |
824 | 850 | ||
@@ -959,7 +985,11 @@ def main(): | |||
959 | avg_acc_val = AverageMeter() | 985 | avg_acc_val = AverageMeter() |
960 | 986 | ||
961 | optimizer = create_optimizer( | 987 | optimizer = create_optimizer( |
962 | text_encoder.text_model.embeddings.token_embedding.parameters(), | 988 | ( |
989 | param | ||
990 | for param in text_encoder.text_model.embeddings.token_embedding.parameters() | ||
991 | if param.requires_grad | ||
992 | ), | ||
963 | lr=args.learning_rate, | 993 | lr=args.learning_rate, |
964 | ) | 994 | ) |
965 | 995 | ||
@@ -973,9 +1003,11 @@ def main(): | |||
973 | 1003 | ||
974 | if response.lower().strip() == "o": | 1004 | if response.lower().strip() == "o": |
975 | if args.learning_rate is not None: | 1005 | if args.learning_rate is not None: |
976 | learning_rate = args.learning_rate * 2 | 1006 | learning_rate = ( |
1007 | args.learning_rate * 2 * (args.cycle_decay**training_iter) | ||
1008 | ) | ||
977 | else: | 1009 | else: |
978 | learning_rate = args.learning_rate | 1010 | learning_rate = args.learning_rate * (args.cycle_decay**training_iter) |
979 | 1011 | ||
980 | if response.lower().strip() == "o": | 1012 | if response.lower().strip() == "o": |
981 | lr_scheduler = "one_cycle" | 1013 | lr_scheduler = "one_cycle" |
@@ -1045,8 +1077,6 @@ def main(): | |||
1045 | ) | 1077 | ) |
1046 | 1078 | ||
1047 | training_iter += 1 | 1079 | training_iter += 1 |
1048 | if learning_rate is not None: | ||
1049 | learning_rate *= args.cycle_decay | ||
1050 | 1080 | ||
1051 | accelerator.end_training() | 1081 | accelerator.end_training() |
1052 | 1082 | ||