summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py5
1 files changed, 1 insertions, 4 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index aa5ff01..1a1f516 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -386,7 +386,7 @@ def parse_args():
386 parser.add_argument( 386 parser.add_argument(
387 "--valid_set_repeat", 387 "--valid_set_repeat",
388 type=int, 388 type=int,
389 default=None, 389 default=1,
390 help="Times the images in the validation dataset are repeated." 390 help="Times the images in the validation dataset are repeated."
391 ) 391 )
392 parser.add_argument( 392 parser.add_argument(
@@ -457,9 +457,6 @@ def parse_args():
457 if isinstance(args.exclude_collections, str): 457 if isinstance(args.exclude_collections, str):
458 args.exclude_collections = [args.exclude_collections] 458 args.exclude_collections = [args.exclude_collections]
459 459
460 if args.valid_set_repeat is None:
461 args.valid_set_repeat = args.train_batch_size
462
463 if args.output_dir is None: 460 if args.output_dir is None:
464 raise ValueError("You must specify --output_dir") 461 raise ValueError("You must specify --output_dir")
465 462