diff options
Diffstat (limited to 'textual_inversion.py')
| -rw-r--r-- | textual_inversion.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 999161b..bf591bc 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -437,7 +437,7 @@ class Checkpointer: | |||
| 437 | generator=generator, | 437 | generator=generator, |
| 438 | ) | 438 | ) |
| 439 | 439 | ||
| 440 | with torch.inference_mode(): | 440 | with torch.autocast("cuda"), torch.inference_mode(): |
| 441 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | 441 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: |
| 442 | all_samples = [] | 442 | all_samples = [] |
| 443 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 443 | file_path = samples_path.joinpath(pool, f"step_{step}.png") |
| @@ -456,7 +456,7 @@ class Checkpointer: | |||
| 456 | negative_prompt=nprompt, | 456 | negative_prompt=nprompt, |
| 457 | height=self.sample_image_size, | 457 | height=self.sample_image_size, |
| 458 | width=self.sample_image_size, | 458 | width=self.sample_image_size, |
| 459 | latents=latents[:len(prompt)] if latents is not None else None, | 459 | latents_or_image=latents[:len(prompt)] if latents is not None else None, |
| 460 | generator=generator if latents is not None else None, | 460 | generator=generator if latents is not None else None, |
| 461 | guidance_scale=guidance_scale, | 461 | guidance_scale=guidance_scale, |
| 462 | eta=eta, | 462 | eta=eta, |
| @@ -670,7 +670,7 @@ def main(): | |||
| 670 | ).to(accelerator.device) | 670 | ).to(accelerator.device) |
| 671 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 671 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 672 | 672 | ||
| 673 | with torch.inference_mode(): | 673 | with torch.autocast("cuda"), torch.inference_mode(): |
| 674 | for batch in batched_data: | 674 | for batch in batched_data: |
| 675 | image_name = [p.class_image_path for p in batch] | 675 | image_name = [p.class_image_path for p in batch] |
| 676 | prompt = [p.prompt.format(identifier=args.class_identifier) for p in batch] | 676 | prompt = [p.prompt.format(identifier=args.class_identifier) for p in batch] |
| @@ -908,7 +908,7 @@ def main(): | |||
| 908 | text_encoder.eval() | 908 | text_encoder.eval() |
| 909 | val_loss = 0.0 | 909 | val_loss = 0.0 |
| 910 | 910 | ||
| 911 | with torch.inference_mode(): | 911 | with torch.autocast("cuda"), torch.inference_mode(): |
| 912 | for step, batch in enumerate(val_dataloader): | 912 | for step, batch in enumerate(val_dataloader): |
| 913 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 913 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 914 | latents = latents * 0.18215 | 914 | latents = latents * 0.18215 |
