From fe3113451fdde72ddccfc71639f0a2a1e146209a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 7 Mar 2023 07:11:51 +0100 Subject: Update --- train_ti.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index b9d6e56..81938c8 100644 --- a/train_ti.py +++ b/train_ti.py @@ -476,13 +476,10 @@ def parse_args(): if len(args.placeholder_tokens) != len(args.initializer_tokens): raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") - if args.num_vectors is None: - args.num_vectors = 1 - if isinstance(args.num_vectors, int): args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) - if len(args.placeholder_tokens) != len(args.num_vectors): + if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") if args.sequential: @@ -491,6 +488,9 @@ def parse_args(): if len(args.placeholder_tokens) != len(args.train_data_template): raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") + + if args.num_vectors is None: + args.num_vectors = [None] * len(args.placeholder_tokens) else: if isinstance(args.train_data_template, list): raise ValueError("--train_data_template can't be a list in simultaneous mode") -- cgit v1.2.3-54-g00ecf