summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-14 09:35:42 +0100
committerVolpeon <git@volpeon.ink>2023-01-14 09:35:42 +0100
commit6c38d0088ece492696a7bc94a5cb43a48289452a (patch)
treed84a8fefd52eba5cbf38e64d34962f34dc6d047d /train_dreambooth.py
parentCleanup (diff)
downloadtextual-inversion-diff-6c38d0088ece492696a7bc94a5cb43a48289452a.tar.gz
textual-inversion-diff-6c38d0088ece492696a7bc94a5cb43a48289452a.tar.bz2
textual-inversion-diff-6c38d0088ece492696a7bc94a5cb43a48289452a.zip
Fix
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index c180170..53776ba 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -523,7 +523,7 @@ class Checkpointer(CheckpointerBase):
523 torch.cuda.empty_cache() 523 torch.cuda.empty_cache()
524 524
525 @torch.no_grad() 525 @torch.no_grad()
526 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 526 def save_samples(self, step):
527 unet = self.accelerator.unwrap_model(self.unet) 527 unet = self.accelerator.unwrap_model(self.unet)
528 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 528 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
529 529
@@ -545,7 +545,7 @@ class Checkpointer(CheckpointerBase):
545 ).to(self.accelerator.device) 545 ).to(self.accelerator.device)
546 pipeline.set_progress_bar_config(dynamic_ncols=True) 546 pipeline.set_progress_bar_config(dynamic_ncols=True)
547 547
548 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) 548 super().save_samples(pipeline, step)
549 549
550 unet.to(dtype=orig_unet_dtype) 550 unet.to(dtype=orig_unet_dtype)
551 text_encoder.to(dtype=orig_text_encoder_dtype) 551 text_encoder.to(dtype=orig_text_encoder_dtype)