From 1c63552a20f34bccd461ac0dfa46405f853cbc7c Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 27 Mar 2023 11:58:47 +0200 Subject: Fix TI --- train_ti.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index ef39c38..9ae8d1b 100644 --- a/train_ti.py +++ b/train_ti.py @@ -155,7 +155,7 @@ def parse_args(): parser.add_argument( "--num_buckets", type=int, - default=2, + default=0, help="Number of aspect ratio buckets in either direction.", ) parser.add_argument( @@ -507,9 +507,18 @@ def parse_args(): if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") + if args.alias_tokens is None: + args.alias_tokens = [] + if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: raise ValueError("--alias_tokens must be a list with an even number of items") + args.alias_tokens += [ + item + for pair in zip(args.placeholder_tokens, args.initializer_tokens) + for item in pair + ] + if args.sequential: if isinstance(args.train_data_template, str): args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) -- cgit v1.2.3-54-g00ecf