diff options
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index d396249..aa5ff01 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -384,6 +384,12 @@ def parse_args(): | |||
| 384 | help="Number of images in the validation dataset." | 384 | help="Number of images in the validation dataset." |
| 385 | ) | 385 | ) |
| 386 | parser.add_argument( | 386 | parser.add_argument( |
| 387 | "--valid_set_repeat", | ||
| 388 | type=int, | ||
| 389 | default=None, | ||
| 390 | help="Times the images in the validation dataset are repeated." | ||
| 391 | ) | ||
| 392 | parser.add_argument( | ||
| 387 | "--train_batch_size", | 393 | "--train_batch_size", |
| 388 | type=int, | 394 | type=int, |
| 389 | default=1, | 395 | default=1, |
| @@ -451,6 +457,9 @@ def parse_args(): | |||
| 451 | if isinstance(args.exclude_collections, str): | 457 | if isinstance(args.exclude_collections, str): |
| 452 | args.exclude_collections = [args.exclude_collections] | 458 | args.exclude_collections = [args.exclude_collections] |
| 453 | 459 | ||
| 460 | if args.valid_set_repeat is None: | ||
| 461 | args.valid_set_repeat = args.train_batch_size | ||
| 462 | |||
| 454 | if args.output_dir is None: | 463 | if args.output_dir is None: |
| 455 | raise ValueError("You must specify --output_dir") | 464 | raise ValueError("You must specify --output_dir") |
| 456 | 465 | ||
| @@ -764,6 +773,7 @@ def main(): | |||
| 764 | dropout=args.tag_dropout, | 773 | dropout=args.tag_dropout, |
| 765 | template_key=args.train_data_template, | 774 | template_key=args.train_data_template, |
| 766 | valid_set_size=args.valid_set_size, | 775 | valid_set_size=args.valid_set_size, |
| 776 | valid_set_repeat=args.valid_set_repeat, | ||
| 767 | num_workers=args.dataloader_num_workers, | 777 | num_workers=args.dataloader_num_workers, |
| 768 | seed=args.seed, | 778 | seed=args.seed, |
| 769 | filter=keyword_filter, | 779 | filter=keyword_filter, |
