diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 21 |
1 files changed, 14 insertions, 7 deletions
diff --git a/train_lora.py b/train_lora.py index 54c9e7a..e81742a 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -81,6 +81,12 @@ def parse_args(): | |||
81 | help="The name of the current project.", | 81 | help="The name of the current project.", |
82 | ) | 82 | ) |
83 | parser.add_argument( | 83 | parser.add_argument( |
84 | "--auto_cycles", | ||
85 | type=int, | ||
86 | default=1, | ||
87 | help="How many cycles to run automatically." | ||
88 | ) | ||
89 | parser.add_argument( | ||
84 | "--placeholder_tokens", | 90 | "--placeholder_tokens", |
85 | type=str, | 91 | type=str, |
86 | nargs='*', | 92 | nargs='*', |
@@ -933,10 +939,15 @@ def main(): | |||
933 | train_epochs=num_train_epochs, | 939 | train_epochs=num_train_epochs, |
934 | ) | 940 | ) |
935 | 941 | ||
936 | continue_training = True | 942 | training_iter = 0 |
937 | training_iter = 1 | 943 | |
944 | while True: | ||
945 | training_iter += 1 | ||
946 | if training_iter > args.auto_cycles: | ||
947 | response = input("Run another cycle? [y/n] ") | ||
948 | if response.lower().strip() == "n": | ||
949 | break | ||
938 | 950 | ||
939 | while continue_training: | ||
940 | print("") | 951 | print("") |
941 | print(f"============ LoRA cycle {training_iter} ============") | 952 | print(f"============ LoRA cycle {training_iter} ============") |
942 | print("") | 953 | print("") |
@@ -961,10 +972,6 @@ def main(): | |||
961 | sample_frequency=lora_sample_frequency, | 972 | sample_frequency=lora_sample_frequency, |
962 | ) | 973 | ) |
963 | 974 | ||
964 | response = input("Run another cycle? [y/n] ") | ||
965 | continue_training = response.lower().strip() != "n" | ||
966 | training_iter += 1 | ||
967 | |||
968 | 975 | ||
969 | if __name__ == "__main__": | 976 | if __name__ == "__main__": |
970 | main() | 977 | main() |