summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-21 09:50:46 +0200
committerVolpeon <git@volpeon.ink>2022-10-21 09:50:46 +0200
commitca914af018632b6231fb3ee4fcd5cdbdc467c784 (patch)
tree01af701c5ac740518cdbc4001592a3f9a29cc57a /textual_inversion.py
parentDreambooth: Added option to insert a new input token; removed Dreambooth Plus (diff)
downloadtextual-inversion-diff-ca914af018632b6231fb3ee4fcd5cdbdc467c784.tar.gz
textual-inversion-diff-ca914af018632b6231fb3ee4fcd5cdbdc467c784.tar.bz2
textual-inversion-diff-ca914af018632b6231fb3ee4fcd5cdbdc467c784.zip
Add optional TI functionality to Dreambooth
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py15
1 files changed, 8 insertions, 7 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 8f266e0..fe56d36 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -521,6 +521,7 @@ def main():
521 521
522 if args.gradient_checkpointing: 522 if args.gradient_checkpointing:
523 unet.enable_gradient_checkpointing() 523 unet.enable_gradient_checkpointing()
524 text_encoder.gradient_checkpointing_enable()
524 525
525 # slice_size = unet.config.attention_head_dim // 2 526 # slice_size = unet.config.attention_head_dim // 2
526 # unet.set_attention_slice(slice_size) 527 # unet.set_attention_slice(slice_size)
@@ -875,8 +876,8 @@ def main():
875 text_encoder.eval() 876 text_encoder.eval()
876 val_loss = 0.0 877 val_loss = 0.0
877 878
878 for step, batch in enumerate(val_dataloader): 879 with torch.inference_mode():
879 with torch.no_grad(): 880 for step, batch in enumerate(val_dataloader):
880 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 881 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
881 latents = latents * 0.18215 882 latents = latents * 0.18215
882 883
@@ -899,12 +900,12 @@ def main():
899 loss = loss.detach().item() 900 loss = loss.detach().item()
900 val_loss += loss 901 val_loss += loss
901 902
902 if accelerator.sync_gradients: 903 if accelerator.sync_gradients:
903 local_progress_bar.update(1) 904 local_progress_bar.update(1)
904 global_progress_bar.update(1) 905 global_progress_bar.update(1)
905 906
906 logs = {"val/loss": loss} 907 logs = {"val/loss": loss}
907 local_progress_bar.set_postfix(**logs) 908 local_progress_bar.set_postfix(**logs)
908 909
909 val_loss /= len(val_dataloader) 910 val_loss /= len(val_dataloader)
910 911