summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/train_ti.py b/train_ti.py
index b00b0d7..84ca296 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -906,9 +906,6 @@ def main():
906 if args.sample_num is not None: 906 if args.sample_num is not None:
907 sample_frequency = math.ceil(num_train_epochs / args.sample_num) 907 sample_frequency = math.ceil(num_train_epochs / args.sample_num)
908 908
909 training_iter = 0
910 learning_rate = args.learning_rate
911
912 project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti" 909 project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti"
913 910
914 if accelerator.is_main_process: 911 if accelerator.is_main_process:
@@ -916,7 +913,9 @@ def main():
916 913
917 sample_output_dir = output_dir / project / "samples" 914 sample_output_dir = output_dir / project / "samples"
918 915
916 training_iter = 0
919 auto_cycles = list(args.auto_cycles) 917 auto_cycles = list(args.auto_cycles)
918 learning_rate = args.learning_rate
920 lr_scheduler = args.lr_scheduler 919 lr_scheduler = args.lr_scheduler
921 lr_warmup_epochs = args.lr_warmup_epochs 920 lr_warmup_epochs = args.lr_warmup_epochs
922 lr_cycles = args.lr_cycles 921 lr_cycles = args.lr_cycles
@@ -929,6 +928,12 @@ def main():
929 "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") 928 "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ")
930 929
931 if response.lower().strip() == "o": 930 if response.lower().strip() == "o":
931 if args.learning_rate is not None:
932 learning_rate = args.learning_rate * 2
933 else:
934 learning_rate = args.learning_rate
935
936 if response.lower().strip() == "o":
932 lr_scheduler = "one_cycle" 937 lr_scheduler = "one_cycle"
933 lr_warmup_epochs = args.lr_warmup_epochs 938 lr_warmup_epochs = args.lr_warmup_epochs
934 lr_cycles = args.lr_cycles 939 lr_cycles = args.lr_cycles
@@ -945,7 +950,7 @@ def main():
945 break 950 break
946 951
947 print("") 952 print("")
948 print(f"------------ TI cycle {training_iter + 1} ------------") 953 print(f"------------ TI cycle {training_iter + 1}: {response} ------------")
949 print("") 954 print("")
950 955
951 optimizer = create_optimizer( 956 optimizer = create_optimizer(