summaryrefslogtreecommitdiffstats
path: root/training/strategy
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy')
-rw-r--r--training/strategy/ti.py30
1 files changed, 4 insertions, 26 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index b9a5547..19b8d25 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -32,9 +32,6 @@ def textual_inversion_strategy_callbacks(
32 placeholder_tokens: list[str], 32 placeholder_tokens: list[str],
33 placeholder_token_ids: list[list[int]], 33 placeholder_token_ids: list[list[int]],
34 gradient_checkpointing: bool = False, 34 gradient_checkpointing: bool = False,
35 use_emb_decay: bool = False,
36 emb_decay_target: float = 0.4,
37 emb_decay: float = 1e-2,
38 use_ema: bool = False, 35 use_ema: bool = False,
39 ema_inv_gamma: float = 1.0, 36 ema_inv_gamma: float = 1.0,
40 ema_power: int = 1, 37 ema_power: int = 1,
@@ -73,7 +70,7 @@ def textual_inversion_strategy_callbacks(
73 70
74 if use_ema: 71 if use_ema:
75 ema_embeddings = EMAModel( 72 ema_embeddings = EMAModel(
76 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), 73 text_encoder.text_model.embeddings.overlay.parameters(),
77 inv_gamma=ema_inv_gamma, 74 inv_gamma=ema_inv_gamma,
78 power=ema_power, 75 power=ema_power,
79 max_value=ema_max_decay, 76 max_value=ema_max_decay,
@@ -85,13 +82,13 @@ def textual_inversion_strategy_callbacks(
85 def ema_context(): 82 def ema_context():
86 if ema_embeddings is not None: 83 if ema_embeddings is not None:
87 return ema_embeddings.apply_temporary( 84 return ema_embeddings.apply_temporary(
88 text_encoder.text_model.embeddings.temp_token_embedding.parameters() 85 text_encoder.text_model.embeddings.overlay.parameters()
89 ) 86 )
90 else: 87 else:
91 return nullcontext() 88 return nullcontext()
92 89
93 def on_accum_model(): 90 def on_accum_model():
94 return text_encoder.text_model.embeddings.temp_token_embedding 91 return text_encoder.text_model.embeddings.overlay
95 92
96 @contextmanager 93 @contextmanager
97 def on_train(epoch: int): 94 def on_train(epoch: int):
@@ -106,27 +103,9 @@ def textual_inversion_strategy_callbacks(
106 yield 103 yield
107 104
108 @torch.no_grad() 105 @torch.no_grad()
109 def on_before_optimize(lr: float, epoch: int):
110 if use_emb_decay:
111 w = text_encoder.text_model.embeddings.temp_token_embedding.weight
112 return torch.all(w.grad == 0, dim=1)
113
114 @torch.no_grad()
115 def on_after_optimize(zero_ids, lr: float): 106 def on_after_optimize(zero_ids, lr: float):
116 if ema_embeddings is not None: 107 if ema_embeddings is not None:
117 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) 108 ema_embeddings.step(text_encoder.text_model.embeddings.overlay.parameters())
118
119 if use_emb_decay:
120 lambda_ = emb_decay * lr
121
122 if lambda_ != 0:
123 w = text_encoder.text_model.embeddings.temp_token_embedding.weight
124
125 mask = torch.ones(w.shape[0], dtype=torch.bool)
126 mask[zero_ids] = False
127
128 norm = w[mask, :].norm(dim=-1, keepdim=True)
129 w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
130 109
131 def on_log(): 110 def on_log():
132 if ema_embeddings is not None: 111 if ema_embeddings is not None:
@@ -171,7 +150,6 @@ def textual_inversion_strategy_callbacks(
171 on_accum_model=on_accum_model, 150 on_accum_model=on_accum_model,
172 on_train=on_train, 151 on_train=on_train,
173 on_eval=on_eval, 152 on_eval=on_eval,
174 on_before_optimize=on_before_optimize,
175 on_after_optimize=on_after_optimize, 153 on_after_optimize=on_after_optimize,
176 on_log=on_log, 154 on_log=on_log,
177 on_checkpoint=on_checkpoint, 155 on_checkpoint=on_checkpoint,