diff options
author | Volpeon <git@volpeon.ink> | 2023-01-03 12:40:16 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-03 12:40:16 +0100 |
commit | a72b6260c117cabe4fcb2996cce4f870986df99b (patch) | |
tree | 7c9c7704c6ef60a4ab886d5acbce4e6e22398b56 /train_ti.py | |
parent | Fixed LR finder (diff) | |
download | textual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.tar.gz textual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.tar.bz2 textual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.zip |
Added vector dropout
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 24 |
1 files changed, 19 insertions, 5 deletions
diff --git a/train_ti.py b/train_ti.py index 102c0fa..cacbbc7 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -155,6 +155,12 @@ def parse_args(): | |||
155 | help="Tag dropout probability.", | 155 | help="Tag dropout probability.", |
156 | ) | 156 | ) |
157 | parser.add_argument( | 157 | parser.add_argument( |
158 | "--vector_dropout", | ||
159 | type=int, | ||
160 | default=0.1, | ||
161 | help="Vector dropout probability.", | ||
162 | ) | ||
163 | parser.add_argument( | ||
158 | "--vector_shuffle", | 164 | "--vector_shuffle", |
159 | type=str, | 165 | type=str, |
160 | default="auto", | 166 | default="auto", |
@@ -526,6 +532,8 @@ def main(): | |||
526 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) | 532 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) |
527 | elif args.pretrained_model_name_or_path: | 533 | elif args.pretrained_model_name_or_path: |
528 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 534 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
535 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
536 | tokenizer.set_dropout(args.vector_dropout) | ||
529 | 537 | ||
530 | # Load models and create wrapper for stable diffusion | 538 | # Load models and create wrapper for stable diffusion |
531 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') | 539 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') |
@@ -777,6 +785,12 @@ def main(): | |||
777 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 785 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
778 | val_steps = num_val_steps_per_epoch * num_epochs | 786 | val_steps = num_val_steps_per_epoch * num_epochs |
779 | 787 | ||
788 | def on_train(): | ||
789 | tokenizer.train() | ||
790 | |||
791 | def on_eval(): | ||
792 | tokenizer.eval() | ||
793 | |||
780 | def loop(batch): | 794 | def loop(batch): |
781 | # Convert images to latent space | 795 | # Convert images to latent space |
782 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | 796 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() |
@@ -850,8 +864,8 @@ def main(): | |||
850 | train_dataloader, | 864 | train_dataloader, |
851 | val_dataloader, | 865 | val_dataloader, |
852 | loop, | 866 | loop, |
853 | on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), | 867 | on_train=on_train, |
854 | on_eval=lambda: tokenizer.set_use_vector_shuffle(False) | 868 | on_eval=on_eval, |
855 | ) | 869 | ) |
856 | lr_finder.run(end_lr=1e2) | 870 | lr_finder.run(end_lr=1e2) |
857 | 871 | ||
@@ -903,7 +917,7 @@ def main(): | |||
903 | disable=not accelerator.is_local_main_process, | 917 | disable=not accelerator.is_local_main_process, |
904 | dynamic_ncols=True | 918 | dynamic_ncols=True |
905 | ) | 919 | ) |
906 | local_progress_bar.set_description("Epoch X / Y") | 920 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") |
907 | 921 | ||
908 | global_progress_bar = tqdm( | 922 | global_progress_bar = tqdm( |
909 | range(args.max_train_steps + val_steps), | 923 | range(args.max_train_steps + val_steps), |
@@ -922,7 +936,7 @@ def main(): | |||
922 | local_progress_bar.reset() | 936 | local_progress_bar.reset() |
923 | 937 | ||
924 | text_encoder.train() | 938 | text_encoder.train() |
925 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 939 | on_train() |
926 | 940 | ||
927 | for step, batch in enumerate(train_dataloader): | 941 | for step, batch in enumerate(train_dataloader): |
928 | with accelerator.accumulate(text_encoder): | 942 | with accelerator.accumulate(text_encoder): |
@@ -963,7 +977,7 @@ def main(): | |||
963 | accelerator.wait_for_everyone() | 977 | accelerator.wait_for_everyone() |
964 | 978 | ||
965 | text_encoder.eval() | 979 | text_encoder.eval() |
966 | tokenizer.set_use_vector_shuffle(False) | 980 | on_eval() |
967 | 981 | ||
968 | cur_loss_val = AverageMeter() | 982 | cur_loss_val = AverageMeter() |
969 | cur_acc_val = AverageMeter() | 983 | cur_acc_val = AverageMeter() |