summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index c180170..53776ba 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -523,7 +523,7 @@ class Checkpointer(CheckpointerBase):
523 torch.cuda.empty_cache() 523 torch.cuda.empty_cache()
524 524
525 @torch.no_grad() 525 @torch.no_grad()
526 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 526 def save_samples(self, step):
527 unet = self.accelerator.unwrap_model(self.unet) 527 unet = self.accelerator.unwrap_model(self.unet)
528 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 528 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
529 529
@@ -545,7 +545,7 @@ class Checkpointer(CheckpointerBase):
545 ).to(self.accelerator.device) 545 ).to(self.accelerator.device)
546 pipeline.set_progress_bar_config(dynamic_ncols=True) 546 pipeline.set_progress_bar_config(dynamic_ncols=True)
547 547
548 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) 548 super().save_samples(pipeline, step)
549 549
550 unet.to(dtype=orig_unet_dtype) 550 unet.to(dtype=orig_unet_dtype)
551 text_encoder.to(dtype=orig_text_encoder_dtype) 551 text_encoder.to(dtype=orig_text_encoder_dtype)