From 25539ea3979321d7968c88f5be9cf19d49c8e59e Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 20 Apr 2023 12:48:18 +0200 Subject: Fix PTI --- train_lora.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/train_lora.py b/train_lora.py index 0ae8b31..0d8b8cb 100644 --- a/train_lora.py +++ b/train_lora.py @@ -873,7 +873,6 @@ def main(): seed=args.seed, guidance_scale=args.guidance_scale, prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, - no_val=args.valid_set_size == 0, offset_noise_strength=args.offset_noise_strength, sample_scheduler=sample_scheduler, sample_batch_size=args.sample_batch_size, @@ -905,7 +904,6 @@ def main(): bucket_max_pixels=args.bucket_max_pixels, shuffle=not args.no_tag_shuffle, template_key=args.train_data_template, - valid_set_size=args.valid_set_size, train_set_pad=args.train_set_pad, valid_set_pad=args.valid_set_pad, dtype=weight_dtype, @@ -931,6 +929,7 @@ def main(): filter_tokens = [token for token in args.filter_tokens if token in args.placeholder_tokens] pti_datamodule = create_datamodule( + valid_set_size=0, batch_size=args.train_batch_size, filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), ) @@ -992,6 +991,7 @@ def main(): sample_output_dir=pti_sample_output_dir, checkpoint_output_dir=pti_checkpoint_output_dir, sample_frequency=pti_sample_frequency, + no_val=True, ) embeddings.persist() @@ -1000,6 +1000,7 @@ def main(): # -------------------------------------------------------------------------------- lora_datamodule = create_datamodule( + valid_set_size=args.valid_set_size, batch_size=args.train_batch_size, dropout=args.tag_dropout, filter=partial(keyword_filter, None, args.collection, args.exclude_collections), @@ -1136,6 +1137,7 @@ def main(): sample_output_dir=lora_sample_output_dir, checkpoint_output_dir=lora_checkpoint_output_dir, sample_frequency=lora_sample_frequency, + no_val=args.valid_set_size == 0, ) training_iter += 1 -- cgit v1.2.3-70-g09d2