diff options
Diffstat (limited to 'training/strategy/lora.py')
-rw-r--r-- | training/strategy/lora.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 48236fb..5c3012e 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -146,11 +146,11 @@ def lora_strategy_callbacks( | |||
146 | torch.cuda.empty_cache() | 146 | torch.cuda.empty_cache() |
147 | 147 | ||
148 | @torch.no_grad() | 148 | @torch.no_grad() |
149 | def on_sample(step): | 149 | def on_sample(cycle, step): |
150 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 150 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) |
151 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | 151 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) |
152 | 152 | ||
153 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | 153 | save_samples_(cycle=cycle, step=step, unet=unet_, text_encoder=text_encoder_) |
154 | 154 | ||
155 | del unet_, text_encoder_ | 155 | del unet_, text_encoder_ |
156 | 156 | ||