diff options
| -rw-r--r-- | train_lora.py | 29 |
1 files changed, 22 insertions, 7 deletions
diff --git a/train_lora.py b/train_lora.py index 39bf455..0b26965 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -242,6 +242,12 @@ def parse_args(): | |||
| 242 | help="Number of updates steps to accumulate before performing a backward/update pass.", | 242 | help="Number of updates steps to accumulate before performing a backward/update pass.", |
| 243 | ) | 243 | ) |
| 244 | parser.add_argument( | 244 | parser.add_argument( |
| 245 | "--pti_gradient_accumulation_steps", | ||
| 246 | type=int, | ||
| 247 | default=1, | ||
| 248 | help="Number of updates steps to accumulate before performing a backward/update pass.", | ||
| 249 | ) | ||
| 250 | parser.add_argument( | ||
| 245 | "--lora_r", | 251 | "--lora_r", |
| 246 | type=int, | 252 | type=int, |
| 247 | default=8, | 253 | default=8, |
| @@ -476,6 +482,12 @@ def parse_args(): | |||
| 476 | help="Batch size (per device) for the training dataloader." | 482 | help="Batch size (per device) for the training dataloader." |
| 477 | ) | 483 | ) |
| 478 | parser.add_argument( | 484 | parser.add_argument( |
| 485 | "--pti_batch_size", | ||
| 486 | type=int, | ||
| 487 | default=1, | ||
| 488 | help="Batch size (per device) for the training dataloader." | ||
| 489 | ) | ||
| 490 | parser.add_argument( | ||
| 479 | "--sample_steps", | 491 | "--sample_steps", |
| 480 | type=int, | 492 | type=int, |
| 481 | default=10, | 493 | default=10, |
| @@ -694,8 +706,8 @@ def main(): | |||
| 694 | args.train_batch_size * accelerator.num_processes | 706 | args.train_batch_size * accelerator.num_processes |
| 695 | ) | 707 | ) |
| 696 | args.learning_rate_pti = ( | 708 | args.learning_rate_pti = ( |
| 697 | args.learning_rate_pti * args.gradient_accumulation_steps * | 709 | args.learning_rate_pti * args.pti_gradient_accumulation_steps * |
| 698 | args.train_batch_size * accelerator.num_processes | 710 | args.pti_batch_size * accelerator.num_processes |
| 699 | ) | 711 | ) |
| 700 | 712 | ||
| 701 | if args.find_lr: | 713 | if args.find_lr: |
| @@ -808,7 +820,6 @@ def main(): | |||
| 808 | guidance_scale=args.guidance_scale, | 820 | guidance_scale=args.guidance_scale, |
| 809 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, | 821 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
| 810 | no_val=args.valid_set_size == 0, | 822 | no_val=args.valid_set_size == 0, |
| 811 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 812 | offset_noise_strength=args.offset_noise_strength, | 823 | offset_noise_strength=args.offset_noise_strength, |
| 813 | sample_scheduler=sample_scheduler, | 824 | sample_scheduler=sample_scheduler, |
| 814 | sample_batch_size=args.sample_batch_size, | 825 | sample_batch_size=args.sample_batch_size, |
| @@ -820,7 +831,6 @@ def main(): | |||
| 820 | create_datamodule = partial( | 831 | create_datamodule = partial( |
| 821 | VlpnDataModule, | 832 | VlpnDataModule, |
| 822 | data_file=args.train_data_file, | 833 | data_file=args.train_data_file, |
| 823 | batch_size=args.train_batch_size, | ||
| 824 | tokenizer=tokenizer, | 834 | tokenizer=tokenizer, |
| 825 | class_subdir=args.class_image_dir, | 835 | class_subdir=args.class_image_dir, |
| 826 | with_guidance=args.guidance_scale != 0, | 836 | with_guidance=args.guidance_scale != 0, |
| @@ -843,7 +853,6 @@ def main(): | |||
| 843 | create_lr_scheduler = partial( | 853 | create_lr_scheduler = partial( |
| 844 | get_scheduler, | 854 | get_scheduler, |
| 845 | args.lr_scheduler, | 855 | args.lr_scheduler, |
| 846 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 847 | min_lr=args.lr_min_lr, | 856 | min_lr=args.lr_min_lr, |
| 848 | warmup_func=args.lr_warmup_func, | 857 | warmup_func=args.lr_warmup_func, |
| 849 | annealing_func=args.lr_annealing_func, | 858 | annealing_func=args.lr_annealing_func, |
| @@ -863,6 +872,7 @@ def main(): | |||
| 863 | pti_sample_output_dir = pti_output_dir / "samples" | 872 | pti_sample_output_dir = pti_output_dir / "samples" |
| 864 | 873 | ||
| 865 | pti_datamodule = create_datamodule( | 874 | pti_datamodule = create_datamodule( |
| 875 | batch_size=args.pti_batch_size, | ||
| 866 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), | 876 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), |
| 867 | ) | 877 | ) |
| 868 | pti_datamodule.setup() | 878 | pti_datamodule.setup() |
| @@ -872,7 +882,7 @@ def main(): | |||
| 872 | if num_pti_epochs is None: | 882 | if num_pti_epochs is None: |
| 873 | num_pti_epochs = math.ceil( | 883 | num_pti_epochs = math.ceil( |
| 874 | args.num_pti_steps / len(pti_datamodule.train_dataset) | 884 | args.num_pti_steps / len(pti_datamodule.train_dataset) |
| 875 | ) * args.gradient_accumulation_steps | 885 | ) * args.pti_gradient_accumulation_steps |
| 876 | pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) | 886 | pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) |
| 877 | 887 | ||
| 878 | pti_optimizer = create_optimizer( | 888 | pti_optimizer = create_optimizer( |
| @@ -886,6 +896,7 @@ def main(): | |||
| 886 | ) | 896 | ) |
| 887 | 897 | ||
| 888 | pti_lr_scheduler = create_lr_scheduler( | 898 | pti_lr_scheduler = create_lr_scheduler( |
| 899 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | ||
| 889 | optimizer=pti_optimizer, | 900 | optimizer=pti_optimizer, |
| 890 | num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), | 901 | num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), |
| 891 | train_epochs=num_pti_epochs, | 902 | train_epochs=num_pti_epochs, |
| @@ -893,12 +904,13 @@ def main(): | |||
| 893 | 904 | ||
| 894 | metrics = trainer( | 905 | metrics = trainer( |
| 895 | strategy=textual_inversion_strategy, | 906 | strategy=textual_inversion_strategy, |
| 896 | project="ti", | 907 | project="pti", |
| 897 | train_dataloader=pti_datamodule.train_dataloader, | 908 | train_dataloader=pti_datamodule.train_dataloader, |
| 898 | val_dataloader=pti_datamodule.val_dataloader, | 909 | val_dataloader=pti_datamodule.val_dataloader, |
| 899 | optimizer=pti_optimizer, | 910 | optimizer=pti_optimizer, |
| 900 | lr_scheduler=pti_lr_scheduler, | 911 | lr_scheduler=pti_lr_scheduler, |
| 901 | num_train_epochs=num_pti_epochs, | 912 | num_train_epochs=num_pti_epochs, |
| 913 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | ||
| 902 | # -- | 914 | # -- |
| 903 | sample_output_dir=pti_sample_output_dir, | 915 | sample_output_dir=pti_sample_output_dir, |
| 904 | checkpoint_output_dir=pti_checkpoint_output_dir, | 916 | checkpoint_output_dir=pti_checkpoint_output_dir, |
| @@ -920,6 +932,7 @@ def main(): | |||
| 920 | lora_sample_output_dir = lora_output_dir / "samples" | 932 | lora_sample_output_dir = lora_output_dir / "samples" |
| 921 | 933 | ||
| 922 | lora_datamodule = create_datamodule( | 934 | lora_datamodule = create_datamodule( |
| 935 | batch_size=args.train_batch_size, | ||
| 923 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), | 936 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), |
| 924 | ) | 937 | ) |
| 925 | lora_datamodule.setup() | 938 | lora_datamodule.setup() |
| @@ -954,6 +967,7 @@ def main(): | |||
| 954 | ) | 967 | ) |
| 955 | 968 | ||
| 956 | lora_lr_scheduler = create_lr_scheduler( | 969 | lora_lr_scheduler = create_lr_scheduler( |
| 970 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 957 | optimizer=lora_optimizer, | 971 | optimizer=lora_optimizer, |
| 958 | num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), | 972 | num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), |
| 959 | train_epochs=num_train_epochs, | 973 | train_epochs=num_train_epochs, |
| @@ -967,6 +981,7 @@ def main(): | |||
| 967 | optimizer=lora_optimizer, | 981 | optimizer=lora_optimizer, |
| 968 | lr_scheduler=lora_lr_scheduler, | 982 | lr_scheduler=lora_lr_scheduler, |
| 969 | num_train_epochs=num_train_epochs, | 983 | num_train_epochs=num_train_epochs, |
| 984 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 970 | # -- | 985 | # -- |
| 971 | sample_output_dir=lora_sample_output_dir, | 986 | sample_output_dir=lora_sample_output_dir, |
| 972 | checkpoint_output_dir=lora_checkpoint_output_dir, | 987 | checkpoint_output_dir=lora_checkpoint_output_dir, |
