summaryrefslogtreecommitdiffstats
path: root/trainer
diff options
context:
space:
mode:
Diffstat (limited to 'trainer')
-rw-r--r--trainer/base.py2
-rw-r--r--trainer/ti.py2
2 files changed, 2 insertions, 2 deletions
diff --git a/trainer/base.py b/trainer/base.py
index e700dd6..1f85e71 100644
--- a/trainer/base.py
+++ b/trainer/base.py
@@ -74,7 +74,7 @@ class Checkpointer():
74 def checkpoint(self, step: int, postfix: str): 74 def checkpoint(self, step: int, postfix: str):
75 pass 75 pass
76 76
77 @torch.inference_mode() 77 @torch.no_grad()
78 def save_samples(self, step: int): 78 def save_samples(self, step: int):
79 print(f"Saving samples for step {step}...") 79 print(f"Saving samples for step {step}...")
80 80
diff --git a/trainer/ti.py b/trainer/ti.py
index 15cf747..388acd3 100644
--- a/trainer/ti.py
+++ b/trainer/ti.py
@@ -42,7 +42,7 @@ class TextualInversionCheckpointer(Checkpointer):
42 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") 42 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
43 ) 43 )
44 44
45 @torch.inference_mode() 45 @torch.no_grad()
46 def save_samples(self, step): 46 def save_samples(self, step):
47 ema_context = self.ema_embeddings.apply_temporary( 47 ema_context = self.ema_embeddings.apply_temporary(
48 self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() 48 self.text_encoder.text_model.embeddings.temp_token_embedding.parameters()