diff options
author | Volpeon <git@volpeon.ink> | 2023-03-27 11:58:47 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-27 11:58:47 +0200 |
commit | 1c63552a20f34bccd461ac0dfa46405f853cbc7c (patch) | |
tree | df26b48ff4c2ef79349b0a4025cdde05b0ed8518 /train_ti.py | |
parent | Fix TI (diff) | |
download | textual-inversion-diff-1c63552a20f34bccd461ac0dfa46405f853cbc7c.tar.gz textual-inversion-diff-1c63552a20f34bccd461ac0dfa46405f853cbc7c.tar.bz2 textual-inversion-diff-1c63552a20f34bccd461ac0dfa46405f853cbc7c.zip |
Fix TI
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 11 |
1 files changed, 10 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py index ef39c38..9ae8d1b 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -155,7 +155,7 @@ def parse_args(): | |||
155 | parser.add_argument( | 155 | parser.add_argument( |
156 | "--num_buckets", | 156 | "--num_buckets", |
157 | type=int, | 157 | type=int, |
158 | default=2, | 158 | default=0, |
159 | help="Number of aspect ratio buckets in either direction.", | 159 | help="Number of aspect ratio buckets in either direction.", |
160 | ) | 160 | ) |
161 | parser.add_argument( | 161 | parser.add_argument( |
@@ -507,9 +507,18 @@ def parse_args(): | |||
507 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): | 507 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): |
508 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 508 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") |
509 | 509 | ||
510 | if args.alias_tokens is None: | ||
511 | args.alias_tokens = [] | ||
512 | |||
510 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: | 513 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: |
511 | raise ValueError("--alias_tokens must be a list with an even number of items") | 514 | raise ValueError("--alias_tokens must be a list with an even number of items") |
512 | 515 | ||
516 | args.alias_tokens += [ | ||
517 | item | ||
518 | for pair in zip(args.placeholder_tokens, args.initializer_tokens) | ||
519 | for item in pair | ||
520 | ] | ||
521 | |||
513 | if args.sequential: | 522 | if args.sequential: |
514 | if isinstance(args.train_data_template, str): | 523 | if isinstance(args.train_data_template, str): |
515 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) | 524 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) |