diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 109 |
1 files changed, 55 insertions, 54 deletions
diff --git a/train_lora.py b/train_lora.py index 6de3a75..daf1f6c 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -867,62 +867,63 @@ def main(): | |||
867 | # PTI | 867 | # PTI |
868 | # -------------------------------------------------------------------------------- | 868 | # -------------------------------------------------------------------------------- |
869 | 869 | ||
870 | pti_output_dir = output_dir / "pti" | 870 | if len(args.placeholder_tokens) != 0: |
871 | pti_checkpoint_output_dir = pti_output_dir / "model" | 871 | pti_output_dir = output_dir / "pti" |
872 | pti_sample_output_dir = pti_output_dir / "samples" | 872 | pti_checkpoint_output_dir = pti_output_dir / "model" |
873 | 873 | pti_sample_output_dir = pti_output_dir / "samples" | |
874 | pti_datamodule = create_datamodule( | 874 | |
875 | batch_size=args.pti_batch_size, | 875 | pti_datamodule = create_datamodule( |
876 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), | 876 | batch_size=args.pti_batch_size, |
877 | ) | 877 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), |
878 | pti_datamodule.setup() | 878 | ) |
879 | 879 | pti_datamodule.setup() | |
880 | num_pti_epochs = args.num_pti_epochs | 880 | |
881 | pti_sample_frequency = args.sample_frequency | 881 | num_pti_epochs = args.num_pti_epochs |
882 | if num_pti_epochs is None: | 882 | pti_sample_frequency = args.sample_frequency |
883 | num_pti_epochs = math.ceil( | 883 | if num_pti_epochs is None: |
884 | args.num_pti_steps / len(pti_datamodule.train_dataset) | 884 | num_pti_epochs = math.ceil( |
885 | ) * args.pti_gradient_accumulation_steps | 885 | args.num_pti_steps / len(pti_datamodule.train_dataset) |
886 | pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) | 886 | ) * args.pti_gradient_accumulation_steps |
887 | 887 | pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) | |
888 | pti_optimizer = create_optimizer( | 888 | |
889 | [ | 889 | pti_optimizer = create_optimizer( |
890 | { | 890 | [ |
891 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), | 891 | { |
892 | "lr": args.learning_rate_pti, | 892 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), |
893 | "weight_decay": 0, | 893 | "lr": args.learning_rate_pti, |
894 | }, | 894 | "weight_decay": 0, |
895 | ] | 895 | }, |
896 | ) | 896 | ] |
897 | ) | ||
897 | 898 | ||
898 | pti_lr_scheduler = create_lr_scheduler( | 899 | pti_lr_scheduler = create_lr_scheduler( |
899 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | 900 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, |
900 | optimizer=pti_optimizer, | 901 | optimizer=pti_optimizer, |
901 | num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), | 902 | num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), |
902 | train_epochs=num_pti_epochs, | 903 | train_epochs=num_pti_epochs, |
903 | ) | 904 | ) |
904 | 905 | ||
905 | metrics = trainer( | 906 | metrics = trainer( |
906 | strategy=textual_inversion_strategy, | 907 | strategy=textual_inversion_strategy, |
907 | project="pti", | 908 | project="pti", |
908 | train_dataloader=pti_datamodule.train_dataloader, | 909 | train_dataloader=pti_datamodule.train_dataloader, |
909 | val_dataloader=pti_datamodule.val_dataloader, | 910 | val_dataloader=pti_datamodule.val_dataloader, |
910 | optimizer=pti_optimizer, | 911 | optimizer=pti_optimizer, |
911 | lr_scheduler=pti_lr_scheduler, | 912 | lr_scheduler=pti_lr_scheduler, |
912 | num_train_epochs=num_pti_epochs, | 913 | num_train_epochs=num_pti_epochs, |
913 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | 914 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, |
914 | # -- | 915 | # -- |
915 | sample_output_dir=pti_sample_output_dir, | 916 | sample_output_dir=pti_sample_output_dir, |
916 | checkpoint_output_dir=pti_checkpoint_output_dir, | 917 | checkpoint_output_dir=pti_checkpoint_output_dir, |
917 | sample_frequency=pti_sample_frequency, | 918 | sample_frequency=pti_sample_frequency, |
918 | placeholder_tokens=args.placeholder_tokens, | 919 | placeholder_tokens=args.placeholder_tokens, |
919 | placeholder_token_ids=placeholder_token_ids, | 920 | placeholder_token_ids=placeholder_token_ids, |
920 | use_emb_decay=args.use_emb_decay, | 921 | use_emb_decay=args.use_emb_decay, |
921 | emb_decay_target=args.emb_decay_target, | 922 | emb_decay_target=args.emb_decay_target, |
922 | emb_decay=args.emb_decay, | 923 | emb_decay=args.emb_decay, |
923 | ) | 924 | ) |
924 | 925 | ||
925 | plot_metrics(metrics, output_dir/"lr.png") | 926 | plot_metrics(metrics, pti_output_dir / "lr.png") |
926 | 927 | ||
927 | # LORA | 928 | # LORA |
928 | # -------------------------------------------------------------------------------- | 929 | # -------------------------------------------------------------------------------- |
@@ -994,7 +995,7 @@ def main(): | |||
994 | max_grad_norm=args.max_grad_norm, | 995 | max_grad_norm=args.max_grad_norm, |
995 | ) | 996 | ) |
996 | 997 | ||
997 | plot_metrics(metrics, output_dir/"lr.png") | 998 | plot_metrics(metrics, lora_output_dir / "lr.png") |
998 | 999 | ||
999 | 1000 | ||
1000 | if __name__ == "__main__": | 1001 | if __name__ == "__main__": |