diff options
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/dreambooth.py b/dreambooth.py index 7b34fce..79b3d2c 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -489,7 +489,7 @@ class Checkpointer: | |||
489 | generator=generator, | 489 | generator=generator, |
490 | ) | 490 | ) |
491 | 491 | ||
492 | with torch.inference_mode(): | 492 | with torch.autocast("cuda"), torch.inference_mode(): |
493 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | 493 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: |
494 | all_samples = [] | 494 | all_samples = [] |
495 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | 495 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") |
@@ -522,7 +522,7 @@ class Checkpointer: | |||
522 | negative_prompt=nprompt, | 522 | negative_prompt=nprompt, |
523 | height=self.sample_image_size, | 523 | height=self.sample_image_size, |
524 | width=self.sample_image_size, | 524 | width=self.sample_image_size, |
525 | latents=latents[:len(prompt)] if latents is not None else None, | 525 | latents_or_image=latents[:len(prompt)] if latents is not None else None, |
526 | generator=generator if latents is not None else None, | 526 | generator=generator if latents is not None else None, |
527 | guidance_scale=guidance_scale, | 527 | guidance_scale=guidance_scale, |
528 | eta=eta, | 528 | eta=eta, |
@@ -768,7 +768,7 @@ def main(): | |||
768 | ).to(accelerator.device) | 768 | ).to(accelerator.device) |
769 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 769 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
770 | 770 | ||
771 | with torch.inference_mode(): | 771 | with torch.autocast("cuda"), torch.inference_mode(): |
772 | for batch in batched_data: | 772 | for batch in batched_data: |
773 | image_name = [item.class_image_path for item in batch] | 773 | image_name = [item.class_image_path for item in batch] |
774 | prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] | 774 | prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] |
@@ -1018,7 +1018,7 @@ def main(): | |||
1018 | text_encoder.eval() | 1018 | text_encoder.eval() |
1019 | val_loss = 0.0 | 1019 | val_loss = 0.0 |
1020 | 1020 | ||
1021 | with torch.inference_mode(): | 1021 | with torch.autocast("cuda"), torch.inference_mode(): |
1022 | for step, batch in enumerate(val_dataloader): | 1022 | for step, batch in enumerate(val_dataloader): |
1023 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 1023 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
1024 | latents = latents * 0.18215 | 1024 | latents = latents * 0.18215 |