From 21d70916f66e74a87c631a06b70774954b085b48 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 7 Apr 2023 14:14:00 +0200 Subject: Fix --- training/strategy/ti.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) (limited to 'training/strategy/ti.py') diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 55e9934..6a637c3 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -89,16 +89,15 @@ def textual_inversion_strategy_callbacks( else: return nullcontext() - def on_accum_model(): - return text_encoder.text_model.embeddings.token_override_embedding.params - @contextmanager def on_train(epoch: int): + text_encoder.text_model.embeddings.token_override_embedding.params.train() tokenizer.train() yield @contextmanager def on_eval(): + text_encoder.text_model.embeddings.token_override_embedding.params.eval() tokenizer.eval() with ema_context(): @@ -166,7 +165,6 @@ def textual_inversion_strategy_callbacks( torch.cuda.empty_cache() return TrainingCallbacks( - on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, on_before_optimize=on_before_optimize, -- cgit v1.2.3-54-g00ecf