summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_ti.py2
-rw-r--r--training/strategy/ti.py18
2 files changed, 11 insertions, 9 deletions
diff --git a/train_ti.py b/train_ti.py
index 6c35d41..ef39c38 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -788,7 +788,7 @@ def main():
788 args.num_vectors, 788 args.num_vectors,
789 args.train_data_template 789 args.train_data_template
790 ): 790 ):
791 run(i, [placeholder_token], [initializer_token], [num_vectors], data_template) 791 run(i, [placeholder_token], [initializer_token], num_vectors, data_template)
792 embeddings.persist() 792 embeddings.persist()
793 793
794 794
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 7ac5011..b9a5547 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -108,14 +108,11 @@ def textual_inversion_strategy_callbacks(
108 @torch.no_grad() 108 @torch.no_grad()
109 def on_before_optimize(lr: float, epoch: int): 109 def on_before_optimize(lr: float, epoch: int):
110 if use_emb_decay: 110 if use_emb_decay:
111 return torch.stack([ 111 w = text_encoder.text_model.embeddings.temp_token_embedding.weight
112 t 112 return torch.all(w.grad == 0, dim=1)
113 for t in text_encoder.text_model.embeddings.temp_token_embedding
114 if t.grad is not None
115 ])
116 113
117 @torch.no_grad() 114 @torch.no_grad()
118 def on_after_optimize(w, lr: float): 115 def on_after_optimize(zero_ids, lr: float):
119 if ema_embeddings is not None: 116 if ema_embeddings is not None:
120 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) 117 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
121 118
@@ -123,8 +120,13 @@ def textual_inversion_strategy_callbacks(
123 lambda_ = emb_decay * lr 120 lambda_ = emb_decay * lr
124 121
125 if lambda_ != 0: 122 if lambda_ != 0:
126 norm = w[:, :].norm(dim=-1, keepdim=True) 123 w = text_encoder.text_model.embeddings.temp_token_embedding.weight
127 w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) 124
125 mask = torch.ones(w.shape[0], dtype=torch.bool)
126 mask[zero_ids] = False
127
128 norm = w[mask, :].norm(dim=-1, keepdim=True)
129 w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
128 130
129 def on_log(): 131 def on_log():
130 if ema_embeddings is not None: 132 if ema_embeddings is not None: