diff options
Diffstat (limited to 'training/strategy')
-rw-r--r-- | training/strategy/ti.py | 15 |
1 files changed, 11 insertions, 4 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 66d3129..09beec4 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -116,10 +116,17 @@ def textual_inversion_strategy_callbacks( | |||
116 | @torch.no_grad() | 116 | @torch.no_grad() |
117 | def on_before_optimize(lr: float, epoch: int): | 117 | def on_before_optimize(lr: float, epoch: int): |
118 | if use_emb_decay: | 118 | if use_emb_decay: |
119 | text_encoder.text_model.embeddings.normalize( | 119 | lambda_ = emb_decay * lr |
120 | emb_decay_target, | 120 | |
121 | min(1.0, emb_decay * lr) | 121 | if lambda_ != 0: |
122 | ) | 122 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight |
123 | |||
124 | mask = torch.zeros(w.size(0), dtype=torch.bool) | ||
125 | mask[text_encoder.text_model.embeddings.temp_token_ids] = True | ||
126 | mask[torch.all(w.grad == 0, dim=1)] = False | ||
127 | |||
128 | norm = w[mask, :].norm(dim=-1, keepdim=True) | ||
129 | w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
123 | 130 | ||
124 | def on_after_optimize(lr: float): | 131 | def on_after_optimize(lr: float): |
125 | if ema_embeddings is not None: | 132 | if ema_embeddings is not None: |