summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/strategy/ti.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index d735dac..720ebf3 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -114,11 +114,12 @@ def textual_inversion_strategy_callbacks(
114 return torch.stack(params) if len(params) != 0 else None 114 return torch.stack(params) if len(params) != 0 else None
115 115
116 @torch.no_grad() 116 @torch.no_grad()
117 def on_after_optimize(w, lr: float): 117 def on_after_optimize(w, lrs: dict[str, float]):
118 if ema_embeddings is not None: 118 if ema_embeddings is not None:
119 ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) 119 ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters())
120 120
121 if use_emb_decay and w is not None: 121 if use_emb_decay and w is not None:
122 lr = lrs["emb"] or lrs["0"]
122 lambda_ = emb_decay * lr 123 lambda_ = emb_decay * lr
123 124
124 if lambda_ != 0: 125 if lambda_ != 0: