diff options
author | Volpeon <git@volpeon.ink> | 2023-04-08 08:46:47 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-08 08:46:47 +0200 |
commit | 4e77936376e7a1fa9b16ccc6af6650233825161c (patch) | |
tree | 7ea64ec7fd8cb23b4f03840ab9ddf47a8c807805 | |
parent | Fix (diff) | |
download | textual-inversion-diff-4e77936376e7a1fa9b16ccc6af6650233825161c.tar.gz textual-inversion-diff-4e77936376e7a1fa9b16ccc6af6650233825161c.tar.bz2 textual-inversion-diff-4e77936376e7a1fa9b16ccc6af6650233825161c.zip |
Fix TI
-rw-r--r-- | training/strategy/ti.py | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index d735dac..720ebf3 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -114,11 +114,12 @@ def textual_inversion_strategy_callbacks( | |||
114 | return torch.stack(params) if len(params) != 0 else None | 114 | return torch.stack(params) if len(params) != 0 else None |
115 | 115 | ||
116 | @torch.no_grad() | 116 | @torch.no_grad() |
117 | def on_after_optimize(w, lr: float): | 117 | def on_after_optimize(w, lrs: dict[str, float]): |
118 | if ema_embeddings is not None: | 118 | if ema_embeddings is not None: |
119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) | 119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) |
120 | 120 | ||
121 | if use_emb_decay and w is not None: | 121 | if use_emb_decay and w is not None: |
122 | lr = lrs["emb"] or lrs["0"] | ||
122 | lambda_ = emb_decay * lr | 123 | lambda_ = emb_decay * lr |
123 | 124 | ||
124 | if lambda_ != 0: | 125 | if lambda_ != 0: |