summaryrefslogtreecommitdiffstats
path: root/training/strategy/ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy/ti.py')
-rw-r--r--training/strategy/ti.py27
1 files changed, 19 insertions, 8 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 6bc1d7d..7373982 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks(
104 yield 104 yield
105 105
106 @torch.no_grad() 106 @torch.no_grad()
107 def on_before_optimize(epoch: int): 107 def on_before_optimize(cycle: int):
108 if use_emb_decay: 108 if use_emb_decay:
109 params = [ 109 params = [
110 p 110 p
@@ -116,7 +116,9 @@ def textual_inversion_strategy_callbacks(
116 @torch.no_grad() 116 @torch.no_grad()
117 def on_after_optimize(w, lrs: dict[str, float]): 117 def on_after_optimize(w, lrs: dict[str, float]):
118 if ema_embeddings is not None: 118 if ema_embeddings is not None:
119 ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) 119 ema_embeddings.step(
120 text_encoder.text_model.embeddings.token_embedding.parameters()
121 )
120 122
121 if use_emb_decay and w is not None: 123 if use_emb_decay and w is not None:
122 lr = lrs["emb"] if "emb" in lrs else lrs["0"] 124 lr = lrs["emb"] if "emb" in lrs else lrs["0"]
@@ -124,7 +126,9 @@ def textual_inversion_strategy_callbacks(
124 126
125 if lambda_ != 0: 127 if lambda_ != 0:
126 norm = w[:, :].norm(dim=-1, keepdim=True) 128 norm = w[:, :].norm(dim=-1, keepdim=True)
127 w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) 129 w[:].add_(
130 (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)
131 )
128 132
129 def on_log(): 133 def on_log():
130 if ema_embeddings is not None: 134 if ema_embeddings is not None:
@@ -136,10 +140,10 @@ def textual_inversion_strategy_callbacks(
136 print(f"Saving checkpoint for step {step}...") 140 print(f"Saving checkpoint for step {step}...")
137 141
138 with ema_context(): 142 with ema_context():
139 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): 143 for token, ids in zip(placeholder_tokens, placeholder_token_ids):
140 text_encoder.text_model.embeddings.save_embed( 144 text_encoder.text_model.embeddings.save_embed(
141 ids, 145 ids,
142 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" 146 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin",
143 ) 147 )
144 148
145 @torch.no_grad() 149 @torch.no_grad()
@@ -183,7 +187,7 @@ def textual_inversion_prepare(
183 val_dataloader: Optional[DataLoader], 187 val_dataloader: Optional[DataLoader],
184 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 188 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
185 gradient_checkpointing: bool = False, 189 gradient_checkpointing: bool = False,
186 **kwargs 190 **kwargs,
187): 191):
188 weight_dtype = torch.float32 192 weight_dtype = torch.float32
189 if accelerator.state.mixed_precision == "fp16": 193 if accelerator.state.mixed_precision == "fp16":
@@ -191,8 +195,15 @@ def textual_inversion_prepare(
191 elif accelerator.state.mixed_precision == "bf16": 195 elif accelerator.state.mixed_precision == "bf16":
192 weight_dtype = torch.bfloat16 196 weight_dtype = torch.bfloat16
193 197
194 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 198 (
195 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) 199 text_encoder,
200 optimizer,
201 train_dataloader,
202 val_dataloader,
203 lr_scheduler,
204 ) = accelerator.prepare(
205 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
206 )
196 207
197 unet.to(accelerator.device, dtype=weight_dtype) 208 unet.to(accelerator.device, dtype=weight_dtype)
198 unet.requires_grad_(False) 209 unet.requires_grad_(False)