diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 17 |
1 files changed, 14 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index ca5b113..ebac302 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -64,6 +64,12 @@ def parse_args(): | |||
64 | help="The name of the current project.", | 64 | help="The name of the current project.", |
65 | ) | 65 | ) |
66 | parser.add_argument( | 66 | parser.add_argument( |
67 | "--auto_cycles", | ||
68 | type=int, | ||
69 | default=1, | ||
70 | help="How many cycles to run automatically." | ||
71 | ) | ||
72 | parser.add_argument( | ||
67 | "--placeholder_tokens", | 73 | "--placeholder_tokens", |
68 | type=str, | 74 | type=str, |
69 | nargs='*', | 75 | nargs='*', |
@@ -869,10 +875,15 @@ def main(): | |||
869 | mid_point=args.lr_mid_point, | 875 | mid_point=args.lr_mid_point, |
870 | ) | 876 | ) |
871 | 877 | ||
872 | continue_training = True | 878 | training_iter = 0 |
873 | training_iter = 1 | 879 | |
880 | while True: | ||
881 | training_iter += 1 | ||
882 | if training_iter > args.auto_cycles: | ||
883 | response = input("Run another cycle? [y/n] ") | ||
884 | if response.lower().strip() == "n": | ||
885 | break | ||
874 | 886 | ||
875 | while continue_training: | ||
876 | print("") | 887 | print("") |
877 | print(f"------------ TI cycle {training_iter} ------------") | 888 | print(f"------------ TI cycle {training_iter} ------------") |
878 | print("") | 889 | print("") |