diff options
author | Volpeon <git@volpeon.ink> | 2023-04-07 11:02:47 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-07 11:02:47 +0200 |
commit | f5b86b44565aaaa92543989a85ea5d88ca9b1c0c (patch) | |
tree | df02bdcf757743708001fe70e9db2c3e2b9b4af9 /training/strategy/ti.py | |
parent | Update (diff) | |
download | textual-inversion-diff-f5b86b44565aaaa92543989a85ea5d88ca9b1c0c.tar.gz textual-inversion-diff-f5b86b44565aaaa92543989a85ea5d88ca9b1c0c.tar.bz2 textual-inversion-diff-f5b86b44565aaaa92543989a85ea5d88ca9b1c0c.zip |
Fix
Diffstat (limited to 'training/strategy/ti.py')
-rw-r--r-- | training/strategy/ti.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 9df160a..55e9934 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -107,18 +107,19 @@ def textual_inversion_strategy_callbacks( | |||
107 | @torch.no_grad() | 107 | @torch.no_grad() |
108 | def on_before_optimize(lr: float, epoch: int): | 108 | def on_before_optimize(lr: float, epoch: int): |
109 | if use_emb_decay: | 109 | if use_emb_decay: |
110 | return torch.stack([ | 110 | params = [ |
111 | p | 111 | p |
112 | for p in text_encoder.text_model.embeddings.token_override_embedding.params | 112 | for p in text_encoder.text_model.embeddings.token_override_embedding.params |
113 | if p.grad is not None | 113 | if p.grad is not None |
114 | ]) | 114 | ] |
115 | return torch.stack(params) if len(params) != 0 else None | ||
115 | 116 | ||
116 | @torch.no_grad() | 117 | @torch.no_grad() |
117 | def on_after_optimize(w, lr: float): | 118 | def on_after_optimize(w, lr: float): |
118 | if ema_embeddings is not None: | 119 | if ema_embeddings is not None: |
119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) | 120 | ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) |
120 | 121 | ||
121 | if use_emb_decay: | 122 | if use_emb_decay and w is not None: |
122 | lambda_ = emb_decay * lr | 123 | lambda_ = emb_decay * lr |
123 | 124 | ||
124 | if lambda_ != 0: | 125 | if lambda_ != 0: |