summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index d396249..aa5ff01 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -384,6 +384,12 @@ def parse_args():
384 help="Number of images in the validation dataset." 384 help="Number of images in the validation dataset."
385 ) 385 )
386 parser.add_argument( 386 parser.add_argument(
387 "--valid_set_repeat",
388 type=int,
389 default=None,
390 help="Times the images in the validation dataset are repeated."
391 )
392 parser.add_argument(
387 "--train_batch_size", 393 "--train_batch_size",
388 type=int, 394 type=int,
389 default=1, 395 default=1,
@@ -451,6 +457,9 @@ def parse_args():
451 if isinstance(args.exclude_collections, str): 457 if isinstance(args.exclude_collections, str):
452 args.exclude_collections = [args.exclude_collections] 458 args.exclude_collections = [args.exclude_collections]
453 459
460 if args.valid_set_repeat is None:
461 args.valid_set_repeat = args.train_batch_size
462
454 if args.output_dir is None: 463 if args.output_dir is None:
455 raise ValueError("You must specify --output_dir") 464 raise ValueError("You must specify --output_dir")
456 465
@@ -764,6 +773,7 @@ def main():
764 dropout=args.tag_dropout, 773 dropout=args.tag_dropout,
765 template_key=args.train_data_template, 774 template_key=args.train_data_template,
766 valid_set_size=args.valid_set_size, 775 valid_set_size=args.valid_set_size,
776 valid_set_repeat=args.valid_set_repeat,
767 num_workers=args.dataloader_num_workers, 777 num_workers=args.dataloader_num_workers,
768 seed=args.seed, 778 seed=args.seed,
769 filter=keyword_filter, 779 filter=keyword_filter,