summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_lora.py109
-rw-r--r--train_ti.py12
2 files changed, 61 insertions, 60 deletions
diff --git a/train_lora.py b/train_lora.py
index 6de3a75..daf1f6c 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -867,62 +867,63 @@ def main():
867 # PTI 867 # PTI
868 # -------------------------------------------------------------------------------- 868 # --------------------------------------------------------------------------------
869 869
870 pti_output_dir = output_dir / "pti" 870 if len(args.placeholder_tokens) != 0:
871 pti_checkpoint_output_dir = pti_output_dir / "model" 871 pti_output_dir = output_dir / "pti"
872 pti_sample_output_dir = pti_output_dir / "samples" 872 pti_checkpoint_output_dir = pti_output_dir / "model"
873 873 pti_sample_output_dir = pti_output_dir / "samples"
874 pti_datamodule = create_datamodule( 874
875 batch_size=args.pti_batch_size, 875 pti_datamodule = create_datamodule(
876 filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), 876 batch_size=args.pti_batch_size,
877 ) 877 filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections),
878 pti_datamodule.setup() 878 )
879 879 pti_datamodule.setup()
880 num_pti_epochs = args.num_pti_epochs 880
881 pti_sample_frequency = args.sample_frequency 881 num_pti_epochs = args.num_pti_epochs
882 if num_pti_epochs is None: 882 pti_sample_frequency = args.sample_frequency
883 num_pti_epochs = math.ceil( 883 if num_pti_epochs is None:
884 args.num_pti_steps / len(pti_datamodule.train_dataset) 884 num_pti_epochs = math.ceil(
885 ) * args.pti_gradient_accumulation_steps 885 args.num_pti_steps / len(pti_datamodule.train_dataset)
886 pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) 886 ) * args.pti_gradient_accumulation_steps
887 887 pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps))
888 pti_optimizer = create_optimizer( 888
889 [ 889 pti_optimizer = create_optimizer(
890 { 890 [
891 "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), 891 {
892 "lr": args.learning_rate_pti, 892 "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(),
893 "weight_decay": 0, 893 "lr": args.learning_rate_pti,
894 }, 894 "weight_decay": 0,
895 ] 895 },
896 ) 896 ]
897 )
897 898
898 pti_lr_scheduler = create_lr_scheduler( 899 pti_lr_scheduler = create_lr_scheduler(
899 gradient_accumulation_steps=args.pti_gradient_accumulation_steps, 900 gradient_accumulation_steps=args.pti_gradient_accumulation_steps,
900 optimizer=pti_optimizer, 901 optimizer=pti_optimizer,
901 num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), 902 num_training_steps_per_epoch=len(pti_datamodule.train_dataloader),
902 train_epochs=num_pti_epochs, 903 train_epochs=num_pti_epochs,
903 ) 904 )
904 905
905 metrics = trainer( 906 metrics = trainer(
906 strategy=textual_inversion_strategy, 907 strategy=textual_inversion_strategy,
907 project="pti", 908 project="pti",
908 train_dataloader=pti_datamodule.train_dataloader, 909 train_dataloader=pti_datamodule.train_dataloader,
909 val_dataloader=pti_datamodule.val_dataloader, 910 val_dataloader=pti_datamodule.val_dataloader,
910 optimizer=pti_optimizer, 911 optimizer=pti_optimizer,
911 lr_scheduler=pti_lr_scheduler, 912 lr_scheduler=pti_lr_scheduler,
912 num_train_epochs=num_pti_epochs, 913 num_train_epochs=num_pti_epochs,
913 gradient_accumulation_steps=args.pti_gradient_accumulation_steps, 914 gradient_accumulation_steps=args.pti_gradient_accumulation_steps,
914 # -- 915 # --
915 sample_output_dir=pti_sample_output_dir, 916 sample_output_dir=pti_sample_output_dir,
916 checkpoint_output_dir=pti_checkpoint_output_dir, 917 checkpoint_output_dir=pti_checkpoint_output_dir,
917 sample_frequency=pti_sample_frequency, 918 sample_frequency=pti_sample_frequency,
918 placeholder_tokens=args.placeholder_tokens, 919 placeholder_tokens=args.placeholder_tokens,
919 placeholder_token_ids=placeholder_token_ids, 920 placeholder_token_ids=placeholder_token_ids,
920 use_emb_decay=args.use_emb_decay, 921 use_emb_decay=args.use_emb_decay,
921 emb_decay_target=args.emb_decay_target, 922 emb_decay_target=args.emb_decay_target,
922 emb_decay=args.emb_decay, 923 emb_decay=args.emb_decay,
923 ) 924 )
924 925
925 plot_metrics(metrics, output_dir/"lr.png") 926 plot_metrics(metrics, pti_output_dir / "lr.png")
926 927
927 # LORA 928 # LORA
928 # -------------------------------------------------------------------------------- 929 # --------------------------------------------------------------------------------
@@ -994,7 +995,7 @@ def main():
994 max_grad_norm=args.max_grad_norm, 995 max_grad_norm=args.max_grad_norm,
995 ) 996 )
996 997
997 plot_metrics(metrics, output_dir/"lr.png") 998 plot_metrics(metrics, lora_output_dir / "lr.png")
998 999
999 1000
1000if __name__ == "__main__": 1001if __name__ == "__main__":
diff --git a/train_ti.py b/train_ti.py
index 344b412..c1c0eed 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -602,7 +602,7 @@ def main():
602 elif args.mixed_precision == "bf16": 602 elif args.mixed_precision == "bf16":
603 weight_dtype = torch.bfloat16 603 weight_dtype = torch.bfloat16
604 604
605 logging.basicConfig(filename=output_dir/"log.txt", level=logging.DEBUG) 605 logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG)
606 606
607 if args.seed is None: 607 if args.seed is None:
608 args.seed = torch.random.seed() >> 32 608 args.seed = torch.random.seed() >> 32
@@ -743,7 +743,7 @@ def main():
743 else: 743 else:
744 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 744 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
745 745
746 checkpoint_output_dir = output_dir/"checkpoints" 746 checkpoint_output_dir = output_dir / "checkpoints"
747 747
748 trainer = partial( 748 trainer = partial(
749 train, 749 train,
@@ -782,11 +782,11 @@ def main():
782 782
783 def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): 783 def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template):
784 if len(placeholder_tokens) == 1: 784 if len(placeholder_tokens) == 1:
785 sample_output_dir = output_dir/f"samples_{placeholder_tokens[0]}" 785 sample_output_dir = output_dir / f"samples_{placeholder_tokens[0]}"
786 metrics_output_file = output_dir/f"{placeholder_tokens[0]}.png" 786 metrics_output_file = output_dir / f"{placeholder_tokens[0]}.png"
787 else: 787 else:
788 sample_output_dir = output_dir/"samples" 788 sample_output_dir = output_dir / "samples"
789 metrics_output_file = output_dir/f"lr.png" 789 metrics_output_file = output_dir / "lr.png"
790 790
791 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 791 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
792 tokenizer=tokenizer, 792 tokenizer=tokenizer,