diff options
-rw-r--r-- | train_lora.py | 13 | ||||
-rw-r--r-- | train_ti.py | 9 |
2 files changed, 20 insertions, 2 deletions
diff --git a/train_lora.py b/train_lora.py index 5c78664..4d4c16a 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -95,6 +95,11 @@ def parse_args(): | |||
95 | help="Learning rate decay per cycle." | 95 | help="Learning rate decay per cycle." |
96 | ) | 96 | ) |
97 | parser.add_argument( | 97 | parser.add_argument( |
98 | "--cycle_constant", | ||
99 | action="store_true", | ||
100 | help="Use constant LR on cycles > 1." | ||
101 | ) | ||
102 | parser.add_argument( | ||
98 | "--placeholder_tokens", | 103 | "--placeholder_tokens", |
99 | type=str, | 104 | type=str, |
100 | nargs='*', | 105 | nargs='*', |
@@ -910,7 +915,6 @@ def main(): | |||
910 | 915 | ||
911 | create_lr_scheduler = partial( | 916 | create_lr_scheduler = partial( |
912 | get_scheduler, | 917 | get_scheduler, |
913 | args.lr_scheduler, | ||
914 | min_lr=args.lr_min_lr, | 918 | min_lr=args.lr_min_lr, |
915 | warmup_func=args.lr_warmup_func, | 919 | warmup_func=args.lr_warmup_func, |
916 | annealing_func=args.lr_annealing_func, | 920 | annealing_func=args.lr_annealing_func, |
@@ -918,7 +922,6 @@ def main(): | |||
918 | annealing_exp=args.lr_annealing_exp, | 922 | annealing_exp=args.lr_annealing_exp, |
919 | cycles=args.lr_cycles, | 923 | cycles=args.lr_cycles, |
920 | end_lr=1e2, | 924 | end_lr=1e2, |
921 | warmup_epochs=args.lr_warmup_epochs, | ||
922 | mid_point=args.lr_mid_point, | 925 | mid_point=args.lr_mid_point, |
923 | ) | 926 | ) |
924 | 927 | ||
@@ -971,6 +974,10 @@ def main(): | |||
971 | print(f"============ LoRA cycle {training_iter + 1} ============") | 974 | print(f"============ LoRA cycle {training_iter + 1} ============") |
972 | print("") | 975 | print("") |
973 | 976 | ||
977 | if args.cycle_constant and training_iter == 1: | ||
978 | args.lr_scheduler = "constant" | ||
979 | args.lr_warmup_epochs = 0 | ||
980 | |||
974 | params_to_optimize = [] | 981 | params_to_optimize = [] |
975 | 982 | ||
976 | if len(args.placeholder_tokens) != 0: | 983 | if len(args.placeholder_tokens) != 0: |
@@ -1005,10 +1012,12 @@ def main(): | |||
1005 | lora_optimizer = create_optimizer(params_to_optimize) | 1012 | lora_optimizer = create_optimizer(params_to_optimize) |
1006 | 1013 | ||
1007 | lora_lr_scheduler = create_lr_scheduler( | 1014 | lora_lr_scheduler = create_lr_scheduler( |
1015 | args.lr_scheduler, | ||
1008 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 1016 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
1009 | optimizer=lora_optimizer, | 1017 | optimizer=lora_optimizer, |
1010 | num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), | 1018 | num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), |
1011 | train_epochs=num_train_epochs, | 1019 | train_epochs=num_train_epochs, |
1020 | warmup_epochs=args.lr_warmup_epochs, | ||
1012 | ) | 1021 | ) |
1013 | 1022 | ||
1014 | lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}" | 1023 | lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}" |
diff --git a/train_ti.py b/train_ti.py index 45e730a..c452269 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -79,6 +79,11 @@ def parse_args(): | |||
79 | help="Learning rate decay per cycle." | 79 | help="Learning rate decay per cycle." |
80 | ) | 80 | ) |
81 | parser.add_argument( | 81 | parser.add_argument( |
82 | "--cycle_constant", | ||
83 | action="store_true", | ||
84 | help="Use constant LR on cycles > 1." | ||
85 | ) | ||
86 | parser.add_argument( | ||
82 | "--placeholder_tokens", | 87 | "--placeholder_tokens", |
83 | type=str, | 88 | type=str, |
84 | nargs='*', | 89 | nargs='*', |
@@ -926,6 +931,10 @@ def main(): | |||
926 | print(f"------------ TI cycle {training_iter + 1} ------------") | 931 | print(f"------------ TI cycle {training_iter + 1} ------------") |
927 | print("") | 932 | print("") |
928 | 933 | ||
934 | if args.cycle_constant and training_iter == 1: | ||
935 | args.lr_scheduler = "constant" | ||
936 | args.lr_warmup_epochs = 0 | ||
937 | |||
929 | optimizer = create_optimizer( | 938 | optimizer = create_optimizer( |
930 | text_encoder.text_model.embeddings.token_embedding.parameters(), | 939 | text_encoder.text_model.embeddings.token_embedding.parameters(), |
931 | lr=learning_rate, | 940 | lr=learning_rate, |