summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-07 11:31:21 +0200
committerVolpeon <git@volpeon.ink>2023-04-07 11:31:21 +0200
commit37f5057e8c46693fb4ad02ad3b66b6c1ae46e79f (patch)
tree1f18d01cc23418789b6b4b00b38edc0a80b6214a /train_lora.py
parentFix (diff)
downloadtextual-inversion-diff-37f5057e8c46693fb4ad02ad3b66b6c1ae46e79f.tar.gz
textual-inversion-diff-37f5057e8c46693fb4ad02ad3b66b6c1ae46e79f.tar.bz2
textual-inversion-diff-37f5057e8c46693fb4ad02ad3b66b6c1ae46e79f.zip
Run PTI only if placeholder tokens arg isn't empty
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py109
1 files changed, 55 insertions, 54 deletions
diff --git a/train_lora.py b/train_lora.py
index 6de3a75..daf1f6c 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -867,62 +867,63 @@ def main():
867 # PTI 867 # PTI
868 # -------------------------------------------------------------------------------- 868 # --------------------------------------------------------------------------------
869 869
870 pti_output_dir = output_dir / "pti" 870 if len(args.placeholder_tokens) != 0:
871 pti_checkpoint_output_dir = pti_output_dir / "model" 871 pti_output_dir = output_dir / "pti"
872 pti_sample_output_dir = pti_output_dir / "samples" 872 pti_checkpoint_output_dir = pti_output_dir / "model"
873 873 pti_sample_output_dir = pti_output_dir / "samples"
874 pti_datamodule = create_datamodule( 874
875 batch_size=args.pti_batch_size, 875 pti_datamodule = create_datamodule(
876 filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), 876 batch_size=args.pti_batch_size,
877 ) 877 filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections),
878 pti_datamodule.setup() 878 )
879 879 pti_datamodule.setup()
880 num_pti_epochs = args.num_pti_epochs 880
881 pti_sample_frequency = args.sample_frequency 881 num_pti_epochs = args.num_pti_epochs
882 if num_pti_epochs is None: 882 pti_sample_frequency = args.sample_frequency
883 num_pti_epochs = math.ceil( 883 if num_pti_epochs is None:
884 args.num_pti_steps / len(pti_datamodule.train_dataset) 884 num_pti_epochs = math.ceil(
885 ) * args.pti_gradient_accumulation_steps 885 args.num_pti_steps / len(pti_datamodule.train_dataset)
886 pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) 886 ) * args.pti_gradient_accumulation_steps
887 887 pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps))
888 pti_optimizer = create_optimizer( 888
889 [ 889 pti_optimizer = create_optimizer(
890 { 890 [
891 "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), 891 {
892 "lr": args.learning_rate_pti, 892 "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(),
893 "weight_decay": 0, 893 "lr": args.learning_rate_pti,
894 }, 894 "weight_decay": 0,
895 ] 895 },
896 ) 896 ]
897 )
897 898
898 pti_lr_scheduler = create_lr_scheduler( 899 pti_lr_scheduler = create_lr_scheduler(
899 gradient_accumulation_steps=args.pti_gradient_accumulation_steps, 900 gradient_accumulation_steps=args.pti_gradient_accumulation_steps,
900 optimizer=pti_optimizer, 901 optimizer=pti_optimizer,
901 num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), 902 num_training_steps_per_epoch=len(pti_datamodule.train_dataloader),
902 train_epochs=num_pti_epochs, 903 train_epochs=num_pti_epochs,
903 ) 904 )
904 905
905 metrics = trainer( 906 metrics = trainer(
906 strategy=textual_inversion_strategy, 907 strategy=textual_inversion_strategy,
907 project="pti", 908 project="pti",
908 train_dataloader=pti_datamodule.train_dataloader, 909 train_dataloader=pti_datamodule.train_dataloader,
909 val_dataloader=pti_datamodule.val_dataloader, 910 val_dataloader=pti_datamodule.val_dataloader,
910 optimizer=pti_optimizer, 911 optimizer=pti_optimizer,
911 lr_scheduler=pti_lr_scheduler, 912 lr_scheduler=pti_lr_scheduler,
912 num_train_epochs=num_pti_epochs, 913 num_train_epochs=num_pti_epochs,
913 gradient_accumulation_steps=args.pti_gradient_accumulation_steps, 914 gradient_accumulation_steps=args.pti_gradient_accumulation_steps,
914 # -- 915 # --
915 sample_output_dir=pti_sample_output_dir, 916 sample_output_dir=pti_sample_output_dir,
916 checkpoint_output_dir=pti_checkpoint_output_dir, 917 checkpoint_output_dir=pti_checkpoint_output_dir,
917 sample_frequency=pti_sample_frequency, 918 sample_frequency=pti_sample_frequency,
918 placeholder_tokens=args.placeholder_tokens, 919 placeholder_tokens=args.placeholder_tokens,
919 placeholder_token_ids=placeholder_token_ids, 920 placeholder_token_ids=placeholder_token_ids,
920 use_emb_decay=args.use_emb_decay, 921 use_emb_decay=args.use_emb_decay,
921 emb_decay_target=args.emb_decay_target, 922 emb_decay_target=args.emb_decay_target,
922 emb_decay=args.emb_decay, 923 emb_decay=args.emb_decay,
923 ) 924 )
924 925
925 plot_metrics(metrics, output_dir/"lr.png") 926 plot_metrics(metrics, pti_output_dir / "lr.png")
926 927
927 # LORA 928 # LORA
928 # -------------------------------------------------------------------------------- 929 # --------------------------------------------------------------------------------
@@ -994,7 +995,7 @@ def main():
994 max_grad_norm=args.max_grad_norm, 995 max_grad_norm=args.max_grad_norm,
995 ) 996 )
996 997
997 plot_metrics(metrics, output_dir/"lr.png") 998 plot_metrics(metrics, lora_output_dir / "lr.png")
998 999
999 1000
1000if __name__ == "__main__": 1001if __name__ == "__main__":