diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-15 13:11:11 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-15 13:11:11 +0200 |
| commit | 99b4dba56e3e1e434820d1221d561e90f1a6d30a (patch) | |
| tree | 717a4099e9ebfedec702060fed5ed12aaceb0094 /training/strategy | |
| parent | Added cycle LR decay (diff) | |
| download | textual-inversion-diff-99b4dba56e3e1e434820d1221d561e90f1a6d30a.tar.gz textual-inversion-diff-99b4dba56e3e1e434820d1221d561e90f1a6d30a.tar.bz2 textual-inversion-diff-99b4dba56e3e1e434820d1221d561e90f1a6d30a.zip | |
TI via LoRA
Diffstat (limited to 'training/strategy')
| -rw-r--r-- | training/strategy/lora.py | 4 | ||||
| -rw-r--r-- | training/strategy/ti.py | 9 |
2 files changed, 6 insertions, 7 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 1517ee8..48236fb 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -93,7 +93,7 @@ def lora_strategy_callbacks( | |||
| 93 | if use_emb_decay: | 93 | if use_emb_decay: |
| 94 | params = [ | 94 | params = [ |
| 95 | p | 95 | p |
| 96 | for p in text_encoder.text_model.embeddings.token_override_embedding.parameters() | 96 | for p in text_encoder.text_model.embeddings.parameters() |
| 97 | if p.grad is not None | 97 | if p.grad is not None |
| 98 | ] | 98 | ] |
| 99 | return torch.stack(params) if len(params) != 0 else None | 99 | return torch.stack(params) if len(params) != 0 else None |
| @@ -180,7 +180,7 @@ def lora_prepare( | |||
| 180 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 180 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
| 181 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 181 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) |
| 182 | 182 | ||
| 183 | text_encoder.text_model.embeddings.token_override_embedding.params.requires_grad_(True) | 183 | # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) |
| 184 | 184 | ||
| 185 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 185 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
| 186 | 186 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index ca7cc3d..49236c6 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -72,7 +72,7 @@ def textual_inversion_strategy_callbacks( | |||
| 72 | 72 | ||
| 73 | if use_ema: | 73 | if use_ema: |
| 74 | ema_embeddings = EMAModel( | 74 | ema_embeddings = EMAModel( |
| 75 | text_encoder.text_model.embeddings.token_override_embedding.parameters(), | 75 | text_encoder.text_model.embeddings.token_embedding.parameters(), |
| 76 | inv_gamma=ema_inv_gamma, | 76 | inv_gamma=ema_inv_gamma, |
| 77 | power=ema_power, | 77 | power=ema_power, |
| 78 | max_value=ema_max_decay, | 78 | max_value=ema_max_decay, |
| @@ -84,7 +84,7 @@ def textual_inversion_strategy_callbacks( | |||
| 84 | def ema_context(): | 84 | def ema_context(): |
| 85 | if ema_embeddings is not None: | 85 | if ema_embeddings is not None: |
| 86 | return ema_embeddings.apply_temporary( | 86 | return ema_embeddings.apply_temporary( |
| 87 | text_encoder.text_model.embeddings.token_override_embedding.parameters() | 87 | text_encoder.text_model.embeddings.token_embedding.parameters() |
| 88 | ) | 88 | ) |
| 89 | else: | 89 | else: |
| 90 | return nullcontext() | 90 | return nullcontext() |
| @@ -108,7 +108,7 @@ def textual_inversion_strategy_callbacks( | |||
| 108 | if use_emb_decay: | 108 | if use_emb_decay: |
| 109 | params = [ | 109 | params = [ |
| 110 | p | 110 | p |
| 111 | for p in text_encoder.text_model.embeddings.token_override_embedding.parameters() | 111 | for p in text_encoder.text_model.embeddings.token_embedding.parameters() |
| 112 | if p.grad is not None | 112 | if p.grad is not None |
| 113 | ] | 113 | ] |
| 114 | return torch.stack(params) if len(params) != 0 else None | 114 | return torch.stack(params) if len(params) != 0 else None |
| @@ -116,7 +116,7 @@ def textual_inversion_strategy_callbacks( | |||
| 116 | @torch.no_grad() | 116 | @torch.no_grad() |
| 117 | def on_after_optimize(w, lrs: dict[str, float]): | 117 | def on_after_optimize(w, lrs: dict[str, float]): |
| 118 | if ema_embeddings is not None: | 118 | if ema_embeddings is not None: |
| 119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.parameters()) | 119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) |
| 120 | 120 | ||
| 121 | if use_emb_decay and w is not None: | 121 | if use_emb_decay and w is not None: |
| 122 | lr = lrs["emb"] if "emb" in lrs else lrs["0"] | 122 | lr = lrs["emb"] if "emb" in lrs else lrs["0"] |
| @@ -203,7 +203,6 @@ def textual_inversion_prepare( | |||
| 203 | text_encoder.text_model.encoder.requires_grad_(False) | 203 | text_encoder.text_model.encoder.requires_grad_(False) |
| 204 | text_encoder.text_model.final_layer_norm.requires_grad_(False) | 204 | text_encoder.text_model.final_layer_norm.requires_grad_(False) |
| 205 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) | 205 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) |
| 206 | text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) | ||
| 207 | 206 | ||
| 208 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 207 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
| 209 | 208 | ||
