diff options
| -rw-r--r-- | train_lora.py | 8 | ||||
| -rw-r--r-- | train_ti.py | 8 | 
2 files changed, 10 insertions, 6 deletions
| diff --git a/train_lora.py b/train_lora.py index 1d1485d..c197206 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -1057,17 +1057,19 @@ def main(): | |||
| 1057 | lr_scheduler = "one_cycle" | 1057 | lr_scheduler = "one_cycle" | 
| 1058 | lr_warmup_epochs = args.lr_warmup_epochs | 1058 | lr_warmup_epochs = args.lr_warmup_epochs | 
| 1059 | lr_cycles = args.lr_cycles | 1059 | lr_cycles = args.lr_cycles | 
| 1060 | if response.lower().strip() == "w": | 1060 | elif response.lower().strip() == "w": | 
| 1061 | lr_scheduler = "constant_with_warmup" | 1061 | lr_scheduler = "constant_with_warmup" | 
| 1062 | lr_warmup_epochs = num_train_epochs | 1062 | lr_warmup_epochs = num_train_epochs | 
| 1063 | if response.lower().strip() == "c": | 1063 | elif response.lower().strip() == "c": | 
| 1064 | lr_scheduler = "constant" | 1064 | lr_scheduler = "constant" | 
| 1065 | if response.lower().strip() == "d": | 1065 | elif response.lower().strip() == "d": | 
| 1066 | lr_scheduler = "cosine" | 1066 | lr_scheduler = "cosine" | 
| 1067 | lr_warmup_epochs = 0 | 1067 | lr_warmup_epochs = 0 | 
| 1068 | lr_cycles = 1 | 1068 | lr_cycles = 1 | 
| 1069 | elif response.lower().strip() == "s": | 1069 | elif response.lower().strip() == "s": | 
| 1070 | break | 1070 | break | 
| 1071 | else: | ||
| 1072 | continue | ||
| 1071 | 1073 | ||
| 1072 | print("") | 1074 | print("") | 
| 1073 | print(f"============ LoRA cycle {training_iter + 1}: {response} ============") | 1075 | print(f"============ LoRA cycle {training_iter + 1}: {response} ============") | 
| diff --git a/train_ti.py b/train_ti.py index 84ca296..d1e5467 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -937,17 +937,19 @@ def main(): | |||
| 937 | lr_scheduler = "one_cycle" | 937 | lr_scheduler = "one_cycle" | 
| 938 | lr_warmup_epochs = args.lr_warmup_epochs | 938 | lr_warmup_epochs = args.lr_warmup_epochs | 
| 939 | lr_cycles = args.lr_cycles | 939 | lr_cycles = args.lr_cycles | 
| 940 | if response.lower().strip() == "w": | 940 | elif response.lower().strip() == "w": | 
| 941 | lr_scheduler = "constant_with_warmup" | 941 | lr_scheduler = "constant_with_warmup" | 
| 942 | lr_warmup_epochs = num_train_epochs | 942 | lr_warmup_epochs = num_train_epochs | 
| 943 | if response.lower().strip() == "c": | 943 | elif response.lower().strip() == "c": | 
| 944 | lr_scheduler = "constant" | 944 | lr_scheduler = "constant" | 
| 945 | if response.lower().strip() == "d": | 945 | elif response.lower().strip() == "d": | 
| 946 | lr_scheduler = "cosine" | 946 | lr_scheduler = "cosine" | 
| 947 | lr_warmup_epochs = 0 | 947 | lr_warmup_epochs = 0 | 
| 948 | lr_cycles = 1 | 948 | lr_cycles = 1 | 
| 949 | elif response.lower().strip() == "s": | 949 | elif response.lower().strip() == "s": | 
| 950 | break | 950 | break | 
| 951 | else: | ||
| 952 | continue | ||
| 951 | 953 | ||
| 952 | print("") | 954 | print("") | 
| 953 | print(f"------------ TI cycle {training_iter + 1}: {response} ------------") | 955 | print(f"------------ TI cycle {training_iter + 1}: {response} ------------") | 
