summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-14 09:35:42 +0100
committerVolpeon <git@volpeon.ink>2023-01-14 09:35:42 +0100
commit6c38d0088ece492696a7bc94a5cb43a48289452a (patch)
treed84a8fefd52eba5cbf38e64d34962f34dc6d047d /train_ti.py
parentCleanup (diff)
downloadtextual-inversion-diff-6c38d0088ece492696a7bc94a5cb43a48289452a.tar.gz
textual-inversion-diff-6c38d0088ece492696a7bc94a5cb43a48289452a.tar.bz2
textual-inversion-diff-6c38d0088ece492696a7bc94a5cb43a48289452a.zip
Fix
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py
index d752927..928b721 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -531,7 +531,7 @@ class Checkpointer(CheckpointerBase):
531 del text_encoder 531 del text_encoder
532 532
533 @torch.no_grad() 533 @torch.no_grad()
534 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 534 def save_samples(self, step):
535 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 535 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
536 536
537 ema_context = self.ema_embeddings.apply_temporary( 537 ema_context = self.ema_embeddings.apply_temporary(
@@ -550,7 +550,7 @@ class Checkpointer(CheckpointerBase):
550 ).to(self.accelerator.device) 550 ).to(self.accelerator.device)
551 pipeline.set_progress_bar_config(dynamic_ncols=True) 551 pipeline.set_progress_bar_config(dynamic_ncols=True)
552 552
553 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) 553 super().save_samples(pipeline, step)
554 554
555 text_encoder.to(dtype=orig_dtype) 555 text_encoder.to(dtype=orig_dtype)
556 556