diff options
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index c180170..53776ba 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -523,7 +523,7 @@ class Checkpointer(CheckpointerBase): | |||
| 523 | torch.cuda.empty_cache() | 523 | torch.cuda.empty_cache() |
| 524 | 524 | ||
| 525 | @torch.no_grad() | 525 | @torch.no_grad() |
| 526 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 526 | def save_samples(self, step): |
| 527 | unet = self.accelerator.unwrap_model(self.unet) | 527 | unet = self.accelerator.unwrap_model(self.unet) |
| 528 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 528 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
| 529 | 529 | ||
| @@ -545,7 +545,7 @@ class Checkpointer(CheckpointerBase): | |||
| 545 | ).to(self.accelerator.device) | 545 | ).to(self.accelerator.device) |
| 546 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 546 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 547 | 547 | ||
| 548 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) | 548 | super().save_samples(pipeline, step) |
| 549 | 549 | ||
| 550 | unet.to(dtype=orig_unet_dtype) | 550 | unet.to(dtype=orig_unet_dtype) |
| 551 | text_encoder.to(dtype=orig_text_encoder_dtype) | 551 | text_encoder.to(dtype=orig_text_encoder_dtype) |
