summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-02 17:34:11 +0100
committerVolpeon <git@volpeon.ink>2023-01-02 17:34:11 +0100
commit67aaba2159bcda4c0b8538b1580a40f01e8f0964 (patch)
treee1308417bde00609a5347bc39a8cd6583fd066f8 /train_ti.py
parentFix (diff)
downloadtextual-inversion-diff-67aaba2159bcda4c0b8538b1580a40f01e8f0964.tar.gz
textual-inversion-diff-67aaba2159bcda4c0b8538b1580a40f01e8f0964.tar.bz2
textual-inversion-diff-67aaba2159bcda4c0b8538b1580a40f01e8f0964.zip
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py15
1 files changed, 12 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py
index 97dde1e..2b3f017 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -535,8 +535,6 @@ def main():
535 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( 535 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
536 args.pretrained_model_name_or_path, subfolder='scheduler') 536 args.pretrained_model_name_or_path, subfolder='scheduler')
537 537
538 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
539
540 vae.enable_slicing() 538 vae.enable_slicing()
541 vae.set_use_memory_efficient_attention_xformers(True) 539 vae.set_use_memory_efficient_attention_xformers(True)
542 unet.set_use_memory_efficient_attention_xformers(True) 540 unet.set_use_memory_efficient_attention_xformers(True)
@@ -845,7 +843,16 @@ def main():
845 accelerator.init_trackers("textual_inversion", config=config) 843 accelerator.init_trackers("textual_inversion", config=config)
846 844
847 if args.find_lr: 845 if args.find_lr:
848 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) 846 lr_finder = LRFinder(
847 accelerator,
848 text_encoder,
849 optimizer,
850 train_dataloader,
851 val_dataloader,
852 loop,
853 on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle),
854 on_eval=lambda: tokenizer.set_use_vector_shuffle(False)
855 )
849 lr_finder.run(min_lr=1e-4) 856 lr_finder.run(min_lr=1e-4)
850 857
851 plt.savefig(basepath.joinpath("lr.png")) 858 plt.savefig(basepath.joinpath("lr.png"))
@@ -915,6 +922,7 @@ def main():
915 local_progress_bar.reset() 922 local_progress_bar.reset()
916 923
917 text_encoder.train() 924 text_encoder.train()
925 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
918 926
919 for step, batch in enumerate(train_dataloader): 927 for step, batch in enumerate(train_dataloader):
920 with accelerator.accumulate(text_encoder): 928 with accelerator.accumulate(text_encoder):
@@ -955,6 +963,7 @@ def main():
955 accelerator.wait_for_everyone() 963 accelerator.wait_for_everyone()
956 964
957 text_encoder.eval() 965 text_encoder.eval()
966 tokenizer.set_use_vector_shuffle(False)
958 967
959 cur_loss_val = AverageMeter() 968 cur_loss_val = AverageMeter()
960 cur_acc_val = AverageMeter() 969 cur_acc_val = AverageMeter()