summaryrefslogtreecommitdiffstats
path: root/training/strategy
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy')
-rw-r--r--training/strategy/dreambooth.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index bd853e2..3d1abf7 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -98,7 +98,6 @@ def dreambooth_strategy_callbacks(
98 98
99 if cycle < train_text_encoder_cycles: 99 if cycle < train_text_encoder_cycles:
100 text_encoder.train() 100 text_encoder.train()
101 tokenizer.train()
102 101
103 yield 102 yield
104 103
@@ -155,6 +154,8 @@ def dreambooth_strategy_callbacks(
155 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) 154 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False)
156 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) 155 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)
157 156
157 text_encoder_.text_model.embeddings.persist(False)
158
158 with ema_context(): 159 with ema_context():
159 pipeline = VlpnStableDiffusion( 160 pipeline = VlpnStableDiffusion(
160 text_encoder=text_encoder_, 161 text_encoder=text_encoder_,