From 1abbfd5215a99dba9d699e91baec00e6f02a0bd5 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 17 Jan 2023 08:13:39 +0100 Subject: Update --- training/strategy/ti.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'training/strategy/ti.py') diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 597abd0..081180f 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -88,7 +88,7 @@ def textual_inversion_strategy_callbacks( ema_embeddings = None def ema_context(): - if use_ema: + if ema_embeddings is not None: return ema_embeddings.apply_temporary( text_encoder.text_model.embeddings.temp_token_embedding.parameters() ) @@ -101,7 +101,7 @@ def textual_inversion_strategy_callbacks( def on_prepare(): text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) - if use_ema: + if ema_embeddings is not None: ema_embeddings.to(accelerator.device) if gradient_checkpointing: @@ -120,7 +120,7 @@ def textual_inversion_strategy_callbacks( yield def on_after_optimize(lr: float): - if use_ema: + if ema_embeddings is not None: ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) @torch.no_grad() @@ -132,7 +132,7 @@ def textual_inversion_strategy_callbacks( ) def on_log(): - if use_ema: + if ema_embeddings is not None: return {"ema_decay": ema_embeddings.decay} return {} -- cgit v1.2.3-54-g00ecf