diff options
author | Volpeon <git@volpeon.ink> | 2023-01-14 09:35:42 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-14 09:35:42 +0100 |
commit | 6c38d0088ece492696a7bc94a5cb43a48289452a (patch) | |
tree | d84a8fefd52eba5cbf38e64d34962f34dc6d047d /train_ti.py | |
parent | Cleanup (diff) | |
download | textual-inversion-diff-6c38d0088ece492696a7bc94a5cb43a48289452a.tar.gz textual-inversion-diff-6c38d0088ece492696a7bc94a5cb43a48289452a.tar.bz2 textual-inversion-diff-6c38d0088ece492696a7bc94a5cb43a48289452a.zip |
Fix
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py index d752927..928b721 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -531,7 +531,7 @@ class Checkpointer(CheckpointerBase): | |||
531 | del text_encoder | 531 | del text_encoder |
532 | 532 | ||
533 | @torch.no_grad() | 533 | @torch.no_grad() |
534 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 534 | def save_samples(self, step): |
535 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 535 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
536 | 536 | ||
537 | ema_context = self.ema_embeddings.apply_temporary( | 537 | ema_context = self.ema_embeddings.apply_temporary( |
@@ -550,7 +550,7 @@ class Checkpointer(CheckpointerBase): | |||
550 | ).to(self.accelerator.device) | 550 | ).to(self.accelerator.device) |
551 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 551 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
552 | 552 | ||
553 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) | 553 | super().save_samples(pipeline, step) |
554 | 554 | ||
555 | text_encoder.to(dtype=orig_dtype) | 555 | text_encoder.to(dtype=orig_dtype) |
556 | 556 | ||