diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index 880320f..b00b0d7 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -925,18 +925,18 @@ def main(): | |||
925 | if len(auto_cycles) != 0: | 925 | if len(auto_cycles) != 0: |
926 | response = auto_cycles.pop(0) | 926 | response = auto_cycles.pop(0) |
927 | else: | 927 | else: |
928 | response = input("Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") | 928 | response = input( |
929 | "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") | ||
929 | 930 | ||
930 | if response.lower().strip() == "o": | 931 | if response.lower().strip() == "o": |
931 | lr_scheduler = "one_cycle" | 932 | lr_scheduler = "one_cycle" |
932 | lr_warmup_epochs = args.lr_warmup_epochs | 933 | lr_warmup_epochs = args.lr_warmup_epochs |
933 | lr_cycles = args.lr_cycles | 934 | lr_cycles = args.lr_cycles |
934 | if response.lower().strip() == "w": | 935 | if response.lower().strip() == "w": |
935 | lr_scheduler = "constant" | 936 | lr_scheduler = "constant_with_warmup" |
936 | lr_warmup_epochs = num_train_epochs | 937 | lr_warmup_epochs = num_train_epochs |
937 | if response.lower().strip() == "c": | 938 | if response.lower().strip() == "c": |
938 | lr_scheduler = "constant" | 939 | lr_scheduler = "constant" |
939 | lr_warmup_epochs = 0 | ||
940 | if response.lower().strip() == "d": | 940 | if response.lower().strip() == "d": |
941 | lr_scheduler = "cosine" | 941 | lr_scheduler = "cosine" |
942 | lr_warmup_epochs = 0 | 942 | lr_warmup_epochs = 0 |