diff options
-rw-r--r-- | train_lora.py | 29 |
1 files changed, 22 insertions, 7 deletions
diff --git a/train_lora.py b/train_lora.py index 39bf455..0b26965 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -242,6 +242,12 @@ def parse_args(): | |||
242 | help="Number of updates steps to accumulate before performing a backward/update pass.", | 242 | help="Number of updates steps to accumulate before performing a backward/update pass.", |
243 | ) | 243 | ) |
244 | parser.add_argument( | 244 | parser.add_argument( |
245 | "--pti_gradient_accumulation_steps", | ||
246 | type=int, | ||
247 | default=1, | ||
248 | help="Number of updates steps to accumulate before performing a backward/update pass.", | ||
249 | ) | ||
250 | parser.add_argument( | ||
245 | "--lora_r", | 251 | "--lora_r", |
246 | type=int, | 252 | type=int, |
247 | default=8, | 253 | default=8, |
@@ -476,6 +482,12 @@ def parse_args(): | |||
476 | help="Batch size (per device) for the training dataloader." | 482 | help="Batch size (per device) for the training dataloader." |
477 | ) | 483 | ) |
478 | parser.add_argument( | 484 | parser.add_argument( |
485 | "--pti_batch_size", | ||
486 | type=int, | ||
487 | default=1, | ||
488 | help="Batch size (per device) for the training dataloader." | ||
489 | ) | ||
490 | parser.add_argument( | ||
479 | "--sample_steps", | 491 | "--sample_steps", |
480 | type=int, | 492 | type=int, |
481 | default=10, | 493 | default=10, |
@@ -694,8 +706,8 @@ def main(): | |||
694 | args.train_batch_size * accelerator.num_processes | 706 | args.train_batch_size * accelerator.num_processes |
695 | ) | 707 | ) |
696 | args.learning_rate_pti = ( | 708 | args.learning_rate_pti = ( |
697 | args.learning_rate_pti * args.gradient_accumulation_steps * | 709 | args.learning_rate_pti * args.pti_gradient_accumulation_steps * |
698 | args.train_batch_size * accelerator.num_processes | 710 | args.pti_batch_size * accelerator.num_processes |
699 | ) | 711 | ) |
700 | 712 | ||
701 | if args.find_lr: | 713 | if args.find_lr: |
@@ -808,7 +820,6 @@ def main(): | |||
808 | guidance_scale=args.guidance_scale, | 820 | guidance_scale=args.guidance_scale, |
809 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, | 821 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
810 | no_val=args.valid_set_size == 0, | 822 | no_val=args.valid_set_size == 0, |
811 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
812 | offset_noise_strength=args.offset_noise_strength, | 823 | offset_noise_strength=args.offset_noise_strength, |
813 | sample_scheduler=sample_scheduler, | 824 | sample_scheduler=sample_scheduler, |
814 | sample_batch_size=args.sample_batch_size, | 825 | sample_batch_size=args.sample_batch_size, |
@@ -820,7 +831,6 @@ def main(): | |||
820 | create_datamodule = partial( | 831 | create_datamodule = partial( |
821 | VlpnDataModule, | 832 | VlpnDataModule, |
822 | data_file=args.train_data_file, | 833 | data_file=args.train_data_file, |
823 | batch_size=args.train_batch_size, | ||
824 | tokenizer=tokenizer, | 834 | tokenizer=tokenizer, |
825 | class_subdir=args.class_image_dir, | 835 | class_subdir=args.class_image_dir, |
826 | with_guidance=args.guidance_scale != 0, | 836 | with_guidance=args.guidance_scale != 0, |
@@ -843,7 +853,6 @@ def main(): | |||
843 | create_lr_scheduler = partial( | 853 | create_lr_scheduler = partial( |
844 | get_scheduler, | 854 | get_scheduler, |
845 | args.lr_scheduler, | 855 | args.lr_scheduler, |
846 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
847 | min_lr=args.lr_min_lr, | 856 | min_lr=args.lr_min_lr, |
848 | warmup_func=args.lr_warmup_func, | 857 | warmup_func=args.lr_warmup_func, |
849 | annealing_func=args.lr_annealing_func, | 858 | annealing_func=args.lr_annealing_func, |
@@ -863,6 +872,7 @@ def main(): | |||
863 | pti_sample_output_dir = pti_output_dir / "samples" | 872 | pti_sample_output_dir = pti_output_dir / "samples" |
864 | 873 | ||
865 | pti_datamodule = create_datamodule( | 874 | pti_datamodule = create_datamodule( |
875 | batch_size=args.pti_batch_size, | ||
866 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), | 876 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), |
867 | ) | 877 | ) |
868 | pti_datamodule.setup() | 878 | pti_datamodule.setup() |
@@ -872,7 +882,7 @@ def main(): | |||
872 | if num_pti_epochs is None: | 882 | if num_pti_epochs is None: |
873 | num_pti_epochs = math.ceil( | 883 | num_pti_epochs = math.ceil( |
874 | args.num_pti_steps / len(pti_datamodule.train_dataset) | 884 | args.num_pti_steps / len(pti_datamodule.train_dataset) |
875 | ) * args.gradient_accumulation_steps | 885 | ) * args.pti_gradient_accumulation_steps |
876 | pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) | 886 | pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) |
877 | 887 | ||
878 | pti_optimizer = create_optimizer( | 888 | pti_optimizer = create_optimizer( |
@@ -886,6 +896,7 @@ def main(): | |||
886 | ) | 896 | ) |
887 | 897 | ||
888 | pti_lr_scheduler = create_lr_scheduler( | 898 | pti_lr_scheduler = create_lr_scheduler( |
899 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | ||
889 | optimizer=pti_optimizer, | 900 | optimizer=pti_optimizer, |
890 | num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), | 901 | num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), |
891 | train_epochs=num_pti_epochs, | 902 | train_epochs=num_pti_epochs, |
@@ -893,12 +904,13 @@ def main(): | |||
893 | 904 | ||
894 | metrics = trainer( | 905 | metrics = trainer( |
895 | strategy=textual_inversion_strategy, | 906 | strategy=textual_inversion_strategy, |
896 | project="ti", | 907 | project="pti", |
897 | train_dataloader=pti_datamodule.train_dataloader, | 908 | train_dataloader=pti_datamodule.train_dataloader, |
898 | val_dataloader=pti_datamodule.val_dataloader, | 909 | val_dataloader=pti_datamodule.val_dataloader, |
899 | optimizer=pti_optimizer, | 910 | optimizer=pti_optimizer, |
900 | lr_scheduler=pti_lr_scheduler, | 911 | lr_scheduler=pti_lr_scheduler, |
901 | num_train_epochs=num_pti_epochs, | 912 | num_train_epochs=num_pti_epochs, |
913 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | ||
902 | # -- | 914 | # -- |
903 | sample_output_dir=pti_sample_output_dir, | 915 | sample_output_dir=pti_sample_output_dir, |
904 | checkpoint_output_dir=pti_checkpoint_output_dir, | 916 | checkpoint_output_dir=pti_checkpoint_output_dir, |
@@ -920,6 +932,7 @@ def main(): | |||
920 | lora_sample_output_dir = lora_output_dir / "samples" | 932 | lora_sample_output_dir = lora_output_dir / "samples" |
921 | 933 | ||
922 | lora_datamodule = create_datamodule( | 934 | lora_datamodule = create_datamodule( |
935 | batch_size=args.train_batch_size, | ||
923 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), | 936 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), |
924 | ) | 937 | ) |
925 | lora_datamodule.setup() | 938 | lora_datamodule.setup() |
@@ -954,6 +967,7 @@ def main(): | |||
954 | ) | 967 | ) |
955 | 968 | ||
956 | lora_lr_scheduler = create_lr_scheduler( | 969 | lora_lr_scheduler = create_lr_scheduler( |
970 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
957 | optimizer=lora_optimizer, | 971 | optimizer=lora_optimizer, |
958 | num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), | 972 | num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), |
959 | train_epochs=num_train_epochs, | 973 | train_epochs=num_train_epochs, |
@@ -967,6 +981,7 @@ def main(): | |||
967 | optimizer=lora_optimizer, | 981 | optimizer=lora_optimizer, |
968 | lr_scheduler=lora_lr_scheduler, | 982 | lr_scheduler=lora_lr_scheduler, |
969 | num_train_epochs=num_train_epochs, | 983 | num_train_epochs=num_train_epochs, |
984 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
970 | # -- | 985 | # -- |
971 | sample_output_dir=lora_sample_output_dir, | 986 | sample_output_dir=lora_sample_output_dir, |
972 | checkpoint_output_dir=lora_checkpoint_output_dir, | 987 | checkpoint_output_dir=lora_checkpoint_output_dir, |