summaryrefslogtreecommitdiffstats
path: root/training/strategy/ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy/ti.py')
-rw-r--r--training/strategy/ti.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 9df160a..55e9934 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -107,18 +107,19 @@ def textual_inversion_strategy_callbacks(
107 @torch.no_grad() 107 @torch.no_grad()
108 def on_before_optimize(lr: float, epoch: int): 108 def on_before_optimize(lr: float, epoch: int):
109 if use_emb_decay: 109 if use_emb_decay:
110 return torch.stack([ 110 params = [
111 p 111 p
112 for p in text_encoder.text_model.embeddings.token_override_embedding.params 112 for p in text_encoder.text_model.embeddings.token_override_embedding.params
113 if p.grad is not None 113 if p.grad is not None
114 ]) 114 ]
115 return torch.stack(params) if len(params) != 0 else None
115 116
116 @torch.no_grad() 117 @torch.no_grad()
117 def on_after_optimize(w, lr: float): 118 def on_after_optimize(w, lr: float):
118 if ema_embeddings is not None: 119 if ema_embeddings is not None:
119 ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) 120 ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters())
120 121
121 if use_emb_decay: 122 if use_emb_decay and w is not None:
122 lambda_ = emb_decay * lr 123 lambda_ = emb_decay * lr
123 124
124 if lambda_ != 0: 125 if lambda_ != 0: