diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-03 12:40:16 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-03 12:40:16 +0100 |
| commit | a72b6260c117cabe4fcb2996cce4f870986df99b (patch) | |
| tree | 7c9c7704c6ef60a4ab886d5acbce4e6e22398b56 /train_dreambooth.py | |
| parent | Fixed LR finder (diff) | |
| download | textual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.tar.gz textual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.tar.bz2 textual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.zip | |
Added vector dropout
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 24 |
1 files changed, 19 insertions, 5 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 218018b..f26b7f5 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -108,6 +108,12 @@ def parse_args(): | |||
| 108 | help="Tag dropout probability.", | 108 | help="Tag dropout probability.", |
| 109 | ) | 109 | ) |
| 110 | parser.add_argument( | 110 | parser.add_argument( |
| 111 | "--vector_dropout", | ||
| 112 | type=int, | ||
| 113 | default=0.1, | ||
| 114 | help="Vector dropout probability.", | ||
| 115 | ) | ||
| 116 | parser.add_argument( | ||
| 111 | "--vector_shuffle", | 117 | "--vector_shuffle", |
| 112 | type=str, | 118 | type=str, |
| 113 | default="auto", | 119 | default="auto", |
| @@ -556,6 +562,8 @@ def main(): | |||
| 556 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) | 562 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) |
| 557 | elif args.pretrained_model_name_or_path: | 563 | elif args.pretrained_model_name_or_path: |
| 558 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 564 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
| 565 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
| 566 | tokenizer.set_dropout(args.vector_dropout) | ||
| 559 | 567 | ||
| 560 | # Load models and create wrapper for stable diffusion | 568 | # Load models and create wrapper for stable diffusion |
| 561 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') | 569 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') |
| @@ -826,6 +834,12 @@ def main(): | |||
| 826 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 834 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
| 827 | val_steps = num_val_steps_per_epoch * num_epochs | 835 | val_steps = num_val_steps_per_epoch * num_epochs |
| 828 | 836 | ||
| 837 | def on_train(): | ||
| 838 | tokenizer.train() | ||
| 839 | |||
| 840 | def on_eval(): | ||
| 841 | tokenizer.eval() | ||
| 842 | |||
| 829 | def loop(batch): | 843 | def loop(batch): |
| 830 | # Convert images to latent space | 844 | # Convert images to latent space |
| 831 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 845 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| @@ -898,8 +912,8 @@ def main(): | |||
| 898 | train_dataloader, | 912 | train_dataloader, |
| 899 | val_dataloader, | 913 | val_dataloader, |
| 900 | loop, | 914 | loop, |
| 901 | on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), | 915 | on_train=tokenizer.train, |
| 902 | on_eval=lambda: tokenizer.set_use_vector_shuffle(False) | 916 | on_eval=tokenizer.eval, |
| 903 | ) | 917 | ) |
| 904 | lr_finder.run(end_lr=1e2) | 918 | lr_finder.run(end_lr=1e2) |
| 905 | 919 | ||
| @@ -953,7 +967,7 @@ def main(): | |||
| 953 | disable=not accelerator.is_local_main_process, | 967 | disable=not accelerator.is_local_main_process, |
| 954 | dynamic_ncols=True | 968 | dynamic_ncols=True |
| 955 | ) | 969 | ) |
| 956 | local_progress_bar.set_description("Epoch X / Y") | 970 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") |
| 957 | 971 | ||
| 958 | global_progress_bar = tqdm( | 972 | global_progress_bar = tqdm( |
| 959 | range(args.max_train_steps + val_steps), | 973 | range(args.max_train_steps + val_steps), |
| @@ -976,7 +990,7 @@ def main(): | |||
| 976 | text_encoder.train() | 990 | text_encoder.train() |
| 977 | elif epoch == args.train_text_encoder_epochs: | 991 | elif epoch == args.train_text_encoder_epochs: |
| 978 | text_encoder.requires_grad_(False) | 992 | text_encoder.requires_grad_(False) |
| 979 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 993 | on_train() |
| 980 | 994 | ||
| 981 | for step, batch in enumerate(train_dataloader): | 995 | for step, batch in enumerate(train_dataloader): |
| 982 | with accelerator.accumulate(unet): | 996 | with accelerator.accumulate(unet): |
| @@ -1030,7 +1044,7 @@ def main(): | |||
| 1030 | 1044 | ||
| 1031 | unet.eval() | 1045 | unet.eval() |
| 1032 | text_encoder.eval() | 1046 | text_encoder.eval() |
| 1033 | tokenizer.set_use_vector_shuffle(False) | 1047 | on_eval() |
| 1034 | 1048 | ||
| 1035 | cur_loss_val = AverageMeter() | 1049 | cur_loss_val = AverageMeter() |
| 1036 | cur_acc_val = AverageMeter() | 1050 | cur_acc_val = AverageMeter() |
