From fc11c86142915d6c3935d28a3321b3ae91b613ef Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 22:03:01 +0100 Subject: Update --- trainer/base.py | 2 +- trainer/ti.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'trainer') diff --git a/trainer/base.py b/trainer/base.py index e700dd6..1f85e71 100644 --- a/trainer/base.py +++ b/trainer/base.py @@ -74,7 +74,7 @@ class Checkpointer(): def checkpoint(self, step: int, postfix: str): pass - @torch.inference_mode() + @torch.no_grad() def save_samples(self, step: int): print(f"Saving samples for step {step}...") diff --git a/trainer/ti.py b/trainer/ti.py index 15cf747..388acd3 100644 --- a/trainer/ti.py +++ b/trainer/ti.py @@ -42,7 +42,7 @@ class TextualInversionCheckpointer(Checkpointer): checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") ) - @torch.inference_mode() + @torch.no_grad() def save_samples(self, step): ema_context = self.ema_embeddings.apply_temporary( self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() -- cgit v1.2.3-70-g09d2