diff options
-rw-r--r-- | train_lora.py | 109 | ||||
-rw-r--r-- | train_ti.py | 12 |
2 files changed, 61 insertions, 60 deletions
diff --git a/train_lora.py b/train_lora.py index 6de3a75..daf1f6c 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -867,62 +867,63 @@ def main(): | |||
867 | # PTI | 867 | # PTI |
868 | # -------------------------------------------------------------------------------- | 868 | # -------------------------------------------------------------------------------- |
869 | 869 | ||
870 | pti_output_dir = output_dir / "pti" | 870 | if len(args.placeholder_tokens) != 0: |
871 | pti_checkpoint_output_dir = pti_output_dir / "model" | 871 | pti_output_dir = output_dir / "pti" |
872 | pti_sample_output_dir = pti_output_dir / "samples" | 872 | pti_checkpoint_output_dir = pti_output_dir / "model" |
873 | 873 | pti_sample_output_dir = pti_output_dir / "samples" | |
874 | pti_datamodule = create_datamodule( | 874 | |
875 | batch_size=args.pti_batch_size, | 875 | pti_datamodule = create_datamodule( |
876 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), | 876 | batch_size=args.pti_batch_size, |
877 | ) | 877 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), |
878 | pti_datamodule.setup() | 878 | ) |
879 | 879 | pti_datamodule.setup() | |
880 | num_pti_epochs = args.num_pti_epochs | 880 | |
881 | pti_sample_frequency = args.sample_frequency | 881 | num_pti_epochs = args.num_pti_epochs |
882 | if num_pti_epochs is None: | 882 | pti_sample_frequency = args.sample_frequency |
883 | num_pti_epochs = math.ceil( | 883 | if num_pti_epochs is None: |
884 | args.num_pti_steps / len(pti_datamodule.train_dataset) | 884 | num_pti_epochs = math.ceil( |
885 | ) * args.pti_gradient_accumulation_steps | 885 | args.num_pti_steps / len(pti_datamodule.train_dataset) |
886 | pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) | 886 | ) * args.pti_gradient_accumulation_steps |
887 | 887 | pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) | |
888 | pti_optimizer = create_optimizer( | 888 | |
889 | [ | 889 | pti_optimizer = create_optimizer( |
890 | { | 890 | [ |
891 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), | 891 | { |
892 | "lr": args.learning_rate_pti, | 892 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), |
893 | "weight_decay": 0, | 893 | "lr": args.learning_rate_pti, |
894 | }, | 894 | "weight_decay": 0, |
895 | ] | 895 | }, |
896 | ) | 896 | ] |
897 | ) | ||
897 | 898 | ||
898 | pti_lr_scheduler = create_lr_scheduler( | 899 | pti_lr_scheduler = create_lr_scheduler( |
899 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | 900 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, |
900 | optimizer=pti_optimizer, | 901 | optimizer=pti_optimizer, |
901 | num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), | 902 | num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), |
902 | train_epochs=num_pti_epochs, | 903 | train_epochs=num_pti_epochs, |
903 | ) | 904 | ) |
904 | 905 | ||
905 | metrics = trainer( | 906 | metrics = trainer( |
906 | strategy=textual_inversion_strategy, | 907 | strategy=textual_inversion_strategy, |
907 | project="pti", | 908 | project="pti", |
908 | train_dataloader=pti_datamodule.train_dataloader, | 909 | train_dataloader=pti_datamodule.train_dataloader, |
909 | val_dataloader=pti_datamodule.val_dataloader, | 910 | val_dataloader=pti_datamodule.val_dataloader, |
910 | optimizer=pti_optimizer, | 911 | optimizer=pti_optimizer, |
911 | lr_scheduler=pti_lr_scheduler, | 912 | lr_scheduler=pti_lr_scheduler, |
912 | num_train_epochs=num_pti_epochs, | 913 | num_train_epochs=num_pti_epochs, |
913 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, | 914 | gradient_accumulation_steps=args.pti_gradient_accumulation_steps, |
914 | # -- | 915 | # -- |
915 | sample_output_dir=pti_sample_output_dir, | 916 | sample_output_dir=pti_sample_output_dir, |
916 | checkpoint_output_dir=pti_checkpoint_output_dir, | 917 | checkpoint_output_dir=pti_checkpoint_output_dir, |
917 | sample_frequency=pti_sample_frequency, | 918 | sample_frequency=pti_sample_frequency, |
918 | placeholder_tokens=args.placeholder_tokens, | 919 | placeholder_tokens=args.placeholder_tokens, |
919 | placeholder_token_ids=placeholder_token_ids, | 920 | placeholder_token_ids=placeholder_token_ids, |
920 | use_emb_decay=args.use_emb_decay, | 921 | use_emb_decay=args.use_emb_decay, |
921 | emb_decay_target=args.emb_decay_target, | 922 | emb_decay_target=args.emb_decay_target, |
922 | emb_decay=args.emb_decay, | 923 | emb_decay=args.emb_decay, |
923 | ) | 924 | ) |
924 | 925 | ||
925 | plot_metrics(metrics, output_dir/"lr.png") | 926 | plot_metrics(metrics, pti_output_dir / "lr.png") |
926 | 927 | ||
927 | # LORA | 928 | # LORA |
928 | # -------------------------------------------------------------------------------- | 929 | # -------------------------------------------------------------------------------- |
@@ -994,7 +995,7 @@ def main(): | |||
994 | max_grad_norm=args.max_grad_norm, | 995 | max_grad_norm=args.max_grad_norm, |
995 | ) | 996 | ) |
996 | 997 | ||
997 | plot_metrics(metrics, output_dir/"lr.png") | 998 | plot_metrics(metrics, lora_output_dir / "lr.png") |
998 | 999 | ||
999 | 1000 | ||
1000 | if __name__ == "__main__": | 1001 | if __name__ == "__main__": |
diff --git a/train_ti.py b/train_ti.py index 344b412..c1c0eed 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -602,7 +602,7 @@ def main(): | |||
602 | elif args.mixed_precision == "bf16": | 602 | elif args.mixed_precision == "bf16": |
603 | weight_dtype = torch.bfloat16 | 603 | weight_dtype = torch.bfloat16 |
604 | 604 | ||
605 | logging.basicConfig(filename=output_dir/"log.txt", level=logging.DEBUG) | 605 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) |
606 | 606 | ||
607 | if args.seed is None: | 607 | if args.seed is None: |
608 | args.seed = torch.random.seed() >> 32 | 608 | args.seed = torch.random.seed() >> 32 |
@@ -743,7 +743,7 @@ def main(): | |||
743 | else: | 743 | else: |
744 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 744 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
745 | 745 | ||
746 | checkpoint_output_dir = output_dir/"checkpoints" | 746 | checkpoint_output_dir = output_dir / "checkpoints" |
747 | 747 | ||
748 | trainer = partial( | 748 | trainer = partial( |
749 | train, | 749 | train, |
@@ -782,11 +782,11 @@ def main(): | |||
782 | 782 | ||
783 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): | 783 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): |
784 | if len(placeholder_tokens) == 1: | 784 | if len(placeholder_tokens) == 1: |
785 | sample_output_dir = output_dir/f"samples_{placeholder_tokens[0]}" | 785 | sample_output_dir = output_dir / f"samples_{placeholder_tokens[0]}" |
786 | metrics_output_file = output_dir/f"{placeholder_tokens[0]}.png" | 786 | metrics_output_file = output_dir / f"{placeholder_tokens[0]}.png" |
787 | else: | 787 | else: |
788 | sample_output_dir = output_dir/"samples" | 788 | sample_output_dir = output_dir / "samples" |
789 | metrics_output_file = output_dir/f"lr.png" | 789 | metrics_output_file = output_dir / "lr.png" |
790 | 790 | ||
791 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 791 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
792 | tokenizer=tokenizer, | 792 | tokenizer=tokenizer, |