diff options
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 8f266e0..fe56d36 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -521,6 +521,7 @@ def main(): | |||
521 | 521 | ||
522 | if args.gradient_checkpointing: | 522 | if args.gradient_checkpointing: |
523 | unet.enable_gradient_checkpointing() | 523 | unet.enable_gradient_checkpointing() |
524 | text_encoder.gradient_checkpointing_enable() | ||
524 | 525 | ||
525 | # slice_size = unet.config.attention_head_dim // 2 | 526 | # slice_size = unet.config.attention_head_dim // 2 |
526 | # unet.set_attention_slice(slice_size) | 527 | # unet.set_attention_slice(slice_size) |
@@ -875,8 +876,8 @@ def main(): | |||
875 | text_encoder.eval() | 876 | text_encoder.eval() |
876 | val_loss = 0.0 | 877 | val_loss = 0.0 |
877 | 878 | ||
878 | for step, batch in enumerate(val_dataloader): | 879 | with torch.inference_mode(): |
879 | with torch.no_grad(): | 880 | for step, batch in enumerate(val_dataloader): |
880 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 881 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
881 | latents = latents * 0.18215 | 882 | latents = latents * 0.18215 |
882 | 883 | ||
@@ -899,12 +900,12 @@ def main(): | |||
899 | loss = loss.detach().item() | 900 | loss = loss.detach().item() |
900 | val_loss += loss | 901 | val_loss += loss |
901 | 902 | ||
902 | if accelerator.sync_gradients: | 903 | if accelerator.sync_gradients: |
903 | local_progress_bar.update(1) | 904 | local_progress_bar.update(1) |
904 | global_progress_bar.update(1) | 905 | global_progress_bar.update(1) |
905 | 906 | ||
906 | logs = {"val/loss": loss} | 907 | logs = {"val/loss": loss} |
907 | local_progress_bar.set_postfix(**logs) | 908 | local_progress_bar.set_postfix(**logs) |
908 | 909 | ||
909 | val_loss /= len(val_dataloader) | 910 | val_loss /= len(val_dataloader) |
910 | 911 | ||