diff options
author | Volpeon <git@volpeon.ink> | 2023-03-26 14:27:54 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-26 14:27:54 +0200 |
commit | 19ae465203c8dcc0b1179584db632015362b5e44 (patch) | |
tree | ad6d45e78826f525c336927e4269197667f1f354 /train_ti.py | |
parent | Fix training with guidance (diff) | |
download | textual-inversion-diff-19ae465203c8dcc0b1179584db632015362b5e44.tar.gz textual-inversion-diff-19ae465203c8dcc0b1179584db632015362b5e44.tar.bz2 textual-inversion-diff-19ae465203c8dcc0b1179584db632015362b5e44.zip |
Improved inverted tokens
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py index 83ad46d..6c35d41 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -81,6 +81,12 @@ def parse_args(): | |||
81 | help="Tokens to create an alias for." | 81 | help="Tokens to create an alias for." |
82 | ) | 82 | ) |
83 | parser.add_argument( | 83 | parser.add_argument( |
84 | "--inverted_initializer_tokens", | ||
85 | type=str, | ||
86 | nargs='*', | ||
87 | help="A token to use as initializer word." | ||
88 | ) | ||
89 | parser.add_argument( | ||
84 | "--num_vectors", | 90 | "--num_vectors", |
85 | type=int, | 91 | type=int, |
86 | nargs='*', | 92 | nargs='*', |
@@ -149,7 +155,7 @@ def parse_args(): | |||
149 | parser.add_argument( | 155 | parser.add_argument( |
150 | "--num_buckets", | 156 | "--num_buckets", |
151 | type=int, | 157 | type=int, |
152 | default=0, | 158 | default=2, |
153 | help="Number of aspect ratio buckets in either direction.", | 159 | help="Number of aspect ratio buckets in either direction.", |
154 | ) | 160 | ) |
155 | parser.add_argument( | 161 | parser.add_argument( |
@@ -488,6 +494,13 @@ def parse_args(): | |||
488 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | 494 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
489 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | 495 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") |
490 | 496 | ||
497 | if isinstance(args.inverted_initializer_tokens, str): | ||
498 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(args.placeholder_tokens) | ||
499 | |||
500 | if isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0: | ||
501 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] | ||
502 | args.initializer_tokens += args.inverted_initializer_tokens | ||
503 | |||
491 | if isinstance(args.num_vectors, int): | 504 | if isinstance(args.num_vectors, int): |
492 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | 505 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) |
493 | 506 | ||
@@ -720,6 +733,7 @@ def main(): | |||
720 | dropout=args.tag_dropout, | 733 | dropout=args.tag_dropout, |
721 | shuffle=not args.no_tag_shuffle, | 734 | shuffle=not args.no_tag_shuffle, |
722 | template_key=data_template, | 735 | template_key=data_template, |
736 | placeholder_tokens=args.placeholder_tokens, | ||
723 | valid_set_size=args.valid_set_size, | 737 | valid_set_size=args.valid_set_size, |
724 | train_set_pad=args.train_set_pad, | 738 | train_set_pad=args.train_set_pad, |
725 | valid_set_pad=args.valid_set_pad, | 739 | valid_set_pad=args.valid_set_pad, |