From 6c38d0088ece492696a7bc94a5cb43a48289452a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 09:35:42 +0100 Subject: Fix --- train_dreambooth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index c180170..53776ba 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -523,7 +523,7 @@ class Checkpointer(CheckpointerBase): torch.cuda.empty_cache() @torch.no_grad() - def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): + def save_samples(self, step): unet = self.accelerator.unwrap_model(self.unet) text_encoder = self.accelerator.unwrap_model(self.text_encoder) @@ -545,7 +545,7 @@ class Checkpointer(CheckpointerBase): ).to(self.accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) - super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) + super().save_samples(pipeline, step) unet.to(dtype=orig_unet_dtype) text_encoder.to(dtype=orig_text_encoder_dtype) -- cgit v1.2.3-54-g00ecf