summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-03 12:40:16 +0100
committerVolpeon <git@volpeon.ink>2023-01-03 12:40:16 +0100
commita72b6260c117cabe4fcb2996cce4f870986df99b (patch)
tree7c9c7704c6ef60a4ab886d5acbce4e6e22398b56 /train_ti.py
parentFixed LR finder (diff)
downloadtextual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.tar.gz
textual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.tar.bz2
textual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.zip
Added vector dropout
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py24
1 files changed, 19 insertions, 5 deletions
diff --git a/train_ti.py b/train_ti.py
index 102c0fa..cacbbc7 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -155,6 +155,12 @@ def parse_args():
155 help="Tag dropout probability.", 155 help="Tag dropout probability.",
156 ) 156 )
157 parser.add_argument( 157 parser.add_argument(
158 "--vector_dropout",
159 type=int,
160 default=0.1,
161 help="Vector dropout probability.",
162 )
163 parser.add_argument(
158 "--vector_shuffle", 164 "--vector_shuffle",
159 type=str, 165 type=str,
160 default="auto", 166 default="auto",
@@ -526,6 +532,8 @@ def main():
526 tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) 532 tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name)
527 elif args.pretrained_model_name_or_path: 533 elif args.pretrained_model_name_or_path:
528 tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') 534 tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
535 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
536 tokenizer.set_dropout(args.vector_dropout)
529 537
530 # Load models and create wrapper for stable diffusion 538 # Load models and create wrapper for stable diffusion
531 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') 539 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
@@ -777,6 +785,12 @@ def main():
777 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 785 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
778 val_steps = num_val_steps_per_epoch * num_epochs 786 val_steps = num_val_steps_per_epoch * num_epochs
779 787
788 def on_train():
789 tokenizer.train()
790
791 def on_eval():
792 tokenizer.eval()
793
780 def loop(batch): 794 def loop(batch):
781 # Convert images to latent space 795 # Convert images to latent space
782 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() 796 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
@@ -850,8 +864,8 @@ def main():
850 train_dataloader, 864 train_dataloader,
851 val_dataloader, 865 val_dataloader,
852 loop, 866 loop,
853 on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), 867 on_train=on_train,
854 on_eval=lambda: tokenizer.set_use_vector_shuffle(False) 868 on_eval=on_eval,
855 ) 869 )
856 lr_finder.run(end_lr=1e2) 870 lr_finder.run(end_lr=1e2)
857 871
@@ -903,7 +917,7 @@ def main():
903 disable=not accelerator.is_local_main_process, 917 disable=not accelerator.is_local_main_process,
904 dynamic_ncols=True 918 dynamic_ncols=True
905 ) 919 )
906 local_progress_bar.set_description("Epoch X / Y") 920 local_progress_bar.set_description(f"Epoch 1 / {num_epochs}")
907 921
908 global_progress_bar = tqdm( 922 global_progress_bar = tqdm(
909 range(args.max_train_steps + val_steps), 923 range(args.max_train_steps + val_steps),
@@ -922,7 +936,7 @@ def main():
922 local_progress_bar.reset() 936 local_progress_bar.reset()
923 937
924 text_encoder.train() 938 text_encoder.train()
925 tokenizer.set_use_vector_shuffle(args.vector_shuffle) 939 on_train()
926 940
927 for step, batch in enumerate(train_dataloader): 941 for step, batch in enumerate(train_dataloader):
928 with accelerator.accumulate(text_encoder): 942 with accelerator.accumulate(text_encoder):
@@ -963,7 +977,7 @@ def main():
963 accelerator.wait_for_everyone() 977 accelerator.wait_for_everyone()
964 978
965 text_encoder.eval() 979 text_encoder.eval()
966 tokenizer.set_use_vector_shuffle(False) 980 on_eval()
967 981
968 cur_loss_val = AverageMeter() 982 cur_loss_val = AverageMeter()
969 cur_acc_val = AverageMeter() 983 cur_acc_val = AverageMeter()