diff options
author | Volpeon <git@volpeon.ink> | 2023-01-09 10:19:37 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-09 10:19:37 +0100 |
commit | b57ca669a150d9313447612fb8c37668f4f2a80d (patch) | |
tree | b0ebfedc33c26847838850416b96fd2623cf6ba5 /train_dreambooth.py | |
parent | No cache after all (diff) | |
download | textual-inversion-diff-b57ca669a150d9313447612fb8c37668f4f2a80d.tar.gz textual-inversion-diff-b57ca669a150d9313447612fb8c37668f4f2a80d.tar.bz2 textual-inversion-diff-b57ca669a150d9313447612fb8c37668f4f2a80d.zip |
Add --valid_set_repeat
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index d396249..aa5ff01 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -384,6 +384,12 @@ def parse_args(): | |||
384 | help="Number of images in the validation dataset." | 384 | help="Number of images in the validation dataset." |
385 | ) | 385 | ) |
386 | parser.add_argument( | 386 | parser.add_argument( |
387 | "--valid_set_repeat", | ||
388 | type=int, | ||
389 | default=None, | ||
390 | help="Times the images in the validation dataset are repeated." | ||
391 | ) | ||
392 | parser.add_argument( | ||
387 | "--train_batch_size", | 393 | "--train_batch_size", |
388 | type=int, | 394 | type=int, |
389 | default=1, | 395 | default=1, |
@@ -451,6 +457,9 @@ def parse_args(): | |||
451 | if isinstance(args.exclude_collections, str): | 457 | if isinstance(args.exclude_collections, str): |
452 | args.exclude_collections = [args.exclude_collections] | 458 | args.exclude_collections = [args.exclude_collections] |
453 | 459 | ||
460 | if args.valid_set_repeat is None: | ||
461 | args.valid_set_repeat = args.train_batch_size | ||
462 | |||
454 | if args.output_dir is None: | 463 | if args.output_dir is None: |
455 | raise ValueError("You must specify --output_dir") | 464 | raise ValueError("You must specify --output_dir") |
456 | 465 | ||
@@ -764,6 +773,7 @@ def main(): | |||
764 | dropout=args.tag_dropout, | 773 | dropout=args.tag_dropout, |
765 | template_key=args.train_data_template, | 774 | template_key=args.train_data_template, |
766 | valid_set_size=args.valid_set_size, | 775 | valid_set_size=args.valid_set_size, |
776 | valid_set_repeat=args.valid_set_repeat, | ||
767 | num_workers=args.dataloader_num_workers, | 777 | num_workers=args.dataloader_num_workers, |
768 | seed=args.seed, | 778 | seed=args.seed, |
769 | filter=keyword_filter, | 779 | filter=keyword_filter, |