From 8ff51a771905d0d14a3c690f54eb644515730348 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 14 Nov 2022 18:41:38 +0100 Subject: Refactoring --- textual_inversion.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index 999161b..bf591bc 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -437,7 +437,7 @@ class Checkpointer: generator=generator, ) - with torch.inference_mode(): + with torch.autocast("cuda"), torch.inference_mode(): for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: all_samples = [] file_path = samples_path.joinpath(pool, f"step_{step}.png") @@ -456,7 +456,7 @@ class Checkpointer: negative_prompt=nprompt, height=self.sample_image_size, width=self.sample_image_size, - latents=latents[:len(prompt)] if latents is not None else None, + latents_or_image=latents[:len(prompt)] if latents is not None else None, generator=generator if latents is not None else None, guidance_scale=guidance_scale, eta=eta, @@ -670,7 +670,7 @@ def main(): ).to(accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) - with torch.inference_mode(): + with torch.autocast("cuda"), torch.inference_mode(): for batch in batched_data: image_name = [p.class_image_path for p in batch] prompt = [p.prompt.format(identifier=args.class_identifier) for p in batch] @@ -908,7 +908,7 @@ def main(): text_encoder.eval() val_loss = 0.0 - with torch.inference_mode(): + with torch.autocast("cuda"), torch.inference_mode(): for step, batch in enumerate(val_dataloader): latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * 0.18215 -- cgit v1.2.3-54-g00ecf