From 89afcfda3f824cc44221e877182348f9b09687d2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 16 Jan 2023 10:31:55 +0100 Subject: Handle empty validation dataset --- train_dreambooth.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 05777d0..4e41f77 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -564,7 +564,7 @@ def main(): embeddings=embeddings, placeholder_tokens=[placeholder_token], initializer_tokens=[initializer_token], - num_vectors=num_vectors + num_vectors=[num_vectors] ) datamodule = VlpnDataModule( @@ -579,7 +579,7 @@ def main(): valid_set_size=args.valid_set_size, valid_set_repeat=args.valid_set_repeat, seed=args.seed, - filter=partial(keyword_filter, placeholder_token, args.collection, args.exclude_collections), + filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), dtype=weight_dtype ) datamodule.setup() @@ -654,7 +654,7 @@ def main(): valid_set_size=args.valid_set_size, valid_set_repeat=args.valid_set_repeat, seed=args.seed, - filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), + filter=partial(keyword_filter, None, args.collection, args.exclude_collections), dtype=weight_dtype ) datamodule.setup() -- cgit v1.2.3-54-g00ecf