diff options
Diffstat (limited to 'training/strategy')
-rw-r--r-- | training/strategy/ti.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 33f5fb9..1b5adab 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -64,7 +64,7 @@ def textual_inversion_strategy_callbacks( | |||
64 | ) | 64 | ) |
65 | 65 | ||
66 | def on_accum_model(): | 66 | def on_accum_model(): |
67 | return text_encoder.text_model.embeddings.overlay | 67 | return text_encoder.text_model.embeddings |
68 | 68 | ||
69 | @contextmanager | 69 | @contextmanager |
70 | def on_train(epoch: int): | 70 | def on_train(epoch: int): |
@@ -145,6 +145,8 @@ def textual_inversion_prepare( | |||
145 | 145 | ||
146 | text_encoder.text_model.encoder.requires_grad_(False) | 146 | text_encoder.text_model.encoder.requires_grad_(False) |
147 | text_encoder.text_model.final_layer_norm.requires_grad_(False) | 147 | text_encoder.text_model.final_layer_norm.requires_grad_(False) |
148 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) | ||
149 | text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) | ||
148 | text_encoder.eval() | 150 | text_encoder.eval() |
149 | 151 | ||
150 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} | 152 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} |