diff options
author | Volpeon <git@volpeon.ink> | 2023-01-02 17:34:11 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-02 17:34:11 +0100 |
commit | 67aaba2159bcda4c0b8538b1580a40f01e8f0964 (patch) | |
tree | e1308417bde00609a5347bc39a8cd6583fd066f8 /train_ti.py | |
parent | Fix (diff) | |
download | textual-inversion-diff-67aaba2159bcda4c0b8538b1580a40f01e8f0964.tar.gz textual-inversion-diff-67aaba2159bcda4c0b8538b1580a40f01e8f0964.tar.bz2 textual-inversion-diff-67aaba2159bcda4c0b8538b1580a40f01e8f0964.zip |
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 15 |
1 files changed, 12 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index 97dde1e..2b3f017 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -535,8 +535,6 @@ def main(): | |||
535 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 535 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
536 | args.pretrained_model_name_or_path, subfolder='scheduler') | 536 | args.pretrained_model_name_or_path, subfolder='scheduler') |
537 | 537 | ||
538 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
539 | |||
540 | vae.enable_slicing() | 538 | vae.enable_slicing() |
541 | vae.set_use_memory_efficient_attention_xformers(True) | 539 | vae.set_use_memory_efficient_attention_xformers(True) |
542 | unet.set_use_memory_efficient_attention_xformers(True) | 540 | unet.set_use_memory_efficient_attention_xformers(True) |
@@ -845,7 +843,16 @@ def main(): | |||
845 | accelerator.init_trackers("textual_inversion", config=config) | 843 | accelerator.init_trackers("textual_inversion", config=config) |
846 | 844 | ||
847 | if args.find_lr: | 845 | if args.find_lr: |
848 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | 846 | lr_finder = LRFinder( |
847 | accelerator, | ||
848 | text_encoder, | ||
849 | optimizer, | ||
850 | train_dataloader, | ||
851 | val_dataloader, | ||
852 | loop, | ||
853 | on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), | ||
854 | on_eval=lambda: tokenizer.set_use_vector_shuffle(False) | ||
855 | ) | ||
849 | lr_finder.run(min_lr=1e-4) | 856 | lr_finder.run(min_lr=1e-4) |
850 | 857 | ||
851 | plt.savefig(basepath.joinpath("lr.png")) | 858 | plt.savefig(basepath.joinpath("lr.png")) |
@@ -915,6 +922,7 @@ def main(): | |||
915 | local_progress_bar.reset() | 922 | local_progress_bar.reset() |
916 | 923 | ||
917 | text_encoder.train() | 924 | text_encoder.train() |
925 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
918 | 926 | ||
919 | for step, batch in enumerate(train_dataloader): | 927 | for step, batch in enumerate(train_dataloader): |
920 | with accelerator.accumulate(text_encoder): | 928 | with accelerator.accumulate(text_encoder): |
@@ -955,6 +963,7 @@ def main(): | |||
955 | accelerator.wait_for_everyone() | 963 | accelerator.wait_for_everyone() |
956 | 964 | ||
957 | text_encoder.eval() | 965 | text_encoder.eval() |
966 | tokenizer.set_use_vector_shuffle(False) | ||
958 | 967 | ||
959 | cur_loss_val = AverageMeter() | 968 | cur_loss_val = AverageMeter() |
960 | cur_acc_val = AverageMeter() | 969 | cur_acc_val = AverageMeter() |