From 67aaba2159bcda4c0b8538b1580a40f01e8f0964 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 2 Jan 2023 17:34:11 +0100 Subject: Update --- train_ti.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 97dde1e..2b3f017 100644 --- a/train_ti.py +++ b/train_ti.py @@ -535,8 +535,6 @@ def main(): checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder='scheduler') - tokenizer.set_use_vector_shuffle(args.vector_shuffle) - vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) unet.set_use_memory_efficient_attention_xformers(True) @@ -845,7 +843,16 @@ def main(): accelerator.init_trackers("textual_inversion", config=config) if args.find_lr: - lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) + lr_finder = LRFinder( + accelerator, + text_encoder, + optimizer, + train_dataloader, + val_dataloader, + loop, + on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), + on_eval=lambda: tokenizer.set_use_vector_shuffle(False) + ) lr_finder.run(min_lr=1e-4) plt.savefig(basepath.joinpath("lr.png")) @@ -915,6 +922,7 @@ def main(): local_progress_bar.reset() text_encoder.train() + tokenizer.set_use_vector_shuffle(args.vector_shuffle) for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): @@ -955,6 +963,7 @@ def main(): accelerator.wait_for_everyone() text_encoder.eval() + tokenizer.set_use_vector_shuffle(False) cur_loss_val = AverageMeter() cur_acc_val = AverageMeter() -- cgit v1.2.3-54-g00ecf