diff options
Diffstat (limited to 'textual_inversion.py')
| -rw-r--r-- | textual_inversion.py | 8 | 
1 files changed, 4 insertions, 4 deletions
| diff --git a/textual_inversion.py b/textual_inversion.py index 999161b..bf591bc 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -437,7 +437,7 @@ class Checkpointer: | |||
| 437 | generator=generator, | 437 | generator=generator, | 
| 438 | ) | 438 | ) | 
| 439 | 439 | ||
| 440 | with torch.inference_mode(): | 440 | with torch.autocast("cuda"), torch.inference_mode(): | 
| 441 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | 441 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | 
| 442 | all_samples = [] | 442 | all_samples = [] | 
| 443 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 443 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 
| @@ -456,7 +456,7 @@ class Checkpointer: | |||
| 456 | negative_prompt=nprompt, | 456 | negative_prompt=nprompt, | 
| 457 | height=self.sample_image_size, | 457 | height=self.sample_image_size, | 
| 458 | width=self.sample_image_size, | 458 | width=self.sample_image_size, | 
| 459 | latents=latents[:len(prompt)] if latents is not None else None, | 459 | latents_or_image=latents[:len(prompt)] if latents is not None else None, | 
| 460 | generator=generator if latents is not None else None, | 460 | generator=generator if latents is not None else None, | 
| 461 | guidance_scale=guidance_scale, | 461 | guidance_scale=guidance_scale, | 
| 462 | eta=eta, | 462 | eta=eta, | 
| @@ -670,7 +670,7 @@ def main(): | |||
| 670 | ).to(accelerator.device) | 670 | ).to(accelerator.device) | 
| 671 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 671 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 
| 672 | 672 | ||
| 673 | with torch.inference_mode(): | 673 | with torch.autocast("cuda"), torch.inference_mode(): | 
| 674 | for batch in batched_data: | 674 | for batch in batched_data: | 
| 675 | image_name = [p.class_image_path for p in batch] | 675 | image_name = [p.class_image_path for p in batch] | 
| 676 | prompt = [p.prompt.format(identifier=args.class_identifier) for p in batch] | 676 | prompt = [p.prompt.format(identifier=args.class_identifier) for p in batch] | 
| @@ -908,7 +908,7 @@ def main(): | |||
| 908 | text_encoder.eval() | 908 | text_encoder.eval() | 
| 909 | val_loss = 0.0 | 909 | val_loss = 0.0 | 
| 910 | 910 | ||
| 911 | with torch.inference_mode(): | 911 | with torch.autocast("cuda"), torch.inference_mode(): | 
| 912 | for step, batch in enumerate(val_dataloader): | 912 | for step, batch in enumerate(val_dataloader): | 
| 913 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 913 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 
| 914 | latents = latents * 0.18215 | 914 | latents = latents * 0.18215 | 
