From 2e654c017780d37f3304436e2feb84b619f1c023 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 3 Apr 2023 22:25:20 +0200 Subject: Improved sparse embeddings --- training/strategy/ti.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'training') diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 16baa34..95128da 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -69,7 +69,7 @@ def textual_inversion_strategy_callbacks( if use_ema: ema_embeddings = EMAModel( - text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), inv_gamma=ema_inv_gamma, power=ema_power, max_value=ema_max_decay, @@ -81,13 +81,13 @@ def textual_inversion_strategy_callbacks( def ema_context(): if ema_embeddings is not None: return ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters() + text_encoder.text_model.embeddings.token_override_embedding.params.parameters() ) else: return nullcontext() def on_accum_model(): - return text_encoder.text_model.embeddings.temp_token_embedding + return text_encoder.text_model.embeddings.token_override_embedding.params @contextmanager def on_train(epoch: int): @@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks( @torch.no_grad() def on_after_optimize(zero_ids, lr: float): if ema_embeddings is not None: - ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) def on_log(): if ema_embeddings is not None: -- cgit v1.2.3-54-g00ecf