diff options
Diffstat (limited to 'training/strategy')
-rw-r--r-- | training/strategy/ti.py | 18 |
1 files changed, 8 insertions, 10 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index b9a5547..7ac5011 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -108,11 +108,14 @@ def textual_inversion_strategy_callbacks( | |||
108 | @torch.no_grad() | 108 | @torch.no_grad() |
109 | def on_before_optimize(lr: float, epoch: int): | 109 | def on_before_optimize(lr: float, epoch: int): |
110 | if use_emb_decay: | 110 | if use_emb_decay: |
111 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | 111 | return torch.stack([ |
112 | return torch.all(w.grad == 0, dim=1) | 112 | t |
113 | for t in text_encoder.text_model.embeddings.temp_token_embedding | ||
114 | if t.grad is not None | ||
115 | ]) | ||
113 | 116 | ||
114 | @torch.no_grad() | 117 | @torch.no_grad() |
115 | def on_after_optimize(zero_ids, lr: float): | 118 | def on_after_optimize(w, lr: float): |
116 | if ema_embeddings is not None: | 119 | if ema_embeddings is not None: |
117 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 120 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) |
118 | 121 | ||
@@ -120,13 +123,8 @@ def textual_inversion_strategy_callbacks( | |||
120 | lambda_ = emb_decay * lr | 123 | lambda_ = emb_decay * lr |
121 | 124 | ||
122 | if lambda_ != 0: | 125 | if lambda_ != 0: |
123 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | 126 | norm = w[:, :].norm(dim=-1, keepdim=True) |
124 | 127 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | |
125 | mask = torch.ones(w.shape[0], dtype=torch.bool) | ||
126 | mask[zero_ids] = False | ||
127 | |||
128 | norm = w[mask, :].norm(dim=-1, keepdim=True) | ||
129 | w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
130 | 128 | ||
131 | def on_log(): | 129 | def on_log(): |
132 | if ema_embeddings is not None: | 130 | if ema_embeddings is not None: |