summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 999161b..bf591bc 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -437,7 +437,7 @@ class Checkpointer:
437 generator=generator, 437 generator=generator,
438 ) 438 )
439 439
440 with torch.inference_mode(): 440 with torch.autocast("cuda"), torch.inference_mode():
441 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: 441 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
442 all_samples = [] 442 all_samples = []
443 file_path = samples_path.joinpath(pool, f"step_{step}.png") 443 file_path = samples_path.joinpath(pool, f"step_{step}.png")
@@ -456,7 +456,7 @@ class Checkpointer:
456 negative_prompt=nprompt, 456 negative_prompt=nprompt,
457 height=self.sample_image_size, 457 height=self.sample_image_size,
458 width=self.sample_image_size, 458 width=self.sample_image_size,
459 latents=latents[:len(prompt)] if latents is not None else None, 459 latents_or_image=latents[:len(prompt)] if latents is not None else None,
460 generator=generator if latents is not None else None, 460 generator=generator if latents is not None else None,
461 guidance_scale=guidance_scale, 461 guidance_scale=guidance_scale,
462 eta=eta, 462 eta=eta,
@@ -670,7 +670,7 @@ def main():
670 ).to(accelerator.device) 670 ).to(accelerator.device)
671 pipeline.set_progress_bar_config(dynamic_ncols=True) 671 pipeline.set_progress_bar_config(dynamic_ncols=True)
672 672
673 with torch.inference_mode(): 673 with torch.autocast("cuda"), torch.inference_mode():
674 for batch in batched_data: 674 for batch in batched_data:
675 image_name = [p.class_image_path for p in batch] 675 image_name = [p.class_image_path for p in batch]
676 prompt = [p.prompt.format(identifier=args.class_identifier) for p in batch] 676 prompt = [p.prompt.format(identifier=args.class_identifier) for p in batch]
@@ -908,7 +908,7 @@ def main():
908 text_encoder.eval() 908 text_encoder.eval()
909 val_loss = 0.0 909 val_loss = 0.0
910 910
911 with torch.inference_mode(): 911 with torch.autocast("cuda"), torch.inference_mode():
912 for step, batch in enumerate(val_dataloader): 912 for step, batch in enumerate(val_dataloader):
913 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 913 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
914 latents = latents * 0.18215 914 latents = latents * 0.18215