diff options
author | Volpeon <git@volpeon.ink> | 2023-03-07 07:11:51 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-07 07:11:51 +0100 |
commit | fe3113451fdde72ddccfc71639f0a2a1e146209a (patch) | |
tree | ba4114faf1bd00a642f97b5e7729ad74213c3b80 /train_ti.py | |
parent | Update (diff) | |
download | textual-inversion-diff-fe3113451fdde72ddccfc71639f0a2a1e146209a.tar.gz textual-inversion-diff-fe3113451fdde72ddccfc71639f0a2a1e146209a.tar.bz2 textual-inversion-diff-fe3113451fdde72ddccfc71639f0a2a1e146209a.zip |
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/train_ti.py b/train_ti.py index b9d6e56..81938c8 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -476,13 +476,10 @@ def parse_args(): | |||
476 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | 476 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
477 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | 477 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") |
478 | 478 | ||
479 | if args.num_vectors is None: | ||
480 | args.num_vectors = 1 | ||
481 | |||
482 | if isinstance(args.num_vectors, int): | 479 | if isinstance(args.num_vectors, int): |
483 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | 480 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) |
484 | 481 | ||
485 | if len(args.placeholder_tokens) != len(args.num_vectors): | 482 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): |
486 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 483 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") |
487 | 484 | ||
488 | if args.sequential: | 485 | if args.sequential: |
@@ -491,6 +488,9 @@ def parse_args(): | |||
491 | 488 | ||
492 | if len(args.placeholder_tokens) != len(args.train_data_template): | 489 | if len(args.placeholder_tokens) != len(args.train_data_template): |
493 | raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") | 490 | raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") |
491 | |||
492 | if args.num_vectors is None: | ||
493 | args.num_vectors = [None] * len(args.placeholder_tokens) | ||
494 | else: | 494 | else: |
495 | if isinstance(args.train_data_template, list): | 495 | if isinstance(args.train_data_template, list): |
496 | raise ValueError("--train_data_template can't be a list in simultaneous mode") | 496 | raise ValueError("--train_data_template can't be a list in simultaneous mode") |