summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py17
1 files changed, 12 insertions, 5 deletions
diff --git a/train_ti.py b/train_ti.py
index a894ee7..7aecdef 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -360,10 +360,16 @@ def parse_args():
360 help="Number of images in the validation dataset." 360 help="Number of images in the validation dataset."
361 ) 361 )
362 parser.add_argument( 362 parser.add_argument(
363 "--valid_set_repeat", 363 "--train_set_pad",
364 type=int, 364 type=int,
365 default=1, 365 default=None,
366 help="Times the images in the validation dataset are repeated." 366 help="The number to fill train dataset items up to."
367 )
368 parser.add_argument(
369 "--valid_set_pad",
370 type=int,
371 default=None,
372 help="The number to fill validation dataset items up to."
367 ) 373 )
368 parser.add_argument( 374 parser.add_argument(
369 "--train_batch_size", 375 "--train_batch_size",
@@ -575,7 +581,8 @@ def main():
575 shuffle=not args.no_tag_shuffle, 581 shuffle=not args.no_tag_shuffle,
576 template_key=args.train_data_template, 582 template_key=args.train_data_template,
577 valid_set_size=args.valid_set_size, 583 valid_set_size=args.valid_set_size,
578 valid_set_repeat=args.valid_set_repeat, 584 train_set_pad=args.train_set_pad,
585 valid_set_pad=args.valid_set_pad,
579 seed=args.seed, 586 seed=args.seed,
580 filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), 587 filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections),
581 dtype=weight_dtype 588 dtype=weight_dtype
@@ -590,7 +597,7 @@ def main():
590 unet, 597 unet,
591 tokenizer, 598 tokenizer,
592 sample_scheduler, 599 sample_scheduler,
593 datamodule.data_train, 600 datamodule.train_dataset,
594 args.sample_batch_size, 601 args.sample_batch_size,
595 args.sample_image_size, 602 args.sample_image_size,
596 args.sample_steps 603 args.sample_steps