summaryrefslogtreecommitdiffstats
path: root/training/strategy/ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-07 14:14:00 +0200
committerVolpeon <git@volpeon.ink>2023-04-07 14:14:00 +0200
commit21d70916f66e74a87c631a06b70774954b085b48 (patch)
treed1b443b9270f45ae6936f3acb565f767c7c65b1f /training/strategy/ti.py
parentRun PTI only if placeholder tokens arg isn't empty (diff)
downloadtextual-inversion-diff-21d70916f66e74a87c631a06b70774954b085b48.tar.gz
textual-inversion-diff-21d70916f66e74a87c631a06b70774954b085b48.tar.bz2
textual-inversion-diff-21d70916f66e74a87c631a06b70774954b085b48.zip
Fix
Diffstat (limited to 'training/strategy/ti.py')
-rw-r--r--training/strategy/ti.py6
1 files changed, 2 insertions, 4 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 55e9934..6a637c3 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -89,16 +89,15 @@ def textual_inversion_strategy_callbacks(
89 else: 89 else:
90 return nullcontext() 90 return nullcontext()
91 91
92 def on_accum_model():
93 return text_encoder.text_model.embeddings.token_override_embedding.params
94
95 @contextmanager 92 @contextmanager
96 def on_train(epoch: int): 93 def on_train(epoch: int):
94 text_encoder.text_model.embeddings.token_override_embedding.params.train()
97 tokenizer.train() 95 tokenizer.train()
98 yield 96 yield
99 97
100 @contextmanager 98 @contextmanager
101 def on_eval(): 99 def on_eval():
100 text_encoder.text_model.embeddings.token_override_embedding.params.eval()
102 tokenizer.eval() 101 tokenizer.eval()
103 102
104 with ema_context(): 103 with ema_context():
@@ -166,7 +165,6 @@ def textual_inversion_strategy_callbacks(
166 torch.cuda.empty_cache() 165 torch.cuda.empty_cache()
167 166
168 return TrainingCallbacks( 167 return TrainingCallbacks(
169 on_accum_model=on_accum_model,
170 on_train=on_train, 168 on_train=on_train,
171 on_eval=on_eval, 169 on_eval=on_eval,
172 on_before_optimize=on_before_optimize, 170 on_before_optimize=on_before_optimize,