diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 13 | 
1 files changed, 9 insertions, 4 deletions
| diff --git a/train_ti.py b/train_ti.py index b00b0d7..84ca296 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -906,9 +906,6 @@ def main(): | |||
| 906 | if args.sample_num is not None: | 906 | if args.sample_num is not None: | 
| 907 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 907 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 
| 908 | 908 | ||
| 909 | training_iter = 0 | ||
| 910 | learning_rate = args.learning_rate | ||
| 911 | |||
| 912 | project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti" | 909 | project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti" | 
| 913 | 910 | ||
| 914 | if accelerator.is_main_process: | 911 | if accelerator.is_main_process: | 
| @@ -916,7 +913,9 @@ def main(): | |||
| 916 | 913 | ||
| 917 | sample_output_dir = output_dir / project / "samples" | 914 | sample_output_dir = output_dir / project / "samples" | 
| 918 | 915 | ||
| 916 | training_iter = 0 | ||
| 919 | auto_cycles = list(args.auto_cycles) | 917 | auto_cycles = list(args.auto_cycles) | 
| 918 | learning_rate = args.learning_rate | ||
| 920 | lr_scheduler = args.lr_scheduler | 919 | lr_scheduler = args.lr_scheduler | 
| 921 | lr_warmup_epochs = args.lr_warmup_epochs | 920 | lr_warmup_epochs = args.lr_warmup_epochs | 
| 922 | lr_cycles = args.lr_cycles | 921 | lr_cycles = args.lr_cycles | 
| @@ -929,6 +928,12 @@ def main(): | |||
| 929 | "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") | 928 | "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") | 
| 930 | 929 | ||
| 931 | if response.lower().strip() == "o": | 930 | if response.lower().strip() == "o": | 
| 931 | if args.learning_rate is not None: | ||
| 932 | learning_rate = args.learning_rate * 2 | ||
| 933 | else: | ||
| 934 | learning_rate = args.learning_rate | ||
| 935 | |||
| 936 | if response.lower().strip() == "o": | ||
| 932 | lr_scheduler = "one_cycle" | 937 | lr_scheduler = "one_cycle" | 
| 933 | lr_warmup_epochs = args.lr_warmup_epochs | 938 | lr_warmup_epochs = args.lr_warmup_epochs | 
| 934 | lr_cycles = args.lr_cycles | 939 | lr_cycles = args.lr_cycles | 
| @@ -945,7 +950,7 @@ def main(): | |||
| 945 | break | 950 | break | 
| 946 | 951 | ||
| 947 | print("") | 952 | print("") | 
| 948 | print(f"------------ TI cycle {training_iter + 1} ------------") | 953 | print(f"------------ TI cycle {training_iter + 1}: {response} ------------") | 
| 949 | print("") | 954 | print("") | 
| 950 | 955 | ||
| 951 | optimizer = create_optimizer( | 956 | optimizer = create_optimizer( | 
