diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/train_lora.py b/train_lora.py index 0ae8b31..0d8b8cb 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -873,7 +873,6 @@ def main(): | |||
873 | seed=args.seed, | 873 | seed=args.seed, |
874 | guidance_scale=args.guidance_scale, | 874 | guidance_scale=args.guidance_scale, |
875 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, | 875 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
876 | no_val=args.valid_set_size == 0, | ||
877 | offset_noise_strength=args.offset_noise_strength, | 876 | offset_noise_strength=args.offset_noise_strength, |
878 | sample_scheduler=sample_scheduler, | 877 | sample_scheduler=sample_scheduler, |
879 | sample_batch_size=args.sample_batch_size, | 878 | sample_batch_size=args.sample_batch_size, |
@@ -905,7 +904,6 @@ def main(): | |||
905 | bucket_max_pixels=args.bucket_max_pixels, | 904 | bucket_max_pixels=args.bucket_max_pixels, |
906 | shuffle=not args.no_tag_shuffle, | 905 | shuffle=not args.no_tag_shuffle, |
907 | template_key=args.train_data_template, | 906 | template_key=args.train_data_template, |
908 | valid_set_size=args.valid_set_size, | ||
909 | train_set_pad=args.train_set_pad, | 907 | train_set_pad=args.train_set_pad, |
910 | valid_set_pad=args.valid_set_pad, | 908 | valid_set_pad=args.valid_set_pad, |
911 | dtype=weight_dtype, | 909 | dtype=weight_dtype, |
@@ -931,6 +929,7 @@ def main(): | |||
931 | filter_tokens = [token for token in args.filter_tokens if token in args.placeholder_tokens] | 929 | filter_tokens = [token for token in args.filter_tokens if token in args.placeholder_tokens] |
932 | 930 | ||
933 | pti_datamodule = create_datamodule( | 931 | pti_datamodule = create_datamodule( |
932 | valid_set_size=0, | ||
934 | batch_size=args.train_batch_size, | 933 | batch_size=args.train_batch_size, |
935 | filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), | 934 | filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), |
936 | ) | 935 | ) |
@@ -992,6 +991,7 @@ def main(): | |||
992 | sample_output_dir=pti_sample_output_dir, | 991 | sample_output_dir=pti_sample_output_dir, |
993 | checkpoint_output_dir=pti_checkpoint_output_dir, | 992 | checkpoint_output_dir=pti_checkpoint_output_dir, |
994 | sample_frequency=pti_sample_frequency, | 993 | sample_frequency=pti_sample_frequency, |
994 | no_val=True, | ||
995 | ) | 995 | ) |
996 | 996 | ||
997 | embeddings.persist() | 997 | embeddings.persist() |
@@ -1000,6 +1000,7 @@ def main(): | |||
1000 | # -------------------------------------------------------------------------------- | 1000 | # -------------------------------------------------------------------------------- |
1001 | 1001 | ||
1002 | lora_datamodule = create_datamodule( | 1002 | lora_datamodule = create_datamodule( |
1003 | valid_set_size=args.valid_set_size, | ||
1003 | batch_size=args.train_batch_size, | 1004 | batch_size=args.train_batch_size, |
1004 | dropout=args.tag_dropout, | 1005 | dropout=args.tag_dropout, |
1005 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), | 1006 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), |
@@ -1136,6 +1137,7 @@ def main(): | |||
1136 | sample_output_dir=lora_sample_output_dir, | 1137 | sample_output_dir=lora_sample_output_dir, |
1137 | checkpoint_output_dir=lora_checkpoint_output_dir, | 1138 | checkpoint_output_dir=lora_checkpoint_output_dir, |
1138 | sample_frequency=lora_sample_frequency, | 1139 | sample_frequency=lora_sample_frequency, |
1140 | no_val=args.valid_set_size == 0, | ||
1139 | ) | 1141 | ) |
1140 | 1142 | ||
1141 | training_iter += 1 | 1143 | training_iter += 1 |