diff options
author | Volpeon <git@volpeon.ink> | 2023-04-15 13:11:11 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-15 13:11:11 +0200 |
commit | 99b4dba56e3e1e434820d1221d561e90f1a6d30a (patch) | |
tree | 717a4099e9ebfedec702060fed5ed12aaceb0094 /training/strategy/lora.py | |
parent | Added cycle LR decay (diff) | |
download | textual-inversion-diff-99b4dba56e3e1e434820d1221d561e90f1a6d30a.tar.gz textual-inversion-diff-99b4dba56e3e1e434820d1221d561e90f1a6d30a.tar.bz2 textual-inversion-diff-99b4dba56e3e1e434820d1221d561e90f1a6d30a.zip |
TI via LoRA
Diffstat (limited to 'training/strategy/lora.py')
-rw-r--r-- | training/strategy/lora.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 1517ee8..48236fb 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -93,7 +93,7 @@ def lora_strategy_callbacks( | |||
93 | if use_emb_decay: | 93 | if use_emb_decay: |
94 | params = [ | 94 | params = [ |
95 | p | 95 | p |
96 | for p in text_encoder.text_model.embeddings.token_override_embedding.parameters() | 96 | for p in text_encoder.text_model.embeddings.parameters() |
97 | if p.grad is not None | 97 | if p.grad is not None |
98 | ] | 98 | ] |
99 | return torch.stack(params) if len(params) != 0 else None | 99 | return torch.stack(params) if len(params) != 0 else None |
@@ -180,7 +180,7 @@ def lora_prepare( | |||
180 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 180 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
181 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 181 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) |
182 | 182 | ||
183 | text_encoder.text_model.embeddings.token_override_embedding.params.requires_grad_(True) | 183 | # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) |
184 | 184 | ||
185 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 185 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
186 | 186 | ||