diff options
author | Volpeon <git@volpeon.ink> | 2023-04-16 19:03:25 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-16 19:03:25 +0200 |
commit | 71f4a40bb48be4f2759ba2d83faff39691cb2955 (patch) | |
tree | 29c704ca549a4c4323403b6cbb0e62f54040ae22 /training/strategy/lora.py | |
parent | Added option to use constant LR on cycles > 1 (diff) | |
download | textual-inversion-diff-71f4a40bb48be4f2759ba2d83faff39691cb2955.tar.gz textual-inversion-diff-71f4a40bb48be4f2759ba2d83faff39691cb2955.tar.bz2 textual-inversion-diff-71f4a40bb48be4f2759ba2d83faff39691cb2955.zip |
Improved automation caps
Diffstat (limited to 'training/strategy/lora.py')
-rw-r--r-- | training/strategy/lora.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 48236fb..5c3012e 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -146,11 +146,11 @@ def lora_strategy_callbacks( | |||
146 | torch.cuda.empty_cache() | 146 | torch.cuda.empty_cache() |
147 | 147 | ||
148 | @torch.no_grad() | 148 | @torch.no_grad() |
149 | def on_sample(step): | 149 | def on_sample(cycle, step): |
150 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 150 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) |
151 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | 151 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) |
152 | 152 | ||
153 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | 153 | save_samples_(cycle=cycle, step=step, unet=unet_, text_encoder=text_encoder_) |
154 | 154 | ||
155 | del unet_, text_encoder_ | 155 | del unet_, text_encoder_ |
156 | 156 | ||