diff options
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 21 |
1 files changed, 14 insertions, 7 deletions
diff --git a/train_lora.py b/train_lora.py index 54c9e7a..e81742a 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -81,6 +81,12 @@ def parse_args(): | |||
| 81 | help="The name of the current project.", | 81 | help="The name of the current project.", |
| 82 | ) | 82 | ) |
| 83 | parser.add_argument( | 83 | parser.add_argument( |
| 84 | "--auto_cycles", | ||
| 85 | type=int, | ||
| 86 | default=1, | ||
| 87 | help="How many cycles to run automatically." | ||
| 88 | ) | ||
| 89 | parser.add_argument( | ||
| 84 | "--placeholder_tokens", | 90 | "--placeholder_tokens", |
| 85 | type=str, | 91 | type=str, |
| 86 | nargs='*', | 92 | nargs='*', |
| @@ -933,10 +939,15 @@ def main(): | |||
| 933 | train_epochs=num_train_epochs, | 939 | train_epochs=num_train_epochs, |
| 934 | ) | 940 | ) |
| 935 | 941 | ||
| 936 | continue_training = True | 942 | training_iter = 0 |
| 937 | training_iter = 1 | 943 | |
| 944 | while True: | ||
| 945 | training_iter += 1 | ||
| 946 | if training_iter > args.auto_cycles: | ||
| 947 | response = input("Run another cycle? [y/n] ") | ||
| 948 | if response.lower().strip() == "n": | ||
| 949 | break | ||
| 938 | 950 | ||
| 939 | while continue_training: | ||
| 940 | print("") | 951 | print("") |
| 941 | print(f"============ LoRA cycle {training_iter} ============") | 952 | print(f"============ LoRA cycle {training_iter} ============") |
| 942 | print("") | 953 | print("") |
| @@ -961,10 +972,6 @@ def main(): | |||
| 961 | sample_frequency=lora_sample_frequency, | 972 | sample_frequency=lora_sample_frequency, |
| 962 | ) | 973 | ) |
| 963 | 974 | ||
| 964 | response = input("Run another cycle? [y/n] ") | ||
| 965 | continue_training = response.lower().strip() != "n" | ||
| 966 | training_iter += 1 | ||
| 967 | |||
| 968 | 975 | ||
| 969 | if __name__ == "__main__": | 976 | if __name__ == "__main__": |
| 970 | main() | 977 | main() |
