diff options
Diffstat (limited to 'training/strategy/dreambooth.py')
| -rw-r--r-- | training/strategy/dreambooth.py | 4 |
1 files changed, 1 insertions, 3 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 1277939..e88bf90 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -193,9 +193,7 @@ def dreambooth_prepare( | |||
| 193 | unet: UNet2DConditionModel, | 193 | unet: UNet2DConditionModel, |
| 194 | *args | 194 | *args |
| 195 | ): | 195 | ): |
| 196 | prep = [text_encoder, unet] + list(args) | 196 | return accelerator.prepare(text_encoder, unet, *args) |
| 197 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep) | ||
| 198 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
| 199 | 197 | ||
| 200 | 198 | ||
| 201 | dreambooth_strategy = TrainingStrategy( | 199 | dreambooth_strategy = TrainingStrategy( |
