diff options
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 999161b..bf591bc 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -437,7 +437,7 @@ class Checkpointer: | |||
437 | generator=generator, | 437 | generator=generator, |
438 | ) | 438 | ) |
439 | 439 | ||
440 | with torch.inference_mode(): | 440 | with torch.autocast("cuda"), torch.inference_mode(): |
441 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | 441 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: |
442 | all_samples = [] | 442 | all_samples = [] |
443 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 443 | file_path = samples_path.joinpath(pool, f"step_{step}.png") |
@@ -456,7 +456,7 @@ class Checkpointer: | |||
456 | negative_prompt=nprompt, | 456 | negative_prompt=nprompt, |
457 | height=self.sample_image_size, | 457 | height=self.sample_image_size, |
458 | width=self.sample_image_size, | 458 | width=self.sample_image_size, |
459 | latents=latents[:len(prompt)] if latents is not None else None, | 459 | latents_or_image=latents[:len(prompt)] if latents is not None else None, |
460 | generator=generator if latents is not None else None, | 460 | generator=generator if latents is not None else None, |
461 | guidance_scale=guidance_scale, | 461 | guidance_scale=guidance_scale, |
462 | eta=eta, | 462 | eta=eta, |
@@ -670,7 +670,7 @@ def main(): | |||
670 | ).to(accelerator.device) | 670 | ).to(accelerator.device) |
671 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 671 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
672 | 672 | ||
673 | with torch.inference_mode(): | 673 | with torch.autocast("cuda"), torch.inference_mode(): |
674 | for batch in batched_data: | 674 | for batch in batched_data: |
675 | image_name = [p.class_image_path for p in batch] | 675 | image_name = [p.class_image_path for p in batch] |
676 | prompt = [p.prompt.format(identifier=args.class_identifier) for p in batch] | 676 | prompt = [p.prompt.format(identifier=args.class_identifier) for p in batch] |
@@ -908,7 +908,7 @@ def main(): | |||
908 | text_encoder.eval() | 908 | text_encoder.eval() |
909 | val_loss = 0.0 | 909 | val_loss = 0.0 |
910 | 910 | ||
911 | with torch.inference_mode(): | 911 | with torch.autocast("cuda"), torch.inference_mode(): |
912 | for step, batch in enumerate(val_dataloader): | 912 | for step, batch in enumerate(val_dataloader): |
913 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 913 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
914 | latents = latents * 0.18215 | 914 | latents = latents * 0.18215 |