diff options
author | Volpeon <git@volpeon.ink> | 2023-01-02 17:34:11 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-02 17:34:11 +0100 |
commit | 67aaba2159bcda4c0b8538b1580a40f01e8f0964 (patch) | |
tree | e1308417bde00609a5347bc39a8cd6583fd066f8 /train_dreambooth.py | |
parent | Fix (diff) | |
download | textual-inversion-diff-67aaba2159bcda4c0b8538b1580a40f01e8f0964.tar.gz textual-inversion-diff-67aaba2159bcda4c0b8538b1580a40f01e8f0964.tar.bz2 textual-inversion-diff-67aaba2159bcda4c0b8538b1580a40f01e8f0964.zip |
Update
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 05f6cb5..1e49474 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -565,8 +565,6 @@ def main(): | |||
565 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 565 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
566 | args.pretrained_model_name_or_path, subfolder='scheduler') | 566 | args.pretrained_model_name_or_path, subfolder='scheduler') |
567 | 567 | ||
568 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
569 | |||
570 | vae.enable_slicing() | 568 | vae.enable_slicing() |
571 | vae.set_use_memory_efficient_attention_xformers(True) | 569 | vae.set_use_memory_efficient_attention_xformers(True) |
572 | unet.set_use_memory_efficient_attention_xformers(True) | 570 | unet.set_use_memory_efficient_attention_xformers(True) |
@@ -893,7 +891,16 @@ def main(): | |||
893 | accelerator.init_trackers("dreambooth", config=config) | 891 | accelerator.init_trackers("dreambooth", config=config) |
894 | 892 | ||
895 | if args.find_lr: | 893 | if args.find_lr: |
896 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | 894 | lr_finder = LRFinder( |
895 | accelerator, | ||
896 | text_encoder, | ||
897 | optimizer, | ||
898 | train_dataloader, | ||
899 | val_dataloader, | ||
900 | loop, | ||
901 | on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), | ||
902 | on_eval=lambda: tokenizer.set_use_vector_shuffle(False) | ||
903 | ) | ||
897 | lr_finder.run(min_lr=1e-4) | 904 | lr_finder.run(min_lr=1e-4) |
898 | 905 | ||
899 | plt.savefig(basepath.joinpath("lr.png")) | 906 | plt.savefig(basepath.joinpath("lr.png")) |
@@ -965,11 +972,11 @@ def main(): | |||
965 | local_progress_bar.reset() | 972 | local_progress_bar.reset() |
966 | 973 | ||
967 | unet.train() | 974 | unet.train() |
968 | |||
969 | if epoch < args.train_text_encoder_epochs: | 975 | if epoch < args.train_text_encoder_epochs: |
970 | text_encoder.train() | 976 | text_encoder.train() |
971 | elif epoch == args.train_text_encoder_epochs: | 977 | elif epoch == args.train_text_encoder_epochs: |
972 | text_encoder.requires_grad_(False) | 978 | text_encoder.requires_grad_(False) |
979 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
973 | 980 | ||
974 | for step, batch in enumerate(train_dataloader): | 981 | for step, batch in enumerate(train_dataloader): |
975 | with accelerator.accumulate(unet): | 982 | with accelerator.accumulate(unet): |
@@ -1023,6 +1030,7 @@ def main(): | |||
1023 | 1030 | ||
1024 | unet.eval() | 1031 | unet.eval() |
1025 | text_encoder.eval() | 1032 | text_encoder.eval() |
1033 | tokenizer.set_use_vector_shuffle(False) | ||
1026 | 1034 | ||
1027 | cur_loss_val = AverageMeter() | 1035 | cur_loss_val = AverageMeter() |
1028 | cur_acc_val = AverageMeter() | 1036 | cur_acc_val = AverageMeter() |