diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-02 17:34:11 +0100 | 
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-02 17:34:11 +0100 | 
| commit | 67aaba2159bcda4c0b8538b1580a40f01e8f0964 (patch) | |
| tree | e1308417bde00609a5347bc39a8cd6583fd066f8 /train_dreambooth.py | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-67aaba2159bcda4c0b8538b1580a40f01e8f0964.tar.gz textual-inversion-diff-67aaba2159bcda4c0b8538b1580a40f01e8f0964.tar.bz2 textual-inversion-diff-67aaba2159bcda4c0b8538b1580a40f01e8f0964.zip | |
Update
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 16 | 
1 files changed, 12 insertions, 4 deletions
| diff --git a/train_dreambooth.py b/train_dreambooth.py index 05f6cb5..1e49474 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -565,8 +565,6 @@ def main(): | |||
| 565 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 565 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 
| 566 | args.pretrained_model_name_or_path, subfolder='scheduler') | 566 | args.pretrained_model_name_or_path, subfolder='scheduler') | 
| 567 | 567 | ||
| 568 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
| 569 | |||
| 570 | vae.enable_slicing() | 568 | vae.enable_slicing() | 
| 571 | vae.set_use_memory_efficient_attention_xformers(True) | 569 | vae.set_use_memory_efficient_attention_xformers(True) | 
| 572 | unet.set_use_memory_efficient_attention_xformers(True) | 570 | unet.set_use_memory_efficient_attention_xformers(True) | 
| @@ -893,7 +891,16 @@ def main(): | |||
| 893 | accelerator.init_trackers("dreambooth", config=config) | 891 | accelerator.init_trackers("dreambooth", config=config) | 
| 894 | 892 | ||
| 895 | if args.find_lr: | 893 | if args.find_lr: | 
| 896 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | 894 | lr_finder = LRFinder( | 
| 895 | accelerator, | ||
| 896 | text_encoder, | ||
| 897 | optimizer, | ||
| 898 | train_dataloader, | ||
| 899 | val_dataloader, | ||
| 900 | loop, | ||
| 901 | on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle), | ||
| 902 | on_eval=lambda: tokenizer.set_use_vector_shuffle(False) | ||
| 903 | ) | ||
| 897 | lr_finder.run(min_lr=1e-4) | 904 | lr_finder.run(min_lr=1e-4) | 
| 898 | 905 | ||
| 899 | plt.savefig(basepath.joinpath("lr.png")) | 906 | plt.savefig(basepath.joinpath("lr.png")) | 
| @@ -965,11 +972,11 @@ def main(): | |||
| 965 | local_progress_bar.reset() | 972 | local_progress_bar.reset() | 
| 966 | 973 | ||
| 967 | unet.train() | 974 | unet.train() | 
| 968 | |||
| 969 | if epoch < args.train_text_encoder_epochs: | 975 | if epoch < args.train_text_encoder_epochs: | 
| 970 | text_encoder.train() | 976 | text_encoder.train() | 
| 971 | elif epoch == args.train_text_encoder_epochs: | 977 | elif epoch == args.train_text_encoder_epochs: | 
| 972 | text_encoder.requires_grad_(False) | 978 | text_encoder.requires_grad_(False) | 
| 979 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
| 973 | 980 | ||
| 974 | for step, batch in enumerate(train_dataloader): | 981 | for step, batch in enumerate(train_dataloader): | 
| 975 | with accelerator.accumulate(unet): | 982 | with accelerator.accumulate(unet): | 
| @@ -1023,6 +1030,7 @@ def main(): | |||
| 1023 | 1030 | ||
| 1024 | unet.eval() | 1031 | unet.eval() | 
| 1025 | text_encoder.eval() | 1032 | text_encoder.eval() | 
| 1033 | tokenizer.set_use_vector_shuffle(False) | ||
| 1026 | 1034 | ||
| 1027 | cur_loss_val = AverageMeter() | 1035 | cur_loss_val = AverageMeter() | 
| 1028 | cur_acc_val = AverageMeter() | 1036 | cur_acc_val = AverageMeter() | 
