summaryrefslogtreecommitdiffstats
path: root/training/strategy/ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy/ti.py')
-rw-r--r--training/strategy/ti.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 7373982..f37dfb4 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -1,4 +1,5 @@
1from typing import Optional 1from typing import Optional
2from types import MethodType
2from functools import partial 3from functools import partial
3from contextlib import contextmanager, nullcontext 4from contextlib import contextmanager, nullcontext
4from pathlib import Path 5from pathlib import Path
@@ -139,6 +140,12 @@ def textual_inversion_strategy_callbacks(
139 def on_checkpoint(step, postfix): 140 def on_checkpoint(step, postfix):
140 print(f"Saving checkpoint for step {step}...") 141 print(f"Saving checkpoint for step {step}...")
141 142
143 if postfix == "end":
144 text_encoder_ = accelerator.unwrap_model(
145 text_encoder, keep_fp32_wrapper=False
146 )
147 text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_)
148
142 with ema_context(): 149 with ema_context():
143 for token, ids in zip(placeholder_tokens, placeholder_token_ids): 150 for token, ids in zip(placeholder_tokens, placeholder_token_ids):
144 text_encoder.text_model.embeddings.save_embed( 151 text_encoder.text_model.embeddings.save_embed(