summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-07 07:11:51 +0100
committerVolpeon <git@volpeon.ink>2023-03-07 07:11:51 +0100
commitfe3113451fdde72ddccfc71639f0a2a1e146209a (patch)
treeba4114faf1bd00a642f97b5e7729ad74213c3b80 /train_ti.py
parentUpdate (diff)
downloadtextual-inversion-diff-fe3113451fdde72ddccfc71639f0a2a1e146209a.tar.gz
textual-inversion-diff-fe3113451fdde72ddccfc71639f0a2a1e146209a.tar.bz2
textual-inversion-diff-fe3113451fdde72ddccfc71639f0a2a1e146209a.zip
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/train_ti.py b/train_ti.py
index b9d6e56..81938c8 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -476,13 +476,10 @@ def parse_args():
476 if len(args.placeholder_tokens) != len(args.initializer_tokens): 476 if len(args.placeholder_tokens) != len(args.initializer_tokens):
477 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") 477 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items")
478 478
479 if args.num_vectors is None:
480 args.num_vectors = 1
481
482 if isinstance(args.num_vectors, int): 479 if isinstance(args.num_vectors, int):
483 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) 480 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens)
484 481
485 if len(args.placeholder_tokens) != len(args.num_vectors): 482 if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors):
486 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") 483 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
487 484
488 if args.sequential: 485 if args.sequential:
@@ -491,6 +488,9 @@ def parse_args():
491 488
492 if len(args.placeholder_tokens) != len(args.train_data_template): 489 if len(args.placeholder_tokens) != len(args.train_data_template):
493 raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") 490 raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items")
491
492 if args.num_vectors is None:
493 args.num_vectors = [None] * len(args.placeholder_tokens)
494 else: 494 else:
495 if isinstance(args.train_data_template, list): 495 if isinstance(args.train_data_template, list):
496 raise ValueError("--train_data_template can't be a list in simultaneous mode") 496 raise ValueError("--train_data_template can't be a list in simultaneous mode")