summaryrefslogtreecommitdiffstats
path: root/training/strategy
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-06-24 21:00:29 +0200
committerVolpeon <git@volpeon.ink>2023-06-24 21:00:29 +0200
commit12b9aca96a36dd77a6b2b99bbc1743d87a7ce733 (patch)
treeb0fcf8ad1d26c40d784ddc154622f6d01ecac082 /training/strategy
parentNew loss scaling (diff)
downloadtextual-inversion-diff-12b9aca96a36dd77a6b2b99bbc1743d87a7ce733.tar.gz
textual-inversion-diff-12b9aca96a36dd77a6b2b99bbc1743d87a7ce733.tar.bz2
textual-inversion-diff-12b9aca96a36dd77a6b2b99bbc1743d87a7ce733.zip
Update
Diffstat (limited to 'training/strategy')
-rw-r--r--training/strategy/dreambooth.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index bd853e2..3d1abf7 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -98,7 +98,6 @@ def dreambooth_strategy_callbacks(
98 98
99 if cycle < train_text_encoder_cycles: 99 if cycle < train_text_encoder_cycles:
100 text_encoder.train() 100 text_encoder.train()
101 tokenizer.train()
102 101
103 yield 102 yield
104 103
@@ -155,6 +154,8 @@ def dreambooth_strategy_callbacks(
155 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) 154 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False)
156 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) 155 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)
157 156
157 text_encoder_.text_model.embeddings.persist(False)
158
158 with ema_context(): 159 with ema_context():
159 pipeline = VlpnStableDiffusion( 160 pipeline = VlpnStableDiffusion(
160 text_encoder=text_encoder_, 161 text_encoder=text_encoder_,