diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 42 |
1 files changed, 36 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py index 1d0cb6f..a7d2924 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -337,7 +337,16 @@ def parse_args(): | |||
| 337 | "--optimizer", | 337 | "--optimizer", |
| 338 | type=str, | 338 | type=str, |
| 339 | default="adan", | 339 | default="adan", |
| 340 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], | 340 | choices=[ |
| 341 | "adam", | ||
| 342 | "adam8bit", | ||
| 343 | "adan", | ||
| 344 | "lion", | ||
| 345 | "dadam", | ||
| 346 | "dadan", | ||
| 347 | "adafactor", | ||
| 348 | "prodigy", | ||
| 349 | ], | ||
| 341 | help="Optimizer to use", | 350 | help="Optimizer to use", |
| 342 | ) | 351 | ) |
| 343 | parser.add_argument( | 352 | parser.add_argument( |
| @@ -819,6 +828,23 @@ def main(): | |||
| 819 | eps=args.adam_epsilon, | 828 | eps=args.adam_epsilon, |
| 820 | d0=args.dadaptation_d0, | 829 | d0=args.dadaptation_d0, |
| 821 | ) | 830 | ) |
| 831 | elif args.optimizer == "prodigy": | ||
| 832 | try: | ||
| 833 | import prodigyopt | ||
| 834 | except ImportError: | ||
| 835 | raise ImportError( | ||
| 836 | "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." | ||
| 837 | ) | ||
| 838 | |||
| 839 | create_optimizer = partial( | ||
| 840 | prodigyopt.Prodigy, | ||
| 841 | betas=(args.adam_beta1, args.adam_beta2), | ||
| 842 | weight_decay=args.adam_weight_decay, | ||
| 843 | eps=args.adam_epsilon, | ||
| 844 | d0=args.dadaptation_d0, | ||
| 845 | ) | ||
| 846 | |||
| 847 | args.learning_rate = 1.0 | ||
| 822 | else: | 848 | else: |
| 823 | raise ValueError(f'Unknown --optimizer "{args.optimizer}"') | 849 | raise ValueError(f'Unknown --optimizer "{args.optimizer}"') |
| 824 | 850 | ||
| @@ -959,7 +985,11 @@ def main(): | |||
| 959 | avg_acc_val = AverageMeter() | 985 | avg_acc_val = AverageMeter() |
| 960 | 986 | ||
| 961 | optimizer = create_optimizer( | 987 | optimizer = create_optimizer( |
| 962 | text_encoder.text_model.embeddings.token_embedding.parameters(), | 988 | ( |
| 989 | param | ||
| 990 | for param in text_encoder.text_model.embeddings.token_embedding.parameters() | ||
| 991 | if param.requires_grad | ||
| 992 | ), | ||
| 963 | lr=args.learning_rate, | 993 | lr=args.learning_rate, |
| 964 | ) | 994 | ) |
| 965 | 995 | ||
| @@ -973,9 +1003,11 @@ def main(): | |||
| 973 | 1003 | ||
| 974 | if response.lower().strip() == "o": | 1004 | if response.lower().strip() == "o": |
| 975 | if args.learning_rate is not None: | 1005 | if args.learning_rate is not None: |
| 976 | learning_rate = args.learning_rate * 2 | 1006 | learning_rate = ( |
| 1007 | args.learning_rate * 2 * (args.cycle_decay**training_iter) | ||
| 1008 | ) | ||
| 977 | else: | 1009 | else: |
| 978 | learning_rate = args.learning_rate | 1010 | learning_rate = args.learning_rate * (args.cycle_decay**training_iter) |
| 979 | 1011 | ||
| 980 | if response.lower().strip() == "o": | 1012 | if response.lower().strip() == "o": |
| 981 | lr_scheduler = "one_cycle" | 1013 | lr_scheduler = "one_cycle" |
| @@ -1045,8 +1077,6 @@ def main(): | |||
| 1045 | ) | 1077 | ) |
| 1046 | 1078 | ||
| 1047 | training_iter += 1 | 1079 | training_iter += 1 |
| 1048 | if learning_rate is not None: | ||
| 1049 | learning_rate *= args.cycle_decay | ||
| 1050 | 1080 | ||
| 1051 | accelerator.end_training() | 1081 | accelerator.end_training() |
| 1052 | 1082 | ||
