summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 7b34fce..79b3d2c 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -489,7 +489,7 @@ class Checkpointer:
489 generator=generator, 489 generator=generator,
490 ) 490 )
491 491
492 with torch.inference_mode(): 492 with torch.autocast("cuda"), torch.inference_mode():
493 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: 493 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
494 all_samples = [] 494 all_samples = []
495 file_path = samples_path.joinpath(pool, f"step_{step}.jpg") 495 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
@@ -522,7 +522,7 @@ class Checkpointer:
522 negative_prompt=nprompt, 522 negative_prompt=nprompt,
523 height=self.sample_image_size, 523 height=self.sample_image_size,
524 width=self.sample_image_size, 524 width=self.sample_image_size,
525 latents=latents[:len(prompt)] if latents is not None else None, 525 latents_or_image=latents[:len(prompt)] if latents is not None else None,
526 generator=generator if latents is not None else None, 526 generator=generator if latents is not None else None,
527 guidance_scale=guidance_scale, 527 guidance_scale=guidance_scale,
528 eta=eta, 528 eta=eta,
@@ -768,7 +768,7 @@ def main():
768 ).to(accelerator.device) 768 ).to(accelerator.device)
769 pipeline.set_progress_bar_config(dynamic_ncols=True) 769 pipeline.set_progress_bar_config(dynamic_ncols=True)
770 770
771 with torch.inference_mode(): 771 with torch.autocast("cuda"), torch.inference_mode():
772 for batch in batched_data: 772 for batch in batched_data:
773 image_name = [item.class_image_path for item in batch] 773 image_name = [item.class_image_path for item in batch]
774 prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] 774 prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch]
@@ -1018,7 +1018,7 @@ def main():
1018 text_encoder.eval() 1018 text_encoder.eval()
1019 val_loss = 0.0 1019 val_loss = 0.0
1020 1020
1021 with torch.inference_mode(): 1021 with torch.autocast("cuda"), torch.inference_mode():
1022 for step, batch in enumerate(val_dataloader): 1022 for step, batch in enumerate(val_dataloader):
1023 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 1023 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
1024 latents = latents * 0.18215 1024 latents = latents * 0.18215