diff options
author | Volpeon <git@volpeon.ink> | 2023-06-21 13:28:49 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-06-21 13:28:49 +0200 |
commit | 8364ce697ddf6117fdd4f7222832d546d63880de (patch) | |
tree | 152c99815bbd8b2659d0dabe63c98f63151c97c2 /training/strategy/ti.py | |
parent | Fix LoRA training with DAdan (diff) | |
download | textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.gz textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.bz2 textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.zip |
Update
Diffstat (limited to 'training/strategy/ti.py')
-rw-r--r-- | training/strategy/ti.py | 27 |
1 files changed, 19 insertions, 8 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6bc1d7d..7373982 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks( | |||
104 | yield | 104 | yield |
105 | 105 | ||
106 | @torch.no_grad() | 106 | @torch.no_grad() |
107 | def on_before_optimize(epoch: int): | 107 | def on_before_optimize(cycle: int): |
108 | if use_emb_decay: | 108 | if use_emb_decay: |
109 | params = [ | 109 | params = [ |
110 | p | 110 | p |
@@ -116,7 +116,9 @@ def textual_inversion_strategy_callbacks( | |||
116 | @torch.no_grad() | 116 | @torch.no_grad() |
117 | def on_after_optimize(w, lrs: dict[str, float]): | 117 | def on_after_optimize(w, lrs: dict[str, float]): |
118 | if ema_embeddings is not None: | 118 | if ema_embeddings is not None: |
119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) | 119 | ema_embeddings.step( |
120 | text_encoder.text_model.embeddings.token_embedding.parameters() | ||
121 | ) | ||
120 | 122 | ||
121 | if use_emb_decay and w is not None: | 123 | if use_emb_decay and w is not None: |
122 | lr = lrs["emb"] if "emb" in lrs else lrs["0"] | 124 | lr = lrs["emb"] if "emb" in lrs else lrs["0"] |
@@ -124,7 +126,9 @@ def textual_inversion_strategy_callbacks( | |||
124 | 126 | ||
125 | if lambda_ != 0: | 127 | if lambda_ != 0: |
126 | norm = w[:, :].norm(dim=-1, keepdim=True) | 128 | norm = w[:, :].norm(dim=-1, keepdim=True) |
127 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | 129 | w[:].add_( |
130 | (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) | ||
131 | ) | ||
128 | 132 | ||
129 | def on_log(): | 133 | def on_log(): |
130 | if ema_embeddings is not None: | 134 | if ema_embeddings is not None: |
@@ -136,10 +140,10 @@ def textual_inversion_strategy_callbacks( | |||
136 | print(f"Saving checkpoint for step {step}...") | 140 | print(f"Saving checkpoint for step {step}...") |
137 | 141 | ||
138 | with ema_context(): | 142 | with ema_context(): |
139 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): | 143 | for token, ids in zip(placeholder_tokens, placeholder_token_ids): |
140 | text_encoder.text_model.embeddings.save_embed( | 144 | text_encoder.text_model.embeddings.save_embed( |
141 | ids, | 145 | ids, |
142 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" | 146 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin", |
143 | ) | 147 | ) |
144 | 148 | ||
145 | @torch.no_grad() | 149 | @torch.no_grad() |
@@ -183,7 +187,7 @@ def textual_inversion_prepare( | |||
183 | val_dataloader: Optional[DataLoader], | 187 | val_dataloader: Optional[DataLoader], |
184 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 188 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
185 | gradient_checkpointing: bool = False, | 189 | gradient_checkpointing: bool = False, |
186 | **kwargs | 190 | **kwargs, |
187 | ): | 191 | ): |
188 | weight_dtype = torch.float32 | 192 | weight_dtype = torch.float32 |
189 | if accelerator.state.mixed_precision == "fp16": | 193 | if accelerator.state.mixed_precision == "fp16": |
@@ -191,8 +195,15 @@ def textual_inversion_prepare( | |||
191 | elif accelerator.state.mixed_precision == "bf16": | 195 | elif accelerator.state.mixed_precision == "bf16": |
192 | weight_dtype = torch.bfloat16 | 196 | weight_dtype = torch.bfloat16 |
193 | 197 | ||
194 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 198 | ( |
195 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 199 | text_encoder, |
200 | optimizer, | ||
201 | train_dataloader, | ||
202 | val_dataloader, | ||
203 | lr_scheduler, | ||
204 | ) = accelerator.prepare( | ||
205 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
206 | ) | ||
196 | 207 | ||
197 | unet.to(accelerator.device, dtype=weight_dtype) | 208 | unet.to(accelerator.device, dtype=weight_dtype) |
198 | unet.requires_grad_(False) | 209 | unet.requires_grad_(False) |