diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-21 09:50:46 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-21 09:50:46 +0200 |
| commit | ca914af018632b6231fb3ee4fcd5cdbdc467c784 (patch) | |
| tree | 01af701c5ac740518cdbc4001592a3f9a29cc57a /textual_inversion.py | |
| parent | Dreambooth: Added option to insert a new input token; removed Dreambooth Plus (diff) | |
| download | textual-inversion-diff-ca914af018632b6231fb3ee4fcd5cdbdc467c784.tar.gz textual-inversion-diff-ca914af018632b6231fb3ee4fcd5cdbdc467c784.tar.bz2 textual-inversion-diff-ca914af018632b6231fb3ee4fcd5cdbdc467c784.zip | |
Add optional TI functionality to Dreambooth
Diffstat (limited to 'textual_inversion.py')
| -rw-r--r-- | textual_inversion.py | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 8f266e0..fe56d36 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -521,6 +521,7 @@ def main(): | |||
| 521 | 521 | ||
| 522 | if args.gradient_checkpointing: | 522 | if args.gradient_checkpointing: |
| 523 | unet.enable_gradient_checkpointing() | 523 | unet.enable_gradient_checkpointing() |
| 524 | text_encoder.gradient_checkpointing_enable() | ||
| 524 | 525 | ||
| 525 | # slice_size = unet.config.attention_head_dim // 2 | 526 | # slice_size = unet.config.attention_head_dim // 2 |
| 526 | # unet.set_attention_slice(slice_size) | 527 | # unet.set_attention_slice(slice_size) |
| @@ -875,8 +876,8 @@ def main(): | |||
| 875 | text_encoder.eval() | 876 | text_encoder.eval() |
| 876 | val_loss = 0.0 | 877 | val_loss = 0.0 |
| 877 | 878 | ||
| 878 | for step, batch in enumerate(val_dataloader): | 879 | with torch.inference_mode(): |
| 879 | with torch.no_grad(): | 880 | for step, batch in enumerate(val_dataloader): |
| 880 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 881 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 881 | latents = latents * 0.18215 | 882 | latents = latents * 0.18215 |
| 882 | 883 | ||
| @@ -899,12 +900,12 @@ def main(): | |||
| 899 | loss = loss.detach().item() | 900 | loss = loss.detach().item() |
| 900 | val_loss += loss | 901 | val_loss += loss |
| 901 | 902 | ||
| 902 | if accelerator.sync_gradients: | 903 | if accelerator.sync_gradients: |
| 903 | local_progress_bar.update(1) | 904 | local_progress_bar.update(1) |
| 904 | global_progress_bar.update(1) | 905 | global_progress_bar.update(1) |
| 905 | 906 | ||
| 906 | logs = {"val/loss": loss} | 907 | logs = {"val/loss": loss} |
| 907 | local_progress_bar.set_postfix(**logs) | 908 | local_progress_bar.set_postfix(**logs) |
| 908 | 909 | ||
| 909 | val_loss /= len(val_dataloader) | 910 | val_loss /= len(val_dataloader) |
| 910 | 911 | ||
