diff options
author | Volpeon <git@volpeon.ink> | 2023-01-03 12:40:16 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-03 12:40:16 +0100 |
commit | a72b6260c117cabe4fcb2996cce4f870986df99b (patch) | |
tree | 7c9c7704c6ef60a4ab886d5acbce4e6e22398b56 /train_dreambooth.py | |
parent | Fixed LR finder (diff) | |
download | textual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.tar.gz textual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.tar.bz2 textual-inversion-diff-a72b6260c117cabe4fcb2996cce4f870986df99b.zip |
Added vector dropout
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 24 |
1 files changed, 19 insertions, 5 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 218018b..f26b7f5 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -108,6 +108,12 @@ def parse_args(): | |||
108 | help="Tag dropout probability.", | 108 | help="Tag dropout probability.", |
109 | ) | 109 | ) |
110 | parser.add_argument( | 110 | parser.add_argument( |
111 | "--vector_dropout", | ||
112 | type=int, | ||
113 | default=0.1, | ||
114 | help="Vector dropout probability.", | ||
115 | ) | ||
116 | parser.add_argument( | ||
111 | "--vector_shuffle", | 117 | "--vector_shuffle", |
112 | type=str, | 118 | type=str, |
113 | default="auto", | 119 | default="auto", |
@@ -556,6 +562,8 @@ def main(): | |||
556 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) | 562 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) |
557 | elif args.pretrained_model_name_or_path: | 563 | elif args.pretrained_model_name_or_path: |
558 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 564 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
565 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
566 | tokenizer.set_dropout(args.vector_dropout) | ||
559 | 567 | ||
560 | # Load models and create wrapper for stable diffusion | 568 | # Load models and create wrapper for stable diffusion |
561 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') | 569 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') |
@@ -826,6 +834,12 @@ def main(): | |||
826 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 834 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
827 | val_steps = num_val_steps_per_epoch * num_epochs | 835 | val_steps = num_val_steps_per_epoch * num_epochs |
828 | 836 | ||
837 | def on_train(): | ||
838 | tokenizer.train() | ||
839 | |||
840 | def on_eval(): | ||
841 | tokenizer.eval() | ||
842 | |||
829 | def loop(batch): | 843 | def loop(batch): |
830 | # Convert images to latent space | 844 | # Convert images to latent space |
831 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 845 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
@@ -898,8 +912,8 @@ def main(): | |||
898 | train_dataloader, | 912 | train_dataloader, |
899 | val_dataloader, | 913 | val_dataloader, |
900 | loop, | 914 | loop, |
901 | on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), | 915 | on_train=tokenizer.train, |
902 | on_eval=lambda: tokenizer.set_use_vector_shuffle(False) | 916 | on_eval=tokenizer.eval, |
903 | ) | 917 | ) |
904 | lr_finder.run(end_lr=1e2) | 918 | lr_finder.run(end_lr=1e2) |
905 | 919 | ||
@@ -953,7 +967,7 @@ def main(): | |||
953 | disable=not accelerator.is_local_main_process, | 967 | disable=not accelerator.is_local_main_process, |
954 | dynamic_ncols=True | 968 | dynamic_ncols=True |
955 | ) | 969 | ) |
956 | local_progress_bar.set_description("Epoch X / Y") | 970 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") |
957 | 971 | ||
958 | global_progress_bar = tqdm( | 972 | global_progress_bar = tqdm( |
959 | range(args.max_train_steps + val_steps), | 973 | range(args.max_train_steps + val_steps), |
@@ -976,7 +990,7 @@ def main(): | |||
976 | text_encoder.train() | 990 | text_encoder.train() |
977 | elif epoch == args.train_text_encoder_epochs: | 991 | elif epoch == args.train_text_encoder_epochs: |
978 | text_encoder.requires_grad_(False) | 992 | text_encoder.requires_grad_(False) |
979 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 993 | on_train() |
980 | 994 | ||
981 | for step, batch in enumerate(train_dataloader): | 995 | for step, batch in enumerate(train_dataloader): |
982 | with accelerator.accumulate(unet): | 996 | with accelerator.accumulate(unet): |
@@ -1030,7 +1044,7 @@ def main(): | |||
1030 | 1044 | ||
1031 | unet.eval() | 1045 | unet.eval() |
1032 | text_encoder.eval() | 1046 | text_encoder.eval() |
1033 | tokenizer.set_use_vector_shuffle(False) | 1047 | on_eval() |
1034 | 1048 | ||
1035 | cur_loss_val = AverageMeter() | 1049 | cur_loss_val = AverageMeter() |
1036 | cur_acc_val = AverageMeter() | 1050 | cur_acc_val = AverageMeter() |