summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-09 16:21:52 +0200
committerVolpeon <git@volpeon.ink>2023-04-09 16:21:52 +0200
commit776213e99da4ec389575e797d93de8d8960fa010 (patch)
treea21a76e32dbacb707c3d251c56e92d618d5e921b /train_ti.py
parentFix (diff)
downloadtextual-inversion-diff-776213e99da4ec389575e797d93de8d8960fa010.tar.gz
textual-inversion-diff-776213e99da4ec389575e797d93de8d8960fa010.tar.bz2
textual-inversion-diff-776213e99da4ec389575e797d93de8d8960fa010.zip
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py17
1 files changed, 14 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py
index ca5b113..ebac302 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -64,6 +64,12 @@ def parse_args():
64 help="The name of the current project.", 64 help="The name of the current project.",
65 ) 65 )
66 parser.add_argument( 66 parser.add_argument(
67 "--auto_cycles",
68 type=int,
69 default=1,
70 help="How many cycles to run automatically."
71 )
72 parser.add_argument(
67 "--placeholder_tokens", 73 "--placeholder_tokens",
68 type=str, 74 type=str,
69 nargs='*', 75 nargs='*',
@@ -869,10 +875,15 @@ def main():
869 mid_point=args.lr_mid_point, 875 mid_point=args.lr_mid_point,
870 ) 876 )
871 877
872 continue_training = True 878 training_iter = 0
873 training_iter = 1 879
880 while True:
881 training_iter += 1
882 if training_iter > args.auto_cycles:
883 response = input("Run another cycle? [y/n] ")
884 if response.lower().strip() == "n":
885 break
874 886
875 while continue_training:
876 print("") 887 print("")
877 print(f"------------ TI cycle {training_iter} ------------") 888 print(f"------------ TI cycle {training_iter} ------------")
878 print("") 889 print("")