From fc11c86142915d6c3935d28a3321b3ae91b613ef Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 22:03:01 +0100 Subject: Update --- train_ti.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index deed84c..a4e2dde 100644 --- a/train_ti.py +++ b/train_ti.py @@ -512,7 +512,7 @@ class TextualInversionCheckpointer(Checkpointer): checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") ) - @torch.inference_mode() + @torch.no_grad() def save_samples(self, step): ema_context = self.ema_embeddings.apply_temporary( self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() @@ -808,7 +808,6 @@ def main(): optimizer=optimizer, lr_scheduler=lr_scheduler, model=text_encoder, - checkpointer=checkpointer, train_dataloader=train_dataloader, val_dataloader=val_dataloader, loss_step=loss_step_, @@ -819,7 +818,9 @@ def main(): on_log=on_log, on_train=on_train, on_after_optimize=on_after_optimize, - on_eval=on_eval + on_eval=on_eval, + on_sample=checkpointer.save_samples, + on_checkpoint=checkpointer.checkpoint, ) -- cgit v1.2.3-54-g00ecf