diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py index d752927..928b721 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -531,7 +531,7 @@ class Checkpointer(CheckpointerBase): | |||
| 531 | del text_encoder | 531 | del text_encoder |
| 532 | 532 | ||
| 533 | @torch.no_grad() | 533 | @torch.no_grad() |
| 534 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 534 | def save_samples(self, step): |
| 535 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 535 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
| 536 | 536 | ||
| 537 | ema_context = self.ema_embeddings.apply_temporary( | 537 | ema_context = self.ema_embeddings.apply_temporary( |
| @@ -550,7 +550,7 @@ class Checkpointer(CheckpointerBase): | |||
| 550 | ).to(self.accelerator.device) | 550 | ).to(self.accelerator.device) |
| 551 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 551 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 552 | 552 | ||
| 553 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) | 553 | super().save_samples(pipeline, step) |
| 554 | 554 | ||
| 555 | text_encoder.to(dtype=orig_dtype) | 555 | text_encoder.to(dtype=orig_dtype) |
| 556 | 556 | ||
