diff options
| author | Volpeon <git@volpeon.ink> | 2023-03-07 07:11:51 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-03-07 07:11:51 +0100 |
| commit | fe3113451fdde72ddccfc71639f0a2a1e146209a (patch) | |
| tree | ba4114faf1bd00a642f97b5e7729ad74213c3b80 /train_ti.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-fe3113451fdde72ddccfc71639f0a2a1e146209a.tar.gz textual-inversion-diff-fe3113451fdde72ddccfc71639f0a2a1e146209a.tar.bz2 textual-inversion-diff-fe3113451fdde72ddccfc71639f0a2a1e146209a.zip | |
Update
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/train_ti.py b/train_ti.py index b9d6e56..81938c8 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -476,13 +476,10 @@ def parse_args(): | |||
| 476 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | 476 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
| 477 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | 477 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") |
| 478 | 478 | ||
| 479 | if args.num_vectors is None: | ||
| 480 | args.num_vectors = 1 | ||
| 481 | |||
| 482 | if isinstance(args.num_vectors, int): | 479 | if isinstance(args.num_vectors, int): |
| 483 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | 480 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) |
| 484 | 481 | ||
| 485 | if len(args.placeholder_tokens) != len(args.num_vectors): | 482 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): |
| 486 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 483 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") |
| 487 | 484 | ||
| 488 | if args.sequential: | 485 | if args.sequential: |
| @@ -491,6 +488,9 @@ def parse_args(): | |||
| 491 | 488 | ||
| 492 | if len(args.placeholder_tokens) != len(args.train_data_template): | 489 | if len(args.placeholder_tokens) != len(args.train_data_template): |
| 493 | raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") | 490 | raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") |
| 491 | |||
| 492 | if args.num_vectors is None: | ||
| 493 | args.num_vectors = [None] * len(args.placeholder_tokens) | ||
| 494 | else: | 494 | else: |
| 495 | if isinstance(args.train_data_template, list): | 495 | if isinstance(args.train_data_template, list): |
| 496 | raise ValueError("--train_data_template can't be a list in simultaneous mode") | 496 | raise ValueError("--train_data_template can't be a list in simultaneous mode") |
