From a72b6260c117cabe4fcb2996cce4f870986df99b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 3 Jan 2023 12:40:16 +0100 Subject: Added vector dropout --- train_dreambooth.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 218018b..f26b7f5 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -107,6 +107,12 @@ def parse_args(): default=0.1, help="Tag dropout probability.", ) + parser.add_argument( + "--vector_dropout", + type=int, + default=0.1, + help="Vector dropout probability.", + ) parser.add_argument( "--vector_shuffle", type=str, @@ -556,6 +562,8 @@ def main(): tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) elif args.pretrained_model_name_or_path: tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') + tokenizer.set_use_vector_shuffle(args.vector_shuffle) + tokenizer.set_dropout(args.vector_dropout) # Load models and create wrapper for stable diffusion text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') @@ -826,6 +834,12 @@ def main(): num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) val_steps = num_val_steps_per_epoch * num_epochs + def on_train(): + tokenizer.train() + + def on_eval(): + tokenizer.eval() + def loop(batch): # Convert images to latent space latents = vae.encode(batch["pixel_values"]).latent_dist.sample() @@ -898,8 +912,8 @@ def main(): train_dataloader, val_dataloader, loop, - on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), - on_eval=lambda: tokenizer.set_use_vector_shuffle(False) + on_train=tokenizer.train, + on_eval=tokenizer.eval, ) lr_finder.run(end_lr=1e2) @@ -953,7 +967,7 @@ def main(): disable=not accelerator.is_local_main_process, dynamic_ncols=True ) - local_progress_bar.set_description("Epoch X / Y") + local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") global_progress_bar = tqdm( range(args.max_train_steps + val_steps), @@ -976,7 +990,7 @@ def main(): text_encoder.train() elif epoch == args.train_text_encoder_epochs: text_encoder.requires_grad_(False) - tokenizer.set_use_vector_shuffle(args.vector_shuffle) + on_train() for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): @@ -1030,7 +1044,7 @@ def main(): unet.eval() text_encoder.eval() - tokenizer.set_use_vector_shuffle(False) + on_eval() cur_loss_val = AverageMeter() cur_acc_val = AverageMeter() -- cgit v1.2.3-54-g00ecf