From ca914af018632b6231fb3ee4fcd5cdbdc467c784 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 21 Oct 2022 09:50:46 +0200 Subject: Add optional TI functionality to Dreambooth --- textual_inversion.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index 8f266e0..fe56d36 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -521,6 +521,7 @@ def main(): if args.gradient_checkpointing: unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() # slice_size = unet.config.attention_head_dim // 2 # unet.set_attention_slice(slice_size) @@ -875,8 +876,8 @@ def main(): text_encoder.eval() val_loss = 0.0 - for step, batch in enumerate(val_dataloader): - with torch.no_grad(): + with torch.inference_mode(): + for step, batch in enumerate(val_dataloader): latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * 0.18215 @@ -899,12 +900,12 @@ def main(): loss = loss.detach().item() val_loss += loss - if accelerator.sync_gradients: - local_progress_bar.update(1) - global_progress_bar.update(1) + if accelerator.sync_gradients: + local_progress_bar.update(1) + global_progress_bar.update(1) - logs = {"val/loss": loss} - local_progress_bar.set_postfix(**logs) + logs = {"val/loss": loss} + local_progress_bar.set_postfix(**logs) val_loss /= len(val_dataloader) -- cgit v1.2.3-54-g00ecf