diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-17 12:01:55 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-17 12:01:55 +0100 |
| commit | ca910abc5f50a559fa0769e1c21621464d25eaac (patch) | |
| tree | 85fea4b5243f2c41555c66dd0b7ee04237440aaf | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-ca910abc5f50a559fa0769e1c21621464d25eaac.tar.gz textual-inversion-diff-ca910abc5f50a559fa0769e1c21621464d25eaac.tar.bz2 textual-inversion-diff-ca910abc5f50a559fa0769e1c21621464d25eaac.zip | |
Fix
| -rw-r--r-- | training/strategy/ti.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 1af834b..ba78b98 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -119,10 +119,11 @@ def textual_inversion_strategy_callbacks( | |||
| 119 | 119 | ||
| 120 | def on_after_optimize(lr: float): | 120 | def on_after_optimize(lr: float): |
| 121 | if use_emb_decay: | 121 | if use_emb_decay: |
| 122 | text_encoder.text_model.embeddings.normalize( | 122 | with torch.no_grad(): |
| 123 | emb_decay_target, | 123 | text_encoder.text_model.embeddings.normalize( |
| 124 | min(1.0, emb_decay * lr) | 124 | emb_decay_target, |
| 125 | ) | 125 | min(1.0, emb_decay * lr) |
| 126 | ) | ||
| 126 | 127 | ||
| 127 | if ema_embeddings is not None: | 128 | if ema_embeddings is not None: |
| 128 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 129 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) |
