summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-03 12:40:16 +0100
committerVolpeon <git@volpeon.ink>2023-01-03 12:40:16 +0100
commita72b6260c117cabe4fcb2996cce4f870986df99b (patch)
tree7c9c7704c6ef60a4ab886d5acbce4e6e22398b56 /train_dreambooth.py
parentFixed LR finder (diff)
downloadtextual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.tar.gz
textual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.tar.bz2
textual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.zip
Added vector dropout
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py24
1 files changed, 19 insertions, 5 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 218018b..f26b7f5 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -108,6 +108,12 @@ def parse_args():
108 help="Tag dropout probability.", 108 help="Tag dropout probability.",
109 ) 109 )
110 parser.add_argument( 110 parser.add_argument(
111 "--vector_dropout",
112 type=int,
113 default=0.1,
114 help="Vector dropout probability.",
115 )
116 parser.add_argument(
111 "--vector_shuffle", 117 "--vector_shuffle",
112 type=str, 118 type=str,
113 default="auto", 119 default="auto",
@@ -556,6 +562,8 @@ def main():
556 tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) 562 tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name)
557 elif args.pretrained_model_name_or_path: 563 elif args.pretrained_model_name_or_path:
558 tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') 564 tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
565 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
566 tokenizer.set_dropout(args.vector_dropout)
559 567
560 # Load models and create wrapper for stable diffusion 568 # Load models and create wrapper for stable diffusion
561 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') 569 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
@@ -826,6 +834,12 @@ def main():
826 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 834 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
827 val_steps = num_val_steps_per_epoch * num_epochs 835 val_steps = num_val_steps_per_epoch * num_epochs
828 836
837 def on_train():
838 tokenizer.train()
839
840 def on_eval():
841 tokenizer.eval()
842
829 def loop(batch): 843 def loop(batch):
830 # Convert images to latent space 844 # Convert images to latent space
831 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 845 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
@@ -898,8 +912,8 @@ def main():
898 train_dataloader, 912 train_dataloader,
899 val_dataloader, 913 val_dataloader,
900 loop, 914 loop,
901 on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), 915 on_train=tokenizer.train,
902 on_eval=lambda: tokenizer.set_use_vector_shuffle(False) 916 on_eval=tokenizer.eval,
903 ) 917 )
904 lr_finder.run(end_lr=1e2) 918 lr_finder.run(end_lr=1e2)
905 919
@@ -953,7 +967,7 @@ def main():
953 disable=not accelerator.is_local_main_process, 967 disable=not accelerator.is_local_main_process,
954 dynamic_ncols=True 968 dynamic_ncols=True
955 ) 969 )
956 local_progress_bar.set_description("Epoch X / Y") 970 local_progress_bar.set_description(f"Epoch 1 / {num_epochs}")
957 971
958 global_progress_bar = tqdm( 972 global_progress_bar = tqdm(
959 range(args.max_train_steps + val_steps), 973 range(args.max_train_steps + val_steps),
@@ -976,7 +990,7 @@ def main():
976 text_encoder.train() 990 text_encoder.train()
977 elif epoch == args.train_text_encoder_epochs: 991 elif epoch == args.train_text_encoder_epochs:
978 text_encoder.requires_grad_(False) 992 text_encoder.requires_grad_(False)
979 tokenizer.set_use_vector_shuffle(args.vector_shuffle) 993 on_train()
980 994
981 for step, batch in enumerate(train_dataloader): 995 for step, batch in enumerate(train_dataloader):
982 with accelerator.accumulate(unet): 996 with accelerator.accumulate(unet):
@@ -1030,7 +1044,7 @@ def main():
1030 1044
1031 unet.eval() 1045 unet.eval()
1032 text_encoder.eval() 1046 text_encoder.eval()
1033 tokenizer.set_use_vector_shuffle(False) 1047 on_eval()
1034 1048
1035 cur_loss_val = AverageMeter() 1049 cur_loss_val = AverageMeter()
1036 cur_acc_val = AverageMeter() 1050 cur_acc_val = AverageMeter()