From 6c38d0088ece492696a7bc94a5cb43a48289452a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 09:35:42 +0100 Subject: Fix --- train_ti.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index d752927..928b721 100644 --- a/train_ti.py +++ b/train_ti.py @@ -531,7 +531,7 @@ class Checkpointer(CheckpointerBase): del text_encoder @torch.no_grad() - def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): + def save_samples(self, step): text_encoder = self.accelerator.unwrap_model(self.text_encoder) ema_context = self.ema_embeddings.apply_temporary( @@ -550,7 +550,7 @@ class Checkpointer(CheckpointerBase): ).to(self.accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) - super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) + super().save_samples(pipeline, step) text_encoder.to(dtype=orig_dtype) -- cgit v1.2.3-54-g00ecf