diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/train_ti.py b/train_ti.py index a894ee7..7aecdef 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -360,10 +360,16 @@ def parse_args(): | |||
| 360 | help="Number of images in the validation dataset." | 360 | help="Number of images in the validation dataset." |
| 361 | ) | 361 | ) |
| 362 | parser.add_argument( | 362 | parser.add_argument( |
| 363 | "--valid_set_repeat", | 363 | "--train_set_pad", |
| 364 | type=int, | 364 | type=int, |
| 365 | default=1, | 365 | default=None, |
| 366 | help="Times the images in the validation dataset are repeated." | 366 | help="The number to fill train dataset items up to." |
| 367 | ) | ||
| 368 | parser.add_argument( | ||
| 369 | "--valid_set_pad", | ||
| 370 | type=int, | ||
| 371 | default=None, | ||
| 372 | help="The number to fill validation dataset items up to." | ||
| 367 | ) | 373 | ) |
| 368 | parser.add_argument( | 374 | parser.add_argument( |
| 369 | "--train_batch_size", | 375 | "--train_batch_size", |
| @@ -575,7 +581,8 @@ def main(): | |||
| 575 | shuffle=not args.no_tag_shuffle, | 581 | shuffle=not args.no_tag_shuffle, |
| 576 | template_key=args.train_data_template, | 582 | template_key=args.train_data_template, |
| 577 | valid_set_size=args.valid_set_size, | 583 | valid_set_size=args.valid_set_size, |
| 578 | valid_set_repeat=args.valid_set_repeat, | 584 | train_set_pad=args.train_set_pad, |
| 585 | valid_set_pad=args.valid_set_pad, | ||
| 579 | seed=args.seed, | 586 | seed=args.seed, |
| 580 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), | 587 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), |
| 581 | dtype=weight_dtype | 588 | dtype=weight_dtype |
| @@ -590,7 +597,7 @@ def main(): | |||
| 590 | unet, | 597 | unet, |
| 591 | tokenizer, | 598 | tokenizer, |
| 592 | sample_scheduler, | 599 | sample_scheduler, |
| 593 | datamodule.data_train, | 600 | datamodule.train_dataset, |
| 594 | args.sample_batch_size, | 601 | args.sample_batch_size, |
| 595 | args.sample_image_size, | 602 | args.sample_image_size, |
| 596 | args.sample_steps | 603 | args.sample_steps |
