diff options
author | Volpeon <git@volpeon.ink> | 2023-04-16 19:03:25 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-16 19:03:25 +0200 |
commit | 71f4a40bb48be4f2759ba2d83faff39691cb2955 (patch) | |
tree | 29c704ca549a4c4323403b6cbb0e62f54040ae22 /training/strategy/dreambooth.py | |
parent | Added option to use constant LR on cycles > 1 (diff) | |
download | textual-inversion-diff-71f4a40bb48be4f2759ba2d83faff39691cb2955.tar.gz textual-inversion-diff-71f4a40bb48be4f2759ba2d83faff39691cb2955.tar.bz2 textual-inversion-diff-71f4a40bb48be4f2759ba2d83faff39691cb2955.zip |
Improved automation caps
Diffstat (limited to 'training/strategy/dreambooth.py')
-rw-r--r-- | training/strategy/dreambooth.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 4ae28b7..e6fcc89 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -148,7 +148,7 @@ def dreambooth_strategy_callbacks( | |||
148 | torch.cuda.empty_cache() | 148 | torch.cuda.empty_cache() |
149 | 149 | ||
150 | @torch.no_grad() | 150 | @torch.no_grad() |
151 | def on_sample(step): | 151 | def on_sample(cycle, step): |
152 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 152 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) |
153 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | 153 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) |
154 | 154 | ||
@@ -158,7 +158,7 @@ def dreambooth_strategy_callbacks( | |||
158 | unet_.to(dtype=weight_dtype) | 158 | unet_.to(dtype=weight_dtype) |
159 | text_encoder_.to(dtype=weight_dtype) | 159 | text_encoder_.to(dtype=weight_dtype) |
160 | 160 | ||
161 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | 161 | save_samples_(cycle=cycle, step=step, unet=unet_, text_encoder=text_encoder_) |
162 | 162 | ||
163 | unet_.to(dtype=orig_unet_dtype) | 163 | unet_.to(dtype=orig_unet_dtype) |
164 | text_encoder_.to(dtype=orig_text_encoder_dtype) | 164 | text_encoder_.to(dtype=orig_text_encoder_dtype) |