diff options
Diffstat (limited to 'training/strategy/dreambooth.py')
| -rw-r--r-- | training/strategy/dreambooth.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index e5e84c8..28fccff 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -137,8 +137,8 @@ def dreambooth_strategy_callbacks( | |||
| 137 | 137 | ||
| 138 | print("Saving model...") | 138 | print("Saving model...") |
| 139 | 139 | ||
| 140 | unet_ = accelerator.unwrap_model(unet, False) | 140 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
| 141 | text_encoder_ = accelerator.unwrap_model(text_encoder, False) | 141 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
| 142 | 142 | ||
| 143 | with ema_context(): | 143 | with ema_context(): |
| 144 | pipeline = VlpnStableDiffusion( | 144 | pipeline = VlpnStableDiffusion( |
| @@ -160,8 +160,8 @@ def dreambooth_strategy_callbacks( | |||
| 160 | @torch.no_grad() | 160 | @torch.no_grad() |
| 161 | def on_sample(step): | 161 | def on_sample(step): |
| 162 | with ema_context(): | 162 | with ema_context(): |
| 163 | unet_ = accelerator.unwrap_model(unet, False) | 163 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) |
| 164 | text_encoder_ = accelerator.unwrap_model(text_encoder, False) | 164 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) |
| 165 | 165 | ||
| 166 | orig_unet_dtype = unet_.dtype | 166 | orig_unet_dtype = unet_.dtype |
| 167 | orig_text_encoder_dtype = text_encoder_.dtype | 167 | orig_text_encoder_dtype = text_encoder_.dtype |
