diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/train_ti.py b/train_ti.py index b00b0d7..84ca296 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -906,9 +906,6 @@ def main(): | |||
906 | if args.sample_num is not None: | 906 | if args.sample_num is not None: |
907 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 907 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
908 | 908 | ||
909 | training_iter = 0 | ||
910 | learning_rate = args.learning_rate | ||
911 | |||
912 | project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti" | 909 | project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti" |
913 | 910 | ||
914 | if accelerator.is_main_process: | 911 | if accelerator.is_main_process: |
@@ -916,7 +913,9 @@ def main(): | |||
916 | 913 | ||
917 | sample_output_dir = output_dir / project / "samples" | 914 | sample_output_dir = output_dir / project / "samples" |
918 | 915 | ||
916 | training_iter = 0 | ||
919 | auto_cycles = list(args.auto_cycles) | 917 | auto_cycles = list(args.auto_cycles) |
918 | learning_rate = args.learning_rate | ||
920 | lr_scheduler = args.lr_scheduler | 919 | lr_scheduler = args.lr_scheduler |
921 | lr_warmup_epochs = args.lr_warmup_epochs | 920 | lr_warmup_epochs = args.lr_warmup_epochs |
922 | lr_cycles = args.lr_cycles | 921 | lr_cycles = args.lr_cycles |
@@ -929,6 +928,12 @@ def main(): | |||
929 | "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") | 928 | "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") |
930 | 929 | ||
931 | if response.lower().strip() == "o": | 930 | if response.lower().strip() == "o": |
931 | if args.learning_rate is not None: | ||
932 | learning_rate = args.learning_rate * 2 | ||
933 | else: | ||
934 | learning_rate = args.learning_rate | ||
935 | |||
936 | if response.lower().strip() == "o": | ||
932 | lr_scheduler = "one_cycle" | 937 | lr_scheduler = "one_cycle" |
933 | lr_warmup_epochs = args.lr_warmup_epochs | 938 | lr_warmup_epochs = args.lr_warmup_epochs |
934 | lr_cycles = args.lr_cycles | 939 | lr_cycles = args.lr_cycles |
@@ -945,7 +950,7 @@ def main(): | |||
945 | break | 950 | break |
946 | 951 | ||
947 | print("") | 952 | print("") |
948 | print(f"------------ TI cycle {training_iter + 1} ------------") | 953 | print(f"------------ TI cycle {training_iter + 1}: {response} ------------") |
949 | print("") | 954 | print("") |
950 | 955 | ||
951 | optimizer = create_optimizer( | 956 | optimizer = create_optimizer( |