diff options
| -rw-r--r-- | training/strategy/ti.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 33f5fb9..1b5adab 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -64,7 +64,7 @@ def textual_inversion_strategy_callbacks( | |||
| 64 | ) | 64 | ) |
| 65 | 65 | ||
| 66 | def on_accum_model(): | 66 | def on_accum_model(): |
| 67 | return text_encoder.text_model.embeddings.overlay | 67 | return text_encoder.text_model.embeddings |
| 68 | 68 | ||
| 69 | @contextmanager | 69 | @contextmanager |
| 70 | def on_train(epoch: int): | 70 | def on_train(epoch: int): |
| @@ -145,6 +145,8 @@ def textual_inversion_prepare( | |||
| 145 | 145 | ||
| 146 | text_encoder.text_model.encoder.requires_grad_(False) | 146 | text_encoder.text_model.encoder.requires_grad_(False) |
| 147 | text_encoder.text_model.final_layer_norm.requires_grad_(False) | 147 | text_encoder.text_model.final_layer_norm.requires_grad_(False) |
| 148 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) | ||
| 149 | text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) | ||
| 148 | text_encoder.eval() | 150 | text_encoder.eval() |
| 149 | 151 | ||
| 150 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} | 152 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} |
