diff options
| author | Volpeon <git@volpeon.ink> | 2023-03-26 14:27:54 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-03-26 14:27:54 +0200 |
| commit | 19ae465203c8dcc0b1179584db632015362b5e44 (patch) | |
| tree | ad6d45e78826f525c336927e4269197667f1f354 /train_ti.py | |
| parent | Fix training with guidance (diff) | |
| download | textual-inversion-diff-19ae465203c8dcc0b1179584db632015362b5e44.tar.gz textual-inversion-diff-19ae465203c8dcc0b1179584db632015362b5e44.tar.bz2 textual-inversion-diff-19ae465203c8dcc0b1179584db632015362b5e44.zip | |
Improved inverted tokens
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py index 83ad46d..6c35d41 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -81,6 +81,12 @@ def parse_args(): | |||
| 81 | help="Tokens to create an alias for." | 81 | help="Tokens to create an alias for." |
| 82 | ) | 82 | ) |
| 83 | parser.add_argument( | 83 | parser.add_argument( |
| 84 | "--inverted_initializer_tokens", | ||
| 85 | type=str, | ||
| 86 | nargs='*', | ||
| 87 | help="A token to use as initializer word." | ||
| 88 | ) | ||
| 89 | parser.add_argument( | ||
| 84 | "--num_vectors", | 90 | "--num_vectors", |
| 85 | type=int, | 91 | type=int, |
| 86 | nargs='*', | 92 | nargs='*', |
| @@ -149,7 +155,7 @@ def parse_args(): | |||
| 149 | parser.add_argument( | 155 | parser.add_argument( |
| 150 | "--num_buckets", | 156 | "--num_buckets", |
| 151 | type=int, | 157 | type=int, |
| 152 | default=0, | 158 | default=2, |
| 153 | help="Number of aspect ratio buckets in either direction.", | 159 | help="Number of aspect ratio buckets in either direction.", |
| 154 | ) | 160 | ) |
| 155 | parser.add_argument( | 161 | parser.add_argument( |
| @@ -488,6 +494,13 @@ def parse_args(): | |||
| 488 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | 494 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
| 489 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | 495 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") |
| 490 | 496 | ||
| 497 | if isinstance(args.inverted_initializer_tokens, str): | ||
| 498 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(args.placeholder_tokens) | ||
| 499 | |||
| 500 | if isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0: | ||
| 501 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] | ||
| 502 | args.initializer_tokens += args.inverted_initializer_tokens | ||
| 503 | |||
| 491 | if isinstance(args.num_vectors, int): | 504 | if isinstance(args.num_vectors, int): |
| 492 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | 505 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) |
| 493 | 506 | ||
| @@ -720,6 +733,7 @@ def main(): | |||
| 720 | dropout=args.tag_dropout, | 733 | dropout=args.tag_dropout, |
| 721 | shuffle=not args.no_tag_shuffle, | 734 | shuffle=not args.no_tag_shuffle, |
| 722 | template_key=data_template, | 735 | template_key=data_template, |
| 736 | placeholder_tokens=args.placeholder_tokens, | ||
| 723 | valid_set_size=args.valid_set_size, | 737 | valid_set_size=args.valid_set_size, |
| 724 | train_set_pad=args.train_set_pad, | 738 | train_set_pad=args.train_set_pad, |
| 725 | valid_set_pad=args.valid_set_pad, | 739 | valid_set_pad=args.valid_set_pad, |
