summaryrefslogtreecommitdiffstats
path: root/training/strategy
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy')
-rw-r--r--training/strategy/dreambooth.py4
-rw-r--r--training/strategy/ti.py5
2 files changed, 3 insertions, 6 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 1277939..e88bf90 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -193,9 +193,7 @@ def dreambooth_prepare(
193 unet: UNet2DConditionModel, 193 unet: UNet2DConditionModel,
194 *args 194 *args
195): 195):
196 prep = [text_encoder, unet] + list(args) 196 return accelerator.prepare(text_encoder, unet, *args)
197 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep)
198 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
199 197
200 198
201dreambooth_strategy = TrainingStrategy( 199dreambooth_strategy = TrainingStrategy(
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 6a76f98..14bdafd 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -176,10 +176,9 @@ def textual_inversion_prepare(
176 elif accelerator.state.mixed_precision == "bf16": 176 elif accelerator.state.mixed_precision == "bf16":
177 weight_dtype = torch.bfloat16 177 weight_dtype = torch.bfloat16
178 178
179 prep = [text_encoder] + list(args) 179 prepped = accelerator.prepare(text_encoder, *args)
180 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep)
181 unet.to(accelerator.device, dtype=weight_dtype) 180 unet.to(accelerator.device, dtype=weight_dtype)
182 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 181 return (prepped[0], unet) + prepped[1:]
183 182
184 183
185textual_inversion_strategy = TrainingStrategy( 184textual_inversion_strategy = TrainingStrategy(