summaryrefslogtreecommitdiffstats
path: root/training/strategy
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy')
-rw-r--r--training/strategy/ti.py15
1 files changed, 11 insertions, 4 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 66d3129..09beec4 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -116,10 +116,17 @@ def textual_inversion_strategy_callbacks(
116 @torch.no_grad() 116 @torch.no_grad()
117 def on_before_optimize(lr: float, epoch: int): 117 def on_before_optimize(lr: float, epoch: int):
118 if use_emb_decay: 118 if use_emb_decay:
119 text_encoder.text_model.embeddings.normalize( 119 lambda_ = emb_decay * lr
120 emb_decay_target, 120
121 min(1.0, emb_decay * lr) 121 if lambda_ != 0:
122 ) 122 w = text_encoder.text_model.embeddings.temp_token_embedding.weight
123
124 mask = torch.zeros(w.size(0), dtype=torch.bool)
125 mask[text_encoder.text_model.embeddings.temp_token_ids] = True
126 mask[torch.all(w.grad == 0, dim=1)] = False
127
128 norm = w[mask, :].norm(dim=-1, keepdim=True)
129 w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
123 130
124 def on_after_optimize(lr: float): 131 def on_after_optimize(lr: float):
125 if ema_embeddings is not None: 132 if ema_embeddings is not None: