summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_lora.py29
1 files changed, 22 insertions, 7 deletions
diff --git a/train_lora.py b/train_lora.py
index 39bf455..0b26965 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -242,6 +242,12 @@ def parse_args():
242 help="Number of updates steps to accumulate before performing a backward/update pass.", 242 help="Number of updates steps to accumulate before performing a backward/update pass.",
243 ) 243 )
244 parser.add_argument( 244 parser.add_argument(
245 "--pti_gradient_accumulation_steps",
246 type=int,
247 default=1,
248 help="Number of updates steps to accumulate before performing a backward/update pass.",
249 )
250 parser.add_argument(
245 "--lora_r", 251 "--lora_r",
246 type=int, 252 type=int,
247 default=8, 253 default=8,
@@ -476,6 +482,12 @@ def parse_args():
476 help="Batch size (per device) for the training dataloader." 482 help="Batch size (per device) for the training dataloader."
477 ) 483 )
478 parser.add_argument( 484 parser.add_argument(
485 "--pti_batch_size",
486 type=int,
487 default=1,
488 help="Batch size (per device) for the training dataloader."
489 )
490 parser.add_argument(
479 "--sample_steps", 491 "--sample_steps",
480 type=int, 492 type=int,
481 default=10, 493 default=10,
@@ -694,8 +706,8 @@ def main():
694 args.train_batch_size * accelerator.num_processes 706 args.train_batch_size * accelerator.num_processes
695 ) 707 )
696 args.learning_rate_pti = ( 708 args.learning_rate_pti = (
697 args.learning_rate_pti * args.gradient_accumulation_steps * 709 args.learning_rate_pti * args.pti_gradient_accumulation_steps *
698 args.train_batch_size * accelerator.num_processes 710 args.pti_batch_size * accelerator.num_processes
699 ) 711 )
700 712
701 if args.find_lr: 713 if args.find_lr:
@@ -808,7 +820,6 @@ def main():
808 guidance_scale=args.guidance_scale, 820 guidance_scale=args.guidance_scale,
809 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, 821 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0,
810 no_val=args.valid_set_size == 0, 822 no_val=args.valid_set_size == 0,
811 gradient_accumulation_steps=args.gradient_accumulation_steps,
812 offset_noise_strength=args.offset_noise_strength, 823 offset_noise_strength=args.offset_noise_strength,
813 sample_scheduler=sample_scheduler, 824 sample_scheduler=sample_scheduler,
814 sample_batch_size=args.sample_batch_size, 825 sample_batch_size=args.sample_batch_size,
@@ -820,7 +831,6 @@ def main():
820 create_datamodule = partial( 831 create_datamodule = partial(
821 VlpnDataModule, 832 VlpnDataModule,
822 data_file=args.train_data_file, 833 data_file=args.train_data_file,
823 batch_size=args.train_batch_size,
824 tokenizer=tokenizer, 834 tokenizer=tokenizer,
825 class_subdir=args.class_image_dir, 835 class_subdir=args.class_image_dir,
826 with_guidance=args.guidance_scale != 0, 836 with_guidance=args.guidance_scale != 0,
@@ -843,7 +853,6 @@ def main():
843 create_lr_scheduler = partial( 853 create_lr_scheduler = partial(
844 get_scheduler, 854 get_scheduler,
845 args.lr_scheduler, 855 args.lr_scheduler,
846 gradient_accumulation_steps=args.gradient_accumulation_steps,
847 min_lr=args.lr_min_lr, 856 min_lr=args.lr_min_lr,
848 warmup_func=args.lr_warmup_func, 857 warmup_func=args.lr_warmup_func,
849 annealing_func=args.lr_annealing_func, 858 annealing_func=args.lr_annealing_func,
@@ -863,6 +872,7 @@ def main():
863 pti_sample_output_dir = pti_output_dir / "samples" 872 pti_sample_output_dir = pti_output_dir / "samples"
864 873
865 pti_datamodule = create_datamodule( 874 pti_datamodule = create_datamodule(
875 batch_size=args.pti_batch_size,
866 filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), 876 filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections),
867 ) 877 )
868 pti_datamodule.setup() 878 pti_datamodule.setup()
@@ -872,7 +882,7 @@ def main():
872 if num_pti_epochs is None: 882 if num_pti_epochs is None:
873 num_pti_epochs = math.ceil( 883 num_pti_epochs = math.ceil(
874 args.num_pti_steps / len(pti_datamodule.train_dataset) 884 args.num_pti_steps / len(pti_datamodule.train_dataset)
875 ) * args.gradient_accumulation_steps 885 ) * args.pti_gradient_accumulation_steps
876 pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) 886 pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps))
877 887
878 pti_optimizer = create_optimizer( 888 pti_optimizer = create_optimizer(
@@ -886,6 +896,7 @@ def main():
886 ) 896 )
887 897
888 pti_lr_scheduler = create_lr_scheduler( 898 pti_lr_scheduler = create_lr_scheduler(
899 gradient_accumulation_steps=args.pti_gradient_accumulation_steps,
889 optimizer=pti_optimizer, 900 optimizer=pti_optimizer,
890 num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), 901 num_training_steps_per_epoch=len(pti_datamodule.train_dataloader),
891 train_epochs=num_pti_epochs, 902 train_epochs=num_pti_epochs,
@@ -893,12 +904,13 @@ def main():
893 904
894 metrics = trainer( 905 metrics = trainer(
895 strategy=textual_inversion_strategy, 906 strategy=textual_inversion_strategy,
896 project="ti", 907 project="pti",
897 train_dataloader=pti_datamodule.train_dataloader, 908 train_dataloader=pti_datamodule.train_dataloader,
898 val_dataloader=pti_datamodule.val_dataloader, 909 val_dataloader=pti_datamodule.val_dataloader,
899 optimizer=pti_optimizer, 910 optimizer=pti_optimizer,
900 lr_scheduler=pti_lr_scheduler, 911 lr_scheduler=pti_lr_scheduler,
901 num_train_epochs=num_pti_epochs, 912 num_train_epochs=num_pti_epochs,
913 gradient_accumulation_steps=args.pti_gradient_accumulation_steps,
902 # -- 914 # --
903 sample_output_dir=pti_sample_output_dir, 915 sample_output_dir=pti_sample_output_dir,
904 checkpoint_output_dir=pti_checkpoint_output_dir, 916 checkpoint_output_dir=pti_checkpoint_output_dir,
@@ -920,6 +932,7 @@ def main():
920 lora_sample_output_dir = lora_output_dir / "samples" 932 lora_sample_output_dir = lora_output_dir / "samples"
921 933
922 lora_datamodule = create_datamodule( 934 lora_datamodule = create_datamodule(
935 batch_size=args.train_batch_size,
923 filter=partial(keyword_filter, None, args.collection, args.exclude_collections), 936 filter=partial(keyword_filter, None, args.collection, args.exclude_collections),
924 ) 937 )
925 lora_datamodule.setup() 938 lora_datamodule.setup()
@@ -954,6 +967,7 @@ def main():
954 ) 967 )
955 968
956 lora_lr_scheduler = create_lr_scheduler( 969 lora_lr_scheduler = create_lr_scheduler(
970 gradient_accumulation_steps=args.gradient_accumulation_steps,
957 optimizer=lora_optimizer, 971 optimizer=lora_optimizer,
958 num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), 972 num_training_steps_per_epoch=len(lora_datamodule.train_dataloader),
959 train_epochs=num_train_epochs, 973 train_epochs=num_train_epochs,
@@ -967,6 +981,7 @@ def main():
967 optimizer=lora_optimizer, 981 optimizer=lora_optimizer,
968 lr_scheduler=lora_lr_scheduler, 982 lr_scheduler=lora_lr_scheduler,
969 num_train_epochs=num_train_epochs, 983 num_train_epochs=num_train_epochs,
984 gradient_accumulation_steps=args.gradient_accumulation_steps,
970 # -- 985 # --
971 sample_output_dir=lora_sample_output_dir, 986 sample_output_dir=lora_sample_output_dir,
972 checkpoint_output_dir=lora_checkpoint_output_dir, 987 checkpoint_output_dir=lora_checkpoint_output_dir,