summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--training/functional.py3
-rw-r--r--training/strategy/dreambooth.py27
-rw-r--r--training/strategy/lora.py2
-rw-r--r--training/strategy/ti.py25
4 files changed, 29 insertions, 28 deletions
diff --git a/training/functional.py b/training/functional.py
index 46d25f6..ff6d3a9 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -695,5 +695,8 @@ def train(
695 callbacks=callbacks, 695 callbacks=callbacks,
696 ) 696 )
697 697
698 accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)
699 accelerator.unwrap_model(unet, keep_fp32_wrapper=False)
700
698 accelerator.end_training() 701 accelerator.end_training()
699 accelerator.free_memory() 702 accelerator.free_memory()
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 42624cd..7cdfc7f 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -113,7 +113,7 @@ def dreambooth_strategy_callbacks(
113 accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) 113 accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm)
114 114
115 @torch.no_grad() 115 @torch.no_grad()
116 def on_after_optimize(_, lr: float): 116 def on_after_optimize(_, lrs: dict[str, float]):
117 if ema_unet is not None: 117 if ema_unet is not None:
118 ema_unet.step(unet.parameters()) 118 ema_unet.step(unet.parameters())
119 119
@@ -149,25 +149,24 @@ def dreambooth_strategy_callbacks(
149 if torch.cuda.is_available(): 149 if torch.cuda.is_available():
150 torch.cuda.empty_cache() 150 torch.cuda.empty_cache()
151 151
152 @torch.no_grad() 152 @on_eval()
153 def on_sample(step): 153 def on_sample(step):
154 with ema_context(): 154 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
155 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) 155 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
156 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
157 156
158 orig_unet_dtype = unet_.dtype 157 orig_unet_dtype = unet_.dtype
159 orig_text_encoder_dtype = text_encoder_.dtype 158 orig_text_encoder_dtype = text_encoder_.dtype
160 159
161 unet_.to(dtype=weight_dtype) 160 unet_.to(dtype=weight_dtype)
162 text_encoder_.to(dtype=weight_dtype) 161 text_encoder_.to(dtype=weight_dtype)
163 162
164 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) 163 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_)
165 164
166 unet_.to(dtype=orig_unet_dtype) 165 unet_.to(dtype=orig_unet_dtype)
167 text_encoder_.to(dtype=orig_text_encoder_dtype) 166 text_encoder_.to(dtype=orig_text_encoder_dtype)
168 167
169 del unet_ 168 del unet_
170 del text_encoder_ 169 del text_encoder_
171 170
172 if torch.cuda.is_available(): 171 if torch.cuda.is_available():
173 torch.cuda.empty_cache() 172 torch.cuda.empty_cache()
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 73ec8f2..0f72a17 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -146,7 +146,7 @@ def lora_strategy_callbacks(
146 if torch.cuda.is_available(): 146 if torch.cuda.is_available():
147 torch.cuda.empty_cache() 147 torch.cuda.empty_cache()
148 148
149 @torch.no_grad() 149 @on_eval()
150 def on_sample(step): 150 def on_sample(step):
151 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) 151 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
152 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) 152 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 363c3f9..f00045f 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -142,25 +142,24 @@ def textual_inversion_strategy_callbacks(
142 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" 142 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin"
143 ) 143 )
144 144
145 @torch.no_grad() 145 @on_eval()
146 def on_sample(step): 146 def on_sample(step):
147 with ema_context(): 147 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
148 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) 148 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
149 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
150 149
151 orig_unet_dtype = unet_.dtype 150 orig_unet_dtype = unet_.dtype
152 orig_text_encoder_dtype = text_encoder_.dtype 151 orig_text_encoder_dtype = text_encoder_.dtype
153 152
154 unet_.to(dtype=weight_dtype) 153 unet_.to(dtype=weight_dtype)
155 text_encoder_.to(dtype=weight_dtype) 154 text_encoder_.to(dtype=weight_dtype)
156 155
157 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) 156 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_)
158 157
159 unet_.to(dtype=orig_unet_dtype) 158 unet_.to(dtype=orig_unet_dtype)
160 text_encoder_.to(dtype=orig_text_encoder_dtype) 159 text_encoder_.to(dtype=orig_text_encoder_dtype)
161 160
162 del unet_ 161 del unet_
163 del text_encoder_ 162 del text_encoder_
164 163
165 if torch.cuda.is_available(): 164 if torch.cuda.is_available():
166 torch.cuda.empty_cache() 165 torch.cuda.empty_cache()