diff options
Diffstat (limited to 'training/strategy')
-rw-r--r-- | training/strategy/ti.py | 19 |
1 files changed, 0 insertions, 19 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 49236c6..f0b84b5 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -104,28 +104,10 @@ def textual_inversion_strategy_callbacks( | |||
104 | yield | 104 | yield |
105 | 105 | ||
106 | @torch.no_grad() | 106 | @torch.no_grad() |
107 | def on_before_optimize(epoch: int): | ||
108 | if use_emb_decay: | ||
109 | params = [ | ||
110 | p | ||
111 | for p in text_encoder.text_model.embeddings.token_embedding.parameters() | ||
112 | if p.grad is not None | ||
113 | ] | ||
114 | return torch.stack(params) if len(params) != 0 else None | ||
115 | |||
116 | @torch.no_grad() | ||
117 | def on_after_optimize(w, lrs: dict[str, float]): | 107 | def on_after_optimize(w, lrs: dict[str, float]): |
118 | if ema_embeddings is not None: | 108 | if ema_embeddings is not None: |
119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) | 109 | ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) |
120 | 110 | ||
121 | if use_emb_decay and w is not None: | ||
122 | lr = lrs["emb"] if "emb" in lrs else lrs["0"] | ||
123 | lambda_ = emb_decay * lr | ||
124 | |||
125 | if lambda_ != 0: | ||
126 | norm = w[:, :].norm(dim=-1, keepdim=True) | ||
127 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
128 | |||
129 | def on_log(): | 111 | def on_log(): |
130 | if ema_embeddings is not None: | 112 | if ema_embeddings is not None: |
131 | return {"ema_decay": ema_embeddings.decay} | 113 | return {"ema_decay": ema_embeddings.decay} |
@@ -166,7 +148,6 @@ def textual_inversion_strategy_callbacks( | |||
166 | return TrainingCallbacks( | 148 | return TrainingCallbacks( |
167 | on_train=on_train, | 149 | on_train=on_train, |
168 | on_eval=on_eval, | 150 | on_eval=on_eval, |
169 | on_before_optimize=on_before_optimize, | ||
170 | on_after_optimize=on_after_optimize, | 151 | on_after_optimize=on_after_optimize, |
171 | on_log=on_log, | 152 | on_log=on_log, |
172 | on_checkpoint=on_checkpoint, | 153 | on_checkpoint=on_checkpoint, |