diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-02 17:34:11 +0100 | 
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-02 17:34:11 +0100 | 
| commit | 67aaba2159bcda4c0b8538b1580a40f01e8f0964 (patch) | |
| tree | e1308417bde00609a5347bc39a8cd6583fd066f8 /train_ti.py | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-67aaba2159bcda4c0b8538b1580a40f01e8f0964.tar.gz textual-inversion-diff-67aaba2159bcda4c0b8538b1580a40f01e8f0964.tar.bz2 textual-inversion-diff-67aaba2159bcda4c0b8538b1580a40f01e8f0964.zip  | |
Update
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 15 | 
1 files changed, 12 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index 97dde1e..2b3f017 100644 --- a/train_ti.py +++ b/train_ti.py  | |||
| @@ -535,8 +535,6 @@ def main(): | |||
| 535 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 535 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 
| 536 | args.pretrained_model_name_or_path, subfolder='scheduler') | 536 | args.pretrained_model_name_or_path, subfolder='scheduler') | 
| 537 | 537 | ||
| 538 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
| 539 | |||
| 540 | vae.enable_slicing() | 538 | vae.enable_slicing() | 
| 541 | vae.set_use_memory_efficient_attention_xformers(True) | 539 | vae.set_use_memory_efficient_attention_xformers(True) | 
| 542 | unet.set_use_memory_efficient_attention_xformers(True) | 540 | unet.set_use_memory_efficient_attention_xformers(True) | 
| @@ -845,7 +843,16 @@ def main(): | |||
| 845 | accelerator.init_trackers("textual_inversion", config=config) | 843 | accelerator.init_trackers("textual_inversion", config=config) | 
| 846 | 844 | ||
| 847 | if args.find_lr: | 845 | if args.find_lr: | 
| 848 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | 846 | lr_finder = LRFinder( | 
| 847 | accelerator, | ||
| 848 | text_encoder, | ||
| 849 | optimizer, | ||
| 850 | train_dataloader, | ||
| 851 | val_dataloader, | ||
| 852 | loop, | ||
| 853 | on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), | ||
| 854 | on_eval=lambda: tokenizer.set_use_vector_shuffle(False) | ||
| 855 | ) | ||
| 849 | lr_finder.run(min_lr=1e-4) | 856 | lr_finder.run(min_lr=1e-4) | 
| 850 | 857 | ||
| 851 | plt.savefig(basepath.joinpath("lr.png")) | 858 | plt.savefig(basepath.joinpath("lr.png")) | 
| @@ -915,6 +922,7 @@ def main(): | |||
| 915 | local_progress_bar.reset() | 922 | local_progress_bar.reset() | 
| 916 | 923 | ||
| 917 | text_encoder.train() | 924 | text_encoder.train() | 
| 925 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
| 918 | 926 | ||
| 919 | for step, batch in enumerate(train_dataloader): | 927 | for step, batch in enumerate(train_dataloader): | 
| 920 | with accelerator.accumulate(text_encoder): | 928 | with accelerator.accumulate(text_encoder): | 
| @@ -955,6 +963,7 @@ def main(): | |||
| 955 | accelerator.wait_for_everyone() | 963 | accelerator.wait_for_everyone() | 
| 956 | 964 | ||
| 957 | text_encoder.eval() | 965 | text_encoder.eval() | 
| 966 | tokenizer.set_use_vector_shuffle(False) | ||
| 958 | 967 | ||
| 959 | cur_loss_val = AverageMeter() | 968 | cur_loss_val = AverageMeter() | 
| 960 | cur_acc_val = AverageMeter() | 969 | cur_acc_val = AverageMeter() | 
