summaryrefslogtreecommitdiffstats
path: root/training/strategy
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy')
-rw-r--r--training/strategy/dreambooth.py3
-rw-r--r--training/strategy/ti.py7
2 files changed, 10 insertions, 0 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 3d1abf7..7e67589 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -154,6 +154,9 @@ def dreambooth_strategy_callbacks(
154 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) 154 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False)
155 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) 155 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)
156 156
157 text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_)
158 unet_.forward = MethodType(unet_.forward, unet_)
159
157 text_encoder_.text_model.embeddings.persist(False) 160 text_encoder_.text_model.embeddings.persist(False)
158 161
159 with ema_context(): 162 with ema_context():
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 7373982..f37dfb4 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -1,4 +1,5 @@
1from typing import Optional 1from typing import Optional
2from types import MethodType
2from functools import partial 3from functools import partial
3from contextlib import contextmanager, nullcontext 4from contextlib import contextmanager, nullcontext
4from pathlib import Path 5from pathlib import Path
@@ -139,6 +140,12 @@ def textual_inversion_strategy_callbacks(
139 def on_checkpoint(step, postfix): 140 def on_checkpoint(step, postfix):
140 print(f"Saving checkpoint for step {step}...") 141 print(f"Saving checkpoint for step {step}...")
141 142
143 if postfix == "end":
144 text_encoder_ = accelerator.unwrap_model(
145 text_encoder, keep_fp32_wrapper=False
146 )
147 text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_)
148
142 with ema_context(): 149 with ema_context():
143 for token, ids in zip(placeholder_tokens, placeholder_token_ids): 150 for token, ids in zip(placeholder_tokens, placeholder_token_ids):
144 text_encoder.text_model.embeddings.save_embed( 151 text_encoder.text_model.embeddings.save_embed(