diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/train_lora.py b/train_lora.py index 1d1485d..c197206 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -1057,17 +1057,19 @@ def main(): | |||
1057 | lr_scheduler = "one_cycle" | 1057 | lr_scheduler = "one_cycle" |
1058 | lr_warmup_epochs = args.lr_warmup_epochs | 1058 | lr_warmup_epochs = args.lr_warmup_epochs |
1059 | lr_cycles = args.lr_cycles | 1059 | lr_cycles = args.lr_cycles |
1060 | if response.lower().strip() == "w": | 1060 | elif response.lower().strip() == "w": |
1061 | lr_scheduler = "constant_with_warmup" | 1061 | lr_scheduler = "constant_with_warmup" |
1062 | lr_warmup_epochs = num_train_epochs | 1062 | lr_warmup_epochs = num_train_epochs |
1063 | if response.lower().strip() == "c": | 1063 | elif response.lower().strip() == "c": |
1064 | lr_scheduler = "constant" | 1064 | lr_scheduler = "constant" |
1065 | if response.lower().strip() == "d": | 1065 | elif response.lower().strip() == "d": |
1066 | lr_scheduler = "cosine" | 1066 | lr_scheduler = "cosine" |
1067 | lr_warmup_epochs = 0 | 1067 | lr_warmup_epochs = 0 |
1068 | lr_cycles = 1 | 1068 | lr_cycles = 1 |
1069 | elif response.lower().strip() == "s": | 1069 | elif response.lower().strip() == "s": |
1070 | break | 1070 | break |
1071 | else: | ||
1072 | continue | ||
1071 | 1073 | ||
1072 | print("") | 1074 | print("") |
1073 | print(f"============ LoRA cycle {training_iter + 1}: {response} ============") | 1075 | print(f"============ LoRA cycle {training_iter + 1}: {response} ============") |