summaryrefslogtreecommitdiffstats
path: root/training/strategy/ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy/ti.py')
-rw-r--r--training/strategy/ti.py22
1 files changed, 21 insertions, 1 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 95128da..9df160a 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -31,6 +31,9 @@ def textual_inversion_strategy_callbacks(
31 seed: int, 31 seed: int,
32 placeholder_tokens: list[str], 32 placeholder_tokens: list[str],
33 placeholder_token_ids: list[list[int]], 33 placeholder_token_ids: list[list[int]],
34 use_emb_decay: bool = False,
35 emb_decay_target: float = 0.4,
36 emb_decay: float = 1e-2,
34 use_ema: bool = False, 37 use_ema: bool = False,
35 ema_inv_gamma: float = 1.0, 38 ema_inv_gamma: float = 1.0,
36 ema_power: int = 1, 39 ema_power: int = 1,
@@ -102,10 +105,26 @@ def textual_inversion_strategy_callbacks(
102 yield 105 yield
103 106
104 @torch.no_grad() 107 @torch.no_grad()
105 def on_after_optimize(zero_ids, lr: float): 108 def on_before_optimize(lr: float, epoch: int):
109 if use_emb_decay:
110 return torch.stack([
111 p
112 for p in text_encoder.text_model.embeddings.token_override_embedding.params
113 if p.grad is not None
114 ])
115
116 @torch.no_grad()
117 def on_after_optimize(w, lr: float):
106 if ema_embeddings is not None: 118 if ema_embeddings is not None:
107 ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) 119 ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters())
108 120
121 if use_emb_decay:
122 lambda_ = emb_decay * lr
123
124 if lambda_ != 0:
125 norm = w[:, :].norm(dim=-1, keepdim=True)
126 w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
127
109 def on_log(): 128 def on_log():
110 if ema_embeddings is not None: 129 if ema_embeddings is not None:
111 return {"ema_decay": ema_embeddings.decay} 130 return {"ema_decay": ema_embeddings.decay}
@@ -149,6 +168,7 @@ def textual_inversion_strategy_callbacks(
149 on_accum_model=on_accum_model, 168 on_accum_model=on_accum_model,
150 on_train=on_train, 169 on_train=on_train,
151 on_eval=on_eval, 170 on_eval=on_eval,
171 on_before_optimize=on_before_optimize,
152 on_after_optimize=on_after_optimize, 172 on_after_optimize=on_after_optimize,
153 on_log=on_log, 173 on_log=on_log,
154 on_checkpoint=on_checkpoint, 174 on_checkpoint=on_checkpoint,