From 67aaba2159bcda4c0b8538b1580a40f01e8f0964 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 2 Jan 2023 17:34:11 +0100 Subject: Update --- train_dreambooth.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 05f6cb5..1e49474 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -565,8 +565,6 @@ def main(): checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder='scheduler') - tokenizer.set_use_vector_shuffle(args.vector_shuffle) - vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) unet.set_use_memory_efficient_attention_xformers(True) @@ -893,7 +891,16 @@ def main(): accelerator.init_trackers("dreambooth", config=config) if args.find_lr: - lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) + lr_finder = LRFinder( + accelerator, + text_encoder, + optimizer, + train_dataloader, + val_dataloader, + loop, + on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), + on_eval=lambda: tokenizer.set_use_vector_shuffle(False) + ) lr_finder.run(min_lr=1e-4) plt.savefig(basepath.joinpath("lr.png")) @@ -965,11 +972,11 @@ def main(): local_progress_bar.reset() unet.train() - if epoch < args.train_text_encoder_epochs: text_encoder.train() elif epoch == args.train_text_encoder_epochs: text_encoder.requires_grad_(False) + tokenizer.set_use_vector_shuffle(args.vector_shuffle) for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): @@ -1023,6 +1030,7 @@ def main(): unet.eval() text_encoder.eval() + tokenizer.set_use_vector_shuffle(False) cur_loss_val = AverageMeter() cur_acc_val = AverageMeter() -- cgit v1.2.3-54-g00ecf