diff options
author | Volpeon <git@volpeon.ink> | 2023-04-07 14:14:00 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-07 14:14:00 +0200 |
commit | 21d70916f66e74a87c631a06b70774954b085b48 (patch) | |
tree | d1b443b9270f45ae6936f3acb565f767c7c65b1f /training/strategy/ti.py | |
parent | Run PTI only if placeholder tokens arg isn't empty (diff) | |
download | textual-inversion-diff-21d70916f66e74a87c631a06b70774954b085b48.tar.gz textual-inversion-diff-21d70916f66e74a87c631a06b70774954b085b48.tar.bz2 textual-inversion-diff-21d70916f66e74a87c631a06b70774954b085b48.zip |
Fix
Diffstat (limited to 'training/strategy/ti.py')
-rw-r--r-- | training/strategy/ti.py | 6 |
1 files changed, 2 insertions, 4 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 55e9934..6a637c3 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -89,16 +89,15 @@ def textual_inversion_strategy_callbacks( | |||
89 | else: | 89 | else: |
90 | return nullcontext() | 90 | return nullcontext() |
91 | 91 | ||
92 | def on_accum_model(): | ||
93 | return text_encoder.text_model.embeddings.token_override_embedding.params | ||
94 | |||
95 | @contextmanager | 92 | @contextmanager |
96 | def on_train(epoch: int): | 93 | def on_train(epoch: int): |
94 | text_encoder.text_model.embeddings.token_override_embedding.params.train() | ||
97 | tokenizer.train() | 95 | tokenizer.train() |
98 | yield | 96 | yield |
99 | 97 | ||
100 | @contextmanager | 98 | @contextmanager |
101 | def on_eval(): | 99 | def on_eval(): |
100 | text_encoder.text_model.embeddings.token_override_embedding.params.eval() | ||
102 | tokenizer.eval() | 101 | tokenizer.eval() |
103 | 102 | ||
104 | with ema_context(): | 103 | with ema_context(): |
@@ -166,7 +165,6 @@ def textual_inversion_strategy_callbacks( | |||
166 | torch.cuda.empty_cache() | 165 | torch.cuda.empty_cache() |
167 | 166 | ||
168 | return TrainingCallbacks( | 167 | return TrainingCallbacks( |
169 | on_accum_model=on_accum_model, | ||
170 | on_train=on_train, | 168 | on_train=on_train, |
171 | on_eval=on_eval, | 169 | on_eval=on_eval, |
172 | on_before_optimize=on_before_optimize, | 170 | on_before_optimize=on_before_optimize, |