diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py index d752927..928b721 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -531,7 +531,7 @@ class Checkpointer(CheckpointerBase): | |||
531 | del text_encoder | 531 | del text_encoder |
532 | 532 | ||
533 | @torch.no_grad() | 533 | @torch.no_grad() |
534 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 534 | def save_samples(self, step): |
535 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 535 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
536 | 536 | ||
537 | ema_context = self.ema_embeddings.apply_temporary( | 537 | ema_context = self.ema_embeddings.apply_temporary( |
@@ -550,7 +550,7 @@ class Checkpointer(CheckpointerBase): | |||
550 | ).to(self.accelerator.device) | 550 | ).to(self.accelerator.device) |
551 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 551 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
552 | 552 | ||
553 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) | 553 | super().save_samples(pipeline, step) |
554 | 554 | ||
555 | text_encoder.to(dtype=orig_dtype) | 555 | text_encoder.to(dtype=orig_dtype) |
556 | 556 | ||