diff options
Diffstat (limited to 'training/strategy/dreambooth.py')
-rw-r--r-- | training/strategy/dreambooth.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 8aaed3a..d697554 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -144,8 +144,8 @@ def dreambooth_strategy_callbacks( | |||
144 | 144 | ||
145 | print("Saving model...") | 145 | print("Saving model...") |
146 | 146 | ||
147 | unet_ = accelerator.unwrap_model(unet) | 147 | unet_ = accelerator.unwrap_model(unet, False) |
148 | text_encoder_ = accelerator.unwrap_model(text_encoder) | 148 | text_encoder_ = accelerator.unwrap_model(text_encoder, False) |
149 | 149 | ||
150 | with ema_context(): | 150 | with ema_context(): |
151 | pipeline = VlpnStableDiffusion( | 151 | pipeline = VlpnStableDiffusion( |
@@ -167,8 +167,8 @@ def dreambooth_strategy_callbacks( | |||
167 | @torch.no_grad() | 167 | @torch.no_grad() |
168 | def on_sample(step): | 168 | def on_sample(step): |
169 | with ema_context(): | 169 | with ema_context(): |
170 | unet_ = accelerator.unwrap_model(unet) | 170 | unet_ = accelerator.unwrap_model(unet, False) |
171 | text_encoder_ = accelerator.unwrap_model(text_encoder) | 171 | text_encoder_ = accelerator.unwrap_model(text_encoder, False) |
172 | 172 | ||
173 | orig_unet_dtype = unet_.dtype | 173 | orig_unet_dtype = unet_.dtype |
174 | orig_text_encoder_dtype = text_encoder_.dtype | 174 | orig_text_encoder_dtype = text_encoder_.dtype |