summaryrefslogtreecommitdiffstats
path: root/training/strategy/ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy/ti.py')
-rw-r--r--training/strategy/ti.py19
1 files changed, 0 insertions, 19 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 49236c6..f0b84b5 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -104,28 +104,10 @@ def textual_inversion_strategy_callbacks(
104 yield 104 yield
105 105
106 @torch.no_grad() 106 @torch.no_grad()
107 def on_before_optimize(epoch: int):
108 if use_emb_decay:
109 params = [
110 p
111 for p in text_encoder.text_model.embeddings.token_embedding.parameters()
112 if p.grad is not None
113 ]
114 return torch.stack(params) if len(params) != 0 else None
115
116 @torch.no_grad()
117 def on_after_optimize(w, lrs: dict[str, float]): 107 def on_after_optimize(w, lrs: dict[str, float]):
118 if ema_embeddings is not None: 108 if ema_embeddings is not None:
119 ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) 109 ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters())
120 110
121 if use_emb_decay and w is not None:
122 lr = lrs["emb"] if "emb" in lrs else lrs["0"]
123 lambda_ = emb_decay * lr
124
125 if lambda_ != 0:
126 norm = w[:, :].norm(dim=-1, keepdim=True)
127 w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
128
129 def on_log(): 111 def on_log():
130 if ema_embeddings is not None: 112 if ema_embeddings is not None:
131 return {"ema_decay": ema_embeddings.decay} 113 return {"ema_decay": ema_embeddings.decay}
@@ -166,7 +148,6 @@ def textual_inversion_strategy_callbacks(
166 return TrainingCallbacks( 148 return TrainingCallbacks(
167 on_train=on_train, 149 on_train=on_train,
168 on_eval=on_eval, 150 on_eval=on_eval,
169 on_before_optimize=on_before_optimize,
170 on_after_optimize=on_after_optimize, 151 on_after_optimize=on_after_optimize,
171 on_log=on_log, 152 on_log=on_log,
172 on_checkpoint=on_checkpoint, 153 on_checkpoint=on_checkpoint,