summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py21
1 files changed, 14 insertions, 7 deletions
diff --git a/train_lora.py b/train_lora.py
index 54c9e7a..e81742a 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -81,6 +81,12 @@ def parse_args():
81 help="The name of the current project.", 81 help="The name of the current project.",
82 ) 82 )
83 parser.add_argument( 83 parser.add_argument(
84 "--auto_cycles",
85 type=int,
86 default=1,
87 help="How many cycles to run automatically."
88 )
89 parser.add_argument(
84 "--placeholder_tokens", 90 "--placeholder_tokens",
85 type=str, 91 type=str,
86 nargs='*', 92 nargs='*',
@@ -933,10 +939,15 @@ def main():
933 train_epochs=num_train_epochs, 939 train_epochs=num_train_epochs,
934 ) 940 )
935 941
936 continue_training = True 942 training_iter = 0
937 training_iter = 1 943
944 while True:
945 training_iter += 1
946 if training_iter > args.auto_cycles:
947 response = input("Run another cycle? [y/n] ")
948 if response.lower().strip() == "n":
949 break
938 950
939 while continue_training:
940 print("") 951 print("")
941 print(f"============ LoRA cycle {training_iter} ============") 952 print(f"============ LoRA cycle {training_iter} ============")
942 print("") 953 print("")
@@ -961,10 +972,6 @@ def main():
961 sample_frequency=lora_sample_frequency, 972 sample_frequency=lora_sample_frequency,
962 ) 973 )
963 974
964 response = input("Run another cycle? [y/n] ")
965 continue_training = response.lower().strip() != "n"
966 training_iter += 1
967
968 975
969if __name__ == "__main__": 976if __name__ == "__main__":
970 main() 977 main()