diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-09 16:21:52 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-09 16:21:52 +0200 |
| commit | 776213e99da4ec389575e797d93de8d8960fa010 (patch) | |
| tree | a21a76e32dbacb707c3d251c56e92d618d5e921b /train_ti.py | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-776213e99da4ec389575e797d93de8d8960fa010.tar.gz textual-inversion-diff-776213e99da4ec389575e797d93de8d8960fa010.tar.bz2 textual-inversion-diff-776213e99da4ec389575e797d93de8d8960fa010.zip | |
Update
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 17 |
1 files changed, 14 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index ca5b113..ebac302 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -64,6 +64,12 @@ def parse_args(): | |||
| 64 | help="The name of the current project.", | 64 | help="The name of the current project.", |
| 65 | ) | 65 | ) |
| 66 | parser.add_argument( | 66 | parser.add_argument( |
| 67 | "--auto_cycles", | ||
| 68 | type=int, | ||
| 69 | default=1, | ||
| 70 | help="How many cycles to run automatically." | ||
| 71 | ) | ||
| 72 | parser.add_argument( | ||
| 67 | "--placeholder_tokens", | 73 | "--placeholder_tokens", |
| 68 | type=str, | 74 | type=str, |
| 69 | nargs='*', | 75 | nargs='*', |
| @@ -869,10 +875,15 @@ def main(): | |||
| 869 | mid_point=args.lr_mid_point, | 875 | mid_point=args.lr_mid_point, |
| 870 | ) | 876 | ) |
| 871 | 877 | ||
| 872 | continue_training = True | 878 | training_iter = 0 |
| 873 | training_iter = 1 | 879 | |
| 880 | while True: | ||
| 881 | training_iter += 1 | ||
| 882 | if training_iter > args.auto_cycles: | ||
| 883 | response = input("Run another cycle? [y/n] ") | ||
| 884 | if response.lower().strip() == "n": | ||
| 885 | break | ||
| 874 | 886 | ||
| 875 | while continue_training: | ||
| 876 | print("") | 887 | print("") |
| 877 | print(f"------------ TI cycle {training_iter} ------------") | 888 | print(f"------------ TI cycle {training_iter} ------------") |
| 878 | print("") | 889 | print("") |
