diff options
Diffstat (limited to 'training/strategy/ti.py')
-rw-r--r-- | training/strategy/ti.py | 5 |
1 files changed, 2 insertions, 3 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6a76f98..14bdafd 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -176,10 +176,9 @@ def textual_inversion_prepare( | |||
176 | elif accelerator.state.mixed_precision == "bf16": | 176 | elif accelerator.state.mixed_precision == "bf16": |
177 | weight_dtype = torch.bfloat16 | 177 | weight_dtype = torch.bfloat16 |
178 | 178 | ||
179 | prep = [text_encoder] + list(args) | 179 | prepped = accelerator.prepare(text_encoder, *args) |
180 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep) | ||
181 | unet.to(accelerator.device, dtype=weight_dtype) | 180 | unet.to(accelerator.device, dtype=weight_dtype) |
182 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 181 | return (prepped[0], unet) + prepped[1:] |
183 | 182 | ||
184 | 183 | ||
185 | textual_inversion_strategy = TrainingStrategy( | 184 | textual_inversion_strategy = TrainingStrategy( |