summaryrefslogtreecommitdiffstats
path: root/training/strategy
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy')
-rw-r--r--training/strategy/dreambooth.py2
-rw-r--r--training/strategy/lora.py2
-rw-r--r--training/strategy/ti.py2
3 files changed, 3 insertions, 3 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 7cdfc7f..fa51bc7 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -149,7 +149,7 @@ def dreambooth_strategy_callbacks(
149 if torch.cuda.is_available(): 149 if torch.cuda.is_available():
150 torch.cuda.empty_cache() 150 torch.cuda.empty_cache()
151 151
152 @on_eval() 152 @torch.no_grad()
153 def on_sample(step): 153 def on_sample(step):
154 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) 154 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
155 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) 155 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 0f72a17..73ec8f2 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -146,7 +146,7 @@ def lora_strategy_callbacks(
146 if torch.cuda.is_available(): 146 if torch.cuda.is_available():
147 torch.cuda.empty_cache() 147 torch.cuda.empty_cache()
148 148
149 @on_eval() 149 @torch.no_grad()
150 def on_sample(step): 150 def on_sample(step):
151 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) 151 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
152 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) 152 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index f00045f..08af89d 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -142,7 +142,7 @@ def textual_inversion_strategy_callbacks(
142 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" 142 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin"
143 ) 143 )
144 144
145 @on_eval() 145 @torch.no_grad()
146 def on_sample(step): 146 def on_sample(step):
147 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) 147 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
148 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) 148 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)