From 6c8cffe28baeafac77d047ff3f8ded9418033e2f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 16 Jan 2023 15:52:43 +0100 Subject: More training adjustments --- train_ti.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index a894ee7..7aecdef 100644 --- a/train_ti.py +++ b/train_ti.py @@ -360,10 +360,16 @@ def parse_args(): help="Number of images in the validation dataset." ) parser.add_argument( - "--valid_set_repeat", + "--train_set_pad", type=int, - default=1, - help="Times the images in the validation dataset are repeated." + default=None, + help="The number to fill train dataset items up to." + ) + parser.add_argument( + "--valid_set_pad", + type=int, + default=None, + help="The number to fill validation dataset items up to." ) parser.add_argument( "--train_batch_size", @@ -575,7 +581,8 @@ def main(): shuffle=not args.no_tag_shuffle, template_key=args.train_data_template, valid_set_size=args.valid_set_size, - valid_set_repeat=args.valid_set_repeat, + train_set_pad=args.train_set_pad, + valid_set_pad=args.valid_set_pad, seed=args.seed, filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), dtype=weight_dtype @@ -590,7 +597,7 @@ def main(): unet, tokenizer, sample_scheduler, - datamodule.data_train, + datamodule.train_dataset, args.sample_batch_size, args.sample_image_size, args.sample_steps -- cgit v1.2.3-54-g00ecf