summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-26 14:27:54 +0200
committerVolpeon <git@volpeon.ink>2023-03-26 14:27:54 +0200
commit19ae465203c8dcc0b1179584db632015362b5e44 (patch)
treead6d45e78826f525c336927e4269197667f1f354 /train_ti.py
parentFix training with guidance (diff)
downloadtextual-inversion-diff-19ae465203c8dcc0b1179584db632015362b5e44.tar.gz
textual-inversion-diff-19ae465203c8dcc0b1179584db632015362b5e44.tar.bz2
textual-inversion-diff-19ae465203c8dcc0b1179584db632015362b5e44.zip
Improved inverted tokens
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py16
1 files changed, 15 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py
index 83ad46d..6c35d41 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -81,6 +81,12 @@ def parse_args():
81 help="Tokens to create an alias for." 81 help="Tokens to create an alias for."
82 ) 82 )
83 parser.add_argument( 83 parser.add_argument(
84 "--inverted_initializer_tokens",
85 type=str,
86 nargs='*',
87 help="A token to use as initializer word."
88 )
89 parser.add_argument(
84 "--num_vectors", 90 "--num_vectors",
85 type=int, 91 type=int,
86 nargs='*', 92 nargs='*',
@@ -149,7 +155,7 @@ def parse_args():
149 parser.add_argument( 155 parser.add_argument(
150 "--num_buckets", 156 "--num_buckets",
151 type=int, 157 type=int,
152 default=0, 158 default=2,
153 help="Number of aspect ratio buckets in either direction.", 159 help="Number of aspect ratio buckets in either direction.",
154 ) 160 )
155 parser.add_argument( 161 parser.add_argument(
@@ -488,6 +494,13 @@ def parse_args():
488 if len(args.placeholder_tokens) != len(args.initializer_tokens): 494 if len(args.placeholder_tokens) != len(args.initializer_tokens):
489 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") 495 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items")
490 496
497 if isinstance(args.inverted_initializer_tokens, str):
498 args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(args.placeholder_tokens)
499
500 if isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0:
501 args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens]
502 args.initializer_tokens += args.inverted_initializer_tokens
503
491 if isinstance(args.num_vectors, int): 504 if isinstance(args.num_vectors, int):
492 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) 505 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens)
493 506
@@ -720,6 +733,7 @@ def main():
720 dropout=args.tag_dropout, 733 dropout=args.tag_dropout,
721 shuffle=not args.no_tag_shuffle, 734 shuffle=not args.no_tag_shuffle,
722 template_key=data_template, 735 template_key=data_template,
736 placeholder_tokens=args.placeholder_tokens,
723 valid_set_size=args.valid_set_size, 737 valid_set_size=args.valid_set_size,
724 train_set_pad=args.train_set_pad, 738 train_set_pad=args.train_set_pad,
725 valid_set_pad=args.valid_set_pad, 739 valid_set_pad=args.valid_set_pad,