diff options
author | Volpeon <git@volpeon.ink> | 2023-02-21 14:08:49 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-02-21 14:08:49 +0100 |
commit | 96638bbd54ca7f91d44c938fae7275d3ecaa6add (patch) | |
tree | b281a0e58820151e8738dfc5294bde5be482956b /training/strategy/ti.py | |
parent | Fix (diff) | |
download | textual-inversion-diff-96638bbd54ca7f91d44c938fae7275d3ecaa6add.tar.gz textual-inversion-diff-96638bbd54ca7f91d44c938fae7275d3ecaa6add.tar.bz2 textual-inversion-diff-96638bbd54ca7f91d44c938fae7275d3ecaa6add.zip |
Fixed TI normalization order
Diffstat (limited to 'training/strategy/ti.py')
-rw-r--r-- | training/strategy/ti.py | 15 |
1 files changed, 10 insertions, 5 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 09beec4..732cd74 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -116,6 +116,15 @@ def textual_inversion_strategy_callbacks( | |||
116 | @torch.no_grad() | 116 | @torch.no_grad() |
117 | def on_before_optimize(lr: float, epoch: int): | 117 | def on_before_optimize(lr: float, epoch: int): |
118 | if use_emb_decay: | 118 | if use_emb_decay: |
119 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | ||
120 | return torch.all(w.grad == 0, dim=1) | ||
121 | |||
122 | @torch.no_grad() | ||
123 | def on_after_optimize(zero_ids, lr: float): | ||
124 | if ema_embeddings is not None: | ||
125 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
126 | |||
127 | if use_emb_decay: | ||
119 | lambda_ = emb_decay * lr | 128 | lambda_ = emb_decay * lr |
120 | 129 | ||
121 | if lambda_ != 0: | 130 | if lambda_ != 0: |
@@ -123,15 +132,11 @@ def textual_inversion_strategy_callbacks( | |||
123 | 132 | ||
124 | mask = torch.zeros(w.size(0), dtype=torch.bool) | 133 | mask = torch.zeros(w.size(0), dtype=torch.bool) |
125 | mask[text_encoder.text_model.embeddings.temp_token_ids] = True | 134 | mask[text_encoder.text_model.embeddings.temp_token_ids] = True |
126 | mask[torch.all(w.grad == 0, dim=1)] = False | 135 | mask[zero_ids] = False |
127 | 136 | ||
128 | norm = w[mask, :].norm(dim=-1, keepdim=True) | 137 | norm = w[mask, :].norm(dim=-1, keepdim=True) |
129 | w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | 138 | w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) |
130 | 139 | ||
131 | def on_after_optimize(lr: float): | ||
132 | if ema_embeddings is not None: | ||
133 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
134 | |||
135 | def on_log(): | 140 | def on_log(): |
136 | if ema_embeddings is not None: | 141 | if ema_embeddings is not None: |
137 | return {"ema_decay": ema_embeddings.decay} | 142 | return {"ema_decay": ema_embeddings.decay} |