summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py
index 84ca296..d1e5467 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -937,17 +937,19 @@ def main():
937 lr_scheduler = "one_cycle" 937 lr_scheduler = "one_cycle"
938 lr_warmup_epochs = args.lr_warmup_epochs 938 lr_warmup_epochs = args.lr_warmup_epochs
939 lr_cycles = args.lr_cycles 939 lr_cycles = args.lr_cycles
940 if response.lower().strip() == "w": 940 elif response.lower().strip() == "w":
941 lr_scheduler = "constant_with_warmup" 941 lr_scheduler = "constant_with_warmup"
942 lr_warmup_epochs = num_train_epochs 942 lr_warmup_epochs = num_train_epochs
943 if response.lower().strip() == "c": 943 elif response.lower().strip() == "c":
944 lr_scheduler = "constant" 944 lr_scheduler = "constant"
945 if response.lower().strip() == "d": 945 elif response.lower().strip() == "d":
946 lr_scheduler = "cosine" 946 lr_scheduler = "cosine"
947 lr_warmup_epochs = 0 947 lr_warmup_epochs = 0
948 lr_cycles = 1 948 lr_cycles = 1
949 elif response.lower().strip() == "s": 949 elif response.lower().strip() == "s":
950 break 950 break
951 else:
952 continue
951 953
952 print("") 954 print("")
953 print(f"------------ TI cycle {training_iter + 1}: {response} ------------") 955 print(f"------------ TI cycle {training_iter + 1}: {response} ------------")