summaryrefslogtreecommitdiffstats
path: root/training/strategy/dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy/dreambooth.py')
-rw-r--r--training/strategy/dreambooth.py35
1 files changed, 27 insertions, 8 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index e88bf90..b4c77f3 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -61,14 +61,11 @@ def dreambooth_strategy_callbacks(
61 save_samples_ = partial( 61 save_samples_ = partial(
62 save_samples, 62 save_samples,
63 accelerator=accelerator, 63 accelerator=accelerator,
64 unet=unet,
65 text_encoder=text_encoder,
66 tokenizer=tokenizer, 64 tokenizer=tokenizer,
67 vae=vae, 65 vae=vae,
68 sample_scheduler=sample_scheduler, 66 sample_scheduler=sample_scheduler,
69 train_dataloader=train_dataloader, 67 train_dataloader=train_dataloader,
70 val_dataloader=val_dataloader, 68 val_dataloader=val_dataloader,
71 dtype=weight_dtype,
72 output_dir=sample_output_dir, 69 output_dir=sample_output_dir,
73 seed=seed, 70 seed=seed,
74 batch_size=sample_batch_size, 71 batch_size=sample_batch_size,
@@ -94,7 +91,7 @@ def dreambooth_strategy_callbacks(
94 else: 91 else:
95 return nullcontext() 92 return nullcontext()
96 93
97 def on_model(): 94 def on_accum_model():
98 return unet 95 return unet
99 96
100 def on_prepare(): 97 def on_prepare():
@@ -172,11 +169,29 @@ def dreambooth_strategy_callbacks(
172 @torch.no_grad() 169 @torch.no_grad()
173 def on_sample(step): 170 def on_sample(step):
174 with ema_context(): 171 with ema_context():
175 save_samples_(step=step) 172 unet_ = accelerator.unwrap_model(unet)
173 text_encoder_ = accelerator.unwrap_model(text_encoder)
174
175 orig_unet_dtype = unet_.dtype
176 orig_text_encoder_dtype = text_encoder_.dtype
177
178 unet_.to(dtype=weight_dtype)
179 text_encoder_.to(dtype=weight_dtype)
180
181 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_)
182
183 unet_.to(dtype=orig_unet_dtype)
184 text_encoder_.to(dtype=orig_text_encoder_dtype)
185
186 del unet_
187 del text_encoder_
188
189 if torch.cuda.is_available():
190 torch.cuda.empty_cache()
176 191
177 return TrainingCallbacks( 192 return TrainingCallbacks(
178 on_prepare=on_prepare, 193 on_prepare=on_prepare,
179 on_model=on_model, 194 on_accum_model=on_accum_model,
180 on_train=on_train, 195 on_train=on_train,
181 on_eval=on_eval, 196 on_eval=on_eval,
182 on_before_optimize=on_before_optimize, 197 on_before_optimize=on_before_optimize,
@@ -191,9 +206,13 @@ def dreambooth_prepare(
191 accelerator: Accelerator, 206 accelerator: Accelerator,
192 text_encoder: CLIPTextModel, 207 text_encoder: CLIPTextModel,
193 unet: UNet2DConditionModel, 208 unet: UNet2DConditionModel,
194 *args 209 optimizer: torch.optim.Optimizer,
210 train_dataloader: DataLoader,
211 val_dataloader: Optional[DataLoader],
212 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
213 **kwargs
195): 214):
196 return accelerator.prepare(text_encoder, unet, *args) 215 return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({})
197 216
198 217
199dreambooth_strategy = TrainingStrategy( 218dreambooth_strategy = TrainingStrategy(