summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py
index d752927..928b721 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -531,7 +531,7 @@ class Checkpointer(CheckpointerBase):
531 del text_encoder 531 del text_encoder
532 532
533 @torch.no_grad() 533 @torch.no_grad()
534 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 534 def save_samples(self, step):
535 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 535 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
536 536
537 ema_context = self.ema_embeddings.apply_temporary( 537 ema_context = self.ema_embeddings.apply_temporary(
@@ -550,7 +550,7 @@ class Checkpointer(CheckpointerBase):
550 ).to(self.accelerator.device) 550 ).to(self.accelerator.device)
551 pipeline.set_progress_bar_config(dynamic_ncols=True) 551 pipeline.set_progress_bar_config(dynamic_ncols=True)
552 552
553 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) 553 super().save_samples(pipeline, step)
554 554
555 text_encoder.to(dtype=orig_dtype) 555 text_encoder.to(dtype=orig_dtype)
556 556