diff options
Diffstat (limited to 'trainer')
-rw-r--r-- | trainer/base.py | 2 | ||||
-rw-r--r-- | trainer/ti.py | 2 |
2 files changed, 2 insertions, 2 deletions
diff --git a/trainer/base.py b/trainer/base.py index e700dd6..1f85e71 100644 --- a/trainer/base.py +++ b/trainer/base.py | |||
@@ -74,7 +74,7 @@ class Checkpointer(): | |||
74 | def checkpoint(self, step: int, postfix: str): | 74 | def checkpoint(self, step: int, postfix: str): |
75 | pass | 75 | pass |
76 | 76 | ||
77 | @torch.inference_mode() | 77 | @torch.no_grad() |
78 | def save_samples(self, step: int): | 78 | def save_samples(self, step: int): |
79 | print(f"Saving samples for step {step}...") | 79 | print(f"Saving samples for step {step}...") |
80 | 80 | ||
diff --git a/trainer/ti.py b/trainer/ti.py index 15cf747..388acd3 100644 --- a/trainer/ti.py +++ b/trainer/ti.py | |||
@@ -42,7 +42,7 @@ class TextualInversionCheckpointer(Checkpointer): | |||
42 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") | 42 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") |
43 | ) | 43 | ) |
44 | 44 | ||
45 | @torch.inference_mode() | 45 | @torch.no_grad() |
46 | def save_samples(self, step): | 46 | def save_samples(self, step): |
47 | ema_context = self.ema_embeddings.apply_temporary( | 47 | ema_context = self.ema_embeddings.apply_temporary( |
48 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() | 48 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() |