summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-16 10:31:55 +0100
committerVolpeon <git@volpeon.ink>2023-01-16 10:31:55 +0100
commit89afcfda3f824cc44221e877182348f9b09687d2 (patch)
tree804b84322e5caa8fb861322ce6970bef4b532c61 /train_dreambooth.py
parentExtended Dreambooth: Train TI tokens separately (diff)
downloadtextual-inversion-diff-89afcfda3f824cc44221e877182348f9b09687d2.tar.gz
textual-inversion-diff-89afcfda3f824cc44221e877182348f9b09687d2.tar.bz2
textual-inversion-diff-89afcfda3f824cc44221e877182348f9b09687d2.zip
Handle empty validation dataset
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 05777d0..4e41f77 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -564,7 +564,7 @@ def main():
564 embeddings=embeddings, 564 embeddings=embeddings,
565 placeholder_tokens=[placeholder_token], 565 placeholder_tokens=[placeholder_token],
566 initializer_tokens=[initializer_token], 566 initializer_tokens=[initializer_token],
567 num_vectors=num_vectors 567 num_vectors=[num_vectors]
568 ) 568 )
569 569
570 datamodule = VlpnDataModule( 570 datamodule = VlpnDataModule(
@@ -579,7 +579,7 @@ def main():
579 valid_set_size=args.valid_set_size, 579 valid_set_size=args.valid_set_size,
580 valid_set_repeat=args.valid_set_repeat, 580 valid_set_repeat=args.valid_set_repeat,
581 seed=args.seed, 581 seed=args.seed,
582 filter=partial(keyword_filter, placeholder_token, args.collection, args.exclude_collections), 582 filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections),
583 dtype=weight_dtype 583 dtype=weight_dtype
584 ) 584 )
585 datamodule.setup() 585 datamodule.setup()
@@ -654,7 +654,7 @@ def main():
654 valid_set_size=args.valid_set_size, 654 valid_set_size=args.valid_set_size,
655 valid_set_repeat=args.valid_set_repeat, 655 valid_set_repeat=args.valid_set_repeat,
656 seed=args.seed, 656 seed=args.seed,
657 filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), 657 filter=partial(keyword_filter, None, args.collection, args.exclude_collections),
658 dtype=weight_dtype 658 dtype=weight_dtype
659 ) 659 )
660 datamodule.setup() 660 datamodule.setup()