summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-27 11:58:47 +0200
committerVolpeon <git@volpeon.ink>2023-03-27 11:58:47 +0200
commit1c63552a20f34bccd461ac0dfa46405f853cbc7c (patch)
treedf26b48ff4c2ef79349b0a4025cdde05b0ed8518 /train_ti.py
parentFix TI (diff)
downloadtextual-inversion-diff-1c63552a20f34bccd461ac0dfa46405f853cbc7c.tar.gz
textual-inversion-diff-1c63552a20f34bccd461ac0dfa46405f853cbc7c.tar.bz2
textual-inversion-diff-1c63552a20f34bccd461ac0dfa46405f853cbc7c.zip
Fix TI
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py11
1 files changed, 10 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py
index ef39c38..9ae8d1b 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -155,7 +155,7 @@ def parse_args():
155 parser.add_argument( 155 parser.add_argument(
156 "--num_buckets", 156 "--num_buckets",
157 type=int, 157 type=int,
158 default=2, 158 default=0,
159 help="Number of aspect ratio buckets in either direction.", 159 help="Number of aspect ratio buckets in either direction.",
160 ) 160 )
161 parser.add_argument( 161 parser.add_argument(
@@ -507,9 +507,18 @@ def parse_args():
507 if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): 507 if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors):
508 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") 508 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
509 509
510 if args.alias_tokens is None:
511 args.alias_tokens = []
512
510 if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: 513 if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0:
511 raise ValueError("--alias_tokens must be a list with an even number of items") 514 raise ValueError("--alias_tokens must be a list with an even number of items")
512 515
516 args.alias_tokens += [
517 item
518 for pair in zip(args.placeholder_tokens, args.initializer_tokens)
519 for item in pair
520 ]
521
513 if args.sequential: 522 if args.sequential:
514 if isinstance(args.train_data_template, str): 523 if isinstance(args.train_data_template, str):
515 args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) 524 args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens)