diff options
Diffstat (limited to 'training/strategy/ti.py')
-rw-r--r-- | training/strategy/ti.py | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 7373982..f37dfb4 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -1,4 +1,5 @@ | |||
1 | from typing import Optional | 1 | from typing import Optional |
2 | from types import MethodType | ||
2 | from functools import partial | 3 | from functools import partial |
3 | from contextlib import contextmanager, nullcontext | 4 | from contextlib import contextmanager, nullcontext |
4 | from pathlib import Path | 5 | from pathlib import Path |
@@ -139,6 +140,12 @@ def textual_inversion_strategy_callbacks( | |||
139 | def on_checkpoint(step, postfix): | 140 | def on_checkpoint(step, postfix): |
140 | print(f"Saving checkpoint for step {step}...") | 141 | print(f"Saving checkpoint for step {step}...") |
141 | 142 | ||
143 | if postfix == "end": | ||
144 | text_encoder_ = accelerator.unwrap_model( | ||
145 | text_encoder, keep_fp32_wrapper=False | ||
146 | ) | ||
147 | text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) | ||
148 | |||
142 | with ema_context(): | 149 | with ema_context(): |
143 | for token, ids in zip(placeholder_tokens, placeholder_token_ids): | 150 | for token, ids in zip(placeholder_tokens, placeholder_token_ids): |
144 | text_encoder.text_model.embeddings.save_embed( | 151 | text_encoder.text_model.embeddings.save_embed( |