diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-03 12:40:16 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-03 12:40:16 +0100 |
| commit | a72b6260c117cabe4fcb2996cce4f870986df99b (patch) | |
| tree | 7c9c7704c6ef60a4ab886d5acbce4e6e22398b56 /train_ti.py | |
| parent | Fixed LR finder (diff) | |
| download | textual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.tar.gz textual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.tar.bz2 textual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.zip | |
Added vector dropout
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 24 |
1 files changed, 19 insertions, 5 deletions
diff --git a/train_ti.py b/train_ti.py index 102c0fa..cacbbc7 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -155,6 +155,12 @@ def parse_args(): | |||
| 155 | help="Tag dropout probability.", | 155 | help="Tag dropout probability.", |
| 156 | ) | 156 | ) |
| 157 | parser.add_argument( | 157 | parser.add_argument( |
| 158 | "--vector_dropout", | ||
| 159 | type=int, | ||
| 160 | default=0.1, | ||
| 161 | help="Vector dropout probability.", | ||
| 162 | ) | ||
| 163 | parser.add_argument( | ||
| 158 | "--vector_shuffle", | 164 | "--vector_shuffle", |
| 159 | type=str, | 165 | type=str, |
| 160 | default="auto", | 166 | default="auto", |
| @@ -526,6 +532,8 @@ def main(): | |||
| 526 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) | 532 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) |
| 527 | elif args.pretrained_model_name_or_path: | 533 | elif args.pretrained_model_name_or_path: |
| 528 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 534 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
| 535 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
| 536 | tokenizer.set_dropout(args.vector_dropout) | ||
| 529 | 537 | ||
| 530 | # Load models and create wrapper for stable diffusion | 538 | # Load models and create wrapper for stable diffusion |
| 531 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') | 539 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') |
| @@ -777,6 +785,12 @@ def main(): | |||
| 777 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 785 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
| 778 | val_steps = num_val_steps_per_epoch * num_epochs | 786 | val_steps = num_val_steps_per_epoch * num_epochs |
| 779 | 787 | ||
| 788 | def on_train(): | ||
| 789 | tokenizer.train() | ||
| 790 | |||
| 791 | def on_eval(): | ||
| 792 | tokenizer.eval() | ||
| 793 | |||
| 780 | def loop(batch): | 794 | def loop(batch): |
| 781 | # Convert images to latent space | 795 | # Convert images to latent space |
| 782 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | 796 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() |
| @@ -850,8 +864,8 @@ def main(): | |||
| 850 | train_dataloader, | 864 | train_dataloader, |
| 851 | val_dataloader, | 865 | val_dataloader, |
| 852 | loop, | 866 | loop, |
| 853 | on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), | 867 | on_train=on_train, |
| 854 | on_eval=lambda: tokenizer.set_use_vector_shuffle(False) | 868 | on_eval=on_eval, |
| 855 | ) | 869 | ) |
| 856 | lr_finder.run(end_lr=1e2) | 870 | lr_finder.run(end_lr=1e2) |
| 857 | 871 | ||
| @@ -903,7 +917,7 @@ def main(): | |||
| 903 | disable=not accelerator.is_local_main_process, | 917 | disable=not accelerator.is_local_main_process, |
| 904 | dynamic_ncols=True | 918 | dynamic_ncols=True |
| 905 | ) | 919 | ) |
| 906 | local_progress_bar.set_description("Epoch X / Y") | 920 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") |
| 907 | 921 | ||
| 908 | global_progress_bar = tqdm( | 922 | global_progress_bar = tqdm( |
| 909 | range(args.max_train_steps + val_steps), | 923 | range(args.max_train_steps + val_steps), |
| @@ -922,7 +936,7 @@ def main(): | |||
| 922 | local_progress_bar.reset() | 936 | local_progress_bar.reset() |
| 923 | 937 | ||
| 924 | text_encoder.train() | 938 | text_encoder.train() |
| 925 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 939 | on_train() |
| 926 | 940 | ||
| 927 | for step, batch in enumerate(train_dataloader): | 941 | for step, batch in enumerate(train_dataloader): |
| 928 | with accelerator.accumulate(text_encoder): | 942 | with accelerator.accumulate(text_encoder): |
| @@ -963,7 +977,7 @@ def main(): | |||
| 963 | accelerator.wait_for_everyone() | 977 | accelerator.wait_for_everyone() |
| 964 | 978 | ||
| 965 | text_encoder.eval() | 979 | text_encoder.eval() |
| 966 | tokenizer.set_use_vector_shuffle(False) | 980 | on_eval() |
| 967 | 981 | ||
| 968 | cur_loss_val = AverageMeter() | 982 | cur_loss_val = AverageMeter() |
| 969 | cur_acc_val = AverageMeter() | 983 | cur_acc_val = AverageMeter() |
