diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/train_ti.py b/train_ti.py index a894ee7..7aecdef 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -360,10 +360,16 @@ def parse_args(): | |||
360 | help="Number of images in the validation dataset." | 360 | help="Number of images in the validation dataset." |
361 | ) | 361 | ) |
362 | parser.add_argument( | 362 | parser.add_argument( |
363 | "--valid_set_repeat", | 363 | "--train_set_pad", |
364 | type=int, | 364 | type=int, |
365 | default=1, | 365 | default=None, |
366 | help="Times the images in the validation dataset are repeated." | 366 | help="The number to fill train dataset items up to." |
367 | ) | ||
368 | parser.add_argument( | ||
369 | "--valid_set_pad", | ||
370 | type=int, | ||
371 | default=None, | ||
372 | help="The number to fill validation dataset items up to." | ||
367 | ) | 373 | ) |
368 | parser.add_argument( | 374 | parser.add_argument( |
369 | "--train_batch_size", | 375 | "--train_batch_size", |
@@ -575,7 +581,8 @@ def main(): | |||
575 | shuffle=not args.no_tag_shuffle, | 581 | shuffle=not args.no_tag_shuffle, |
576 | template_key=args.train_data_template, | 582 | template_key=args.train_data_template, |
577 | valid_set_size=args.valid_set_size, | 583 | valid_set_size=args.valid_set_size, |
578 | valid_set_repeat=args.valid_set_repeat, | 584 | train_set_pad=args.train_set_pad, |
585 | valid_set_pad=args.valid_set_pad, | ||
579 | seed=args.seed, | 586 | seed=args.seed, |
580 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), | 587 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), |
581 | dtype=weight_dtype | 588 | dtype=weight_dtype |
@@ -590,7 +597,7 @@ def main(): | |||
590 | unet, | 597 | unet, |
591 | tokenizer, | 598 | tokenizer, |
592 | sample_scheduler, | 599 | sample_scheduler, |
593 | datamodule.data_train, | 600 | datamodule.train_dataset, |
594 | args.sample_batch_size, | 601 | args.sample_batch_size, |
595 | args.sample_image_size, | 602 | args.sample_image_size, |
596 | args.sample_steps | 603 | args.sample_steps |