summaryrefslogtreecommitdiffstats
path: root/training/strategy/ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-08 08:46:47 +0200
committerVolpeon <git@volpeon.ink>2023-04-08 08:46:47 +0200
commit4e77936376e7a1fa9b16ccc6af6650233825161c (patch)
tree7ea64ec7fd8cb23b4f03840ab9ddf47a8c807805 /training/strategy/ti.py
parentFix (diff)
downloadtextual-inversion-diff-4e77936376e7a1fa9b16ccc6af6650233825161c.tar.gz
textual-inversion-diff-4e77936376e7a1fa9b16ccc6af6650233825161c.tar.bz2
textual-inversion-diff-4e77936376e7a1fa9b16ccc6af6650233825161c.zip
Fix TI
Diffstat (limited to 'training/strategy/ti.py')
-rw-r--r--training/strategy/ti.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index d735dac..720ebf3 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -114,11 +114,12 @@ def textual_inversion_strategy_callbacks(
114 return torch.stack(params) if len(params) != 0 else None 114 return torch.stack(params) if len(params) != 0 else None
115 115
116 @torch.no_grad() 116 @torch.no_grad()
117 def on_after_optimize(w, lr: float): 117 def on_after_optimize(w, lrs: dict[str, float]):
118 if ema_embeddings is not None: 118 if ema_embeddings is not None:
119 ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) 119 ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters())
120 120
121 if use_emb_decay and w is not None: 121 if use_emb_decay and w is not None:
122 lr = lrs["emb"] or lrs["0"]
122 lambda_ = emb_decay * lr 123 lambda_ = emb_decay * lr
123 124
124 if lambda_ != 0: 125 if lambda_ != 0: