summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/train_lora.py b/train_lora.py
index 0ae8b31..0d8b8cb 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -873,7 +873,6 @@ def main():
873 seed=args.seed, 873 seed=args.seed,
874 guidance_scale=args.guidance_scale, 874 guidance_scale=args.guidance_scale,
875 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, 875 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0,
876 no_val=args.valid_set_size == 0,
877 offset_noise_strength=args.offset_noise_strength, 876 offset_noise_strength=args.offset_noise_strength,
878 sample_scheduler=sample_scheduler, 877 sample_scheduler=sample_scheduler,
879 sample_batch_size=args.sample_batch_size, 878 sample_batch_size=args.sample_batch_size,
@@ -905,7 +904,6 @@ def main():
905 bucket_max_pixels=args.bucket_max_pixels, 904 bucket_max_pixels=args.bucket_max_pixels,
906 shuffle=not args.no_tag_shuffle, 905 shuffle=not args.no_tag_shuffle,
907 template_key=args.train_data_template, 906 template_key=args.train_data_template,
908 valid_set_size=args.valid_set_size,
909 train_set_pad=args.train_set_pad, 907 train_set_pad=args.train_set_pad,
910 valid_set_pad=args.valid_set_pad, 908 valid_set_pad=args.valid_set_pad,
911 dtype=weight_dtype, 909 dtype=weight_dtype,
@@ -931,6 +929,7 @@ def main():
931 filter_tokens = [token for token in args.filter_tokens if token in args.placeholder_tokens] 929 filter_tokens = [token for token in args.filter_tokens if token in args.placeholder_tokens]
932 930
933 pti_datamodule = create_datamodule( 931 pti_datamodule = create_datamodule(
932 valid_set_size=0,
934 batch_size=args.train_batch_size, 933 batch_size=args.train_batch_size,
935 filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), 934 filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections),
936 ) 935 )
@@ -992,6 +991,7 @@ def main():
992 sample_output_dir=pti_sample_output_dir, 991 sample_output_dir=pti_sample_output_dir,
993 checkpoint_output_dir=pti_checkpoint_output_dir, 992 checkpoint_output_dir=pti_checkpoint_output_dir,
994 sample_frequency=pti_sample_frequency, 993 sample_frequency=pti_sample_frequency,
994 no_val=True,
995 ) 995 )
996 996
997 embeddings.persist() 997 embeddings.persist()
@@ -1000,6 +1000,7 @@ def main():
1000 # -------------------------------------------------------------------------------- 1000 # --------------------------------------------------------------------------------
1001 1001
1002 lora_datamodule = create_datamodule( 1002 lora_datamodule = create_datamodule(
1003 valid_set_size=args.valid_set_size,
1003 batch_size=args.train_batch_size, 1004 batch_size=args.train_batch_size,
1004 dropout=args.tag_dropout, 1005 dropout=args.tag_dropout,
1005 filter=partial(keyword_filter, None, args.collection, args.exclude_collections), 1006 filter=partial(keyword_filter, None, args.collection, args.exclude_collections),
@@ -1136,6 +1137,7 @@ def main():
1136 sample_output_dir=lora_sample_output_dir, 1137 sample_output_dir=lora_sample_output_dir,
1137 checkpoint_output_dir=lora_checkpoint_output_dir, 1138 checkpoint_output_dir=lora_checkpoint_output_dir,
1138 sample_frequency=lora_sample_frequency, 1139 sample_frequency=lora_sample_frequency,
1140 no_val=args.valid_set_size == 0,
1139 ) 1141 )
1140 1142
1141 training_iter += 1 1143 training_iter += 1