diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/strategy/ti.py | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 7ac5011..b9a5547 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -108,14 +108,11 @@ def textual_inversion_strategy_callbacks( | |||
108 | @torch.no_grad() | 108 | @torch.no_grad() |
109 | def on_before_optimize(lr: float, epoch: int): | 109 | def on_before_optimize(lr: float, epoch: int): |
110 | if use_emb_decay: | 110 | if use_emb_decay: |
111 | return torch.stack([ | 111 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight |
112 | t | 112 | return torch.all(w.grad == 0, dim=1) |
113 | for t in text_encoder.text_model.embeddings.temp_token_embedding | ||
114 | if t.grad is not None | ||
115 | ]) | ||
116 | 113 | ||
117 | @torch.no_grad() | 114 | @torch.no_grad() |
118 | def on_after_optimize(w, lr: float): | 115 | def on_after_optimize(zero_ids, lr: float): |
119 | if ema_embeddings is not None: | 116 | if ema_embeddings is not None: |
120 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 117 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) |
121 | 118 | ||
@@ -123,8 +120,13 @@ def textual_inversion_strategy_callbacks( | |||
123 | lambda_ = emb_decay * lr | 120 | lambda_ = emb_decay * lr |
124 | 121 | ||
125 | if lambda_ != 0: | 122 | if lambda_ != 0: |
126 | norm = w[:, :].norm(dim=-1, keepdim=True) | 123 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight |
127 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | 124 | |
125 | mask = torch.ones(w.shape[0], dtype=torch.bool) | ||
126 | mask[zero_ids] = False | ||
127 | |||
128 | norm = w[mask, :].norm(dim=-1, keepdim=True) | ||
129 | w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
128 | 130 | ||
129 | def on_log(): | 131 | def on_log(): |
130 | if ema_embeddings is not None: | 132 | if ema_embeddings is not None: |