diff options
| author | Volpeon <git@volpeon.ink> | 2023-03-27 11:58:47 +0200 | 
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-03-27 11:58:47 +0200 | 
| commit | 1c63552a20f34bccd461ac0dfa46405f853cbc7c (patch) | |
| tree | df26b48ff4c2ef79349b0a4025cdde05b0ed8518 /train_ti.py | |
| parent | Fix TI (diff) | |
| download | textual-inversion-diff-1c63552a20f34bccd461ac0dfa46405f853cbc7c.tar.gz textual-inversion-diff-1c63552a20f34bccd461ac0dfa46405f853cbc7c.tar.bz2 textual-inversion-diff-1c63552a20f34bccd461ac0dfa46405f853cbc7c.zip | |
Fix TI
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 11 | 
1 files changed, 10 insertions, 1 deletions
| diff --git a/train_ti.py b/train_ti.py index ef39c38..9ae8d1b 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -155,7 +155,7 @@ def parse_args(): | |||
| 155 | parser.add_argument( | 155 | parser.add_argument( | 
| 156 | "--num_buckets", | 156 | "--num_buckets", | 
| 157 | type=int, | 157 | type=int, | 
| 158 | default=2, | 158 | default=0, | 
| 159 | help="Number of aspect ratio buckets in either direction.", | 159 | help="Number of aspect ratio buckets in either direction.", | 
| 160 | ) | 160 | ) | 
| 161 | parser.add_argument( | 161 | parser.add_argument( | 
| @@ -507,9 +507,18 @@ def parse_args(): | |||
| 507 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): | 507 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): | 
| 508 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 508 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 
| 509 | 509 | ||
| 510 | if args.alias_tokens is None: | ||
| 511 | args.alias_tokens = [] | ||
| 512 | |||
| 510 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: | 513 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: | 
| 511 | raise ValueError("--alias_tokens must be a list with an even number of items") | 514 | raise ValueError("--alias_tokens must be a list with an even number of items") | 
| 512 | 515 | ||
| 516 | args.alias_tokens += [ | ||
| 517 | item | ||
| 518 | for pair in zip(args.placeholder_tokens, args.initializer_tokens) | ||
| 519 | for item in pair | ||
| 520 | ] | ||
| 521 | |||
| 513 | if args.sequential: | 522 | if args.sequential: | 
| 514 | if isinstance(args.train_data_template, str): | 523 | if isinstance(args.train_data_template, str): | 
| 515 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) | 524 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) | 
