diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 5 |
1 files changed, 2 insertions, 3 deletions
diff --git a/train_lora.py b/train_lora.py index ba5aee1..d0313fe 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -967,18 +967,17 @@ def main(): | |||
967 | if len(auto_cycles) != 0: | 967 | if len(auto_cycles) != 0: |
968 | response = auto_cycles.pop(0) | 968 | response = auto_cycles.pop(0) |
969 | else: | 969 | else: |
970 | response = input("Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") | 970 | response = input("\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") |
971 | 971 | ||
972 | if response.lower().strip() == "o": | 972 | if response.lower().strip() == "o": |
973 | lr_scheduler = "one_cycle" | 973 | lr_scheduler = "one_cycle" |
974 | lr_warmup_epochs = args.lr_warmup_epochs | 974 | lr_warmup_epochs = args.lr_warmup_epochs |
975 | lr_cycles = args.lr_cycles | 975 | lr_cycles = args.lr_cycles |
976 | if response.lower().strip() == "w": | 976 | if response.lower().strip() == "w": |
977 | lr_scheduler = "constant" | 977 | lr_scheduler = "constant_with_warmup" |
978 | lr_warmup_epochs = num_train_epochs | 978 | lr_warmup_epochs = num_train_epochs |
979 | if response.lower().strip() == "c": | 979 | if response.lower().strip() == "c": |
980 | lr_scheduler = "constant" | 980 | lr_scheduler = "constant" |
981 | lr_warmup_epochs = 0 | ||
982 | if response.lower().strip() == "d": | 981 | if response.lower().strip() == "d": |
983 | lr_scheduler = "cosine" | 982 | lr_scheduler = "cosine" |
984 | lr_warmup_epochs = 0 | 983 | lr_warmup_epochs = 0 |