summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-02 17:34:11 +0100
committerVolpeon <git@volpeon.ink>2023-01-02 17:34:11 +0100
commit67aaba2159bcda4c0b8538b1580a40f01e8f0964 (patch)
treee1308417bde00609a5347bc39a8cd6583fd066f8 /train_dreambooth.py
parentFix (diff)
downloadtextual-inversion-diff-67aaba2159bcda4c0b8538b1580a40f01e8f0964.tar.gz
textual-inversion-diff-67aaba2159bcda4c0b8538b1580a40f01e8f0964.tar.bz2
textual-inversion-diff-67aaba2159bcda4c0b8538b1580a40f01e8f0964.zip
Update
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py16
1 files changed, 12 insertions, 4 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 05f6cb5..1e49474 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -565,8 +565,6 @@ def main():
565 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( 565 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
566 args.pretrained_model_name_or_path, subfolder='scheduler') 566 args.pretrained_model_name_or_path, subfolder='scheduler')
567 567
568 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
569
570 vae.enable_slicing() 568 vae.enable_slicing()
571 vae.set_use_memory_efficient_attention_xformers(True) 569 vae.set_use_memory_efficient_attention_xformers(True)
572 unet.set_use_memory_efficient_attention_xformers(True) 570 unet.set_use_memory_efficient_attention_xformers(True)
@@ -893,7 +891,16 @@ def main():
893 accelerator.init_trackers("dreambooth", config=config) 891 accelerator.init_trackers("dreambooth", config=config)
894 892
895 if args.find_lr: 893 if args.find_lr:
896 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) 894 lr_finder = LRFinder(
895 accelerator,
896 text_encoder,
897 optimizer,
898 train_dataloader,
899 val_dataloader,
900 loop,
901 on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle),
902 on_eval=lambda: tokenizer.set_use_vector_shuffle(False)
903 )
897 lr_finder.run(min_lr=1e-4) 904 lr_finder.run(min_lr=1e-4)
898 905
899 plt.savefig(basepath.joinpath("lr.png")) 906 plt.savefig(basepath.joinpath("lr.png"))
@@ -965,11 +972,11 @@ def main():
965 local_progress_bar.reset() 972 local_progress_bar.reset()
966 973
967 unet.train() 974 unet.train()
968
969 if epoch < args.train_text_encoder_epochs: 975 if epoch < args.train_text_encoder_epochs:
970 text_encoder.train() 976 text_encoder.train()
971 elif epoch == args.train_text_encoder_epochs: 977 elif epoch == args.train_text_encoder_epochs:
972 text_encoder.requires_grad_(False) 978 text_encoder.requires_grad_(False)
979 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
973 980
974 for step, batch in enumerate(train_dataloader): 981 for step, batch in enumerate(train_dataloader):
975 with accelerator.accumulate(unet): 982 with accelerator.accumulate(unet):
@@ -1023,6 +1030,7 @@ def main():
1023 1030
1024 unet.eval() 1031 unet.eval()
1025 text_encoder.eval() 1032 text_encoder.eval()
1033 tokenizer.set_use_vector_shuffle(False)
1026 1034
1027 cur_loss_val = AverageMeter() 1035 cur_loss_val = AverageMeter()
1028 cur_acc_val = AverageMeter() 1036 cur_acc_val = AverageMeter()