diff options
Diffstat (limited to 'training/strategy/dreambooth.py')
| -rw-r--r-- | training/strategy/dreambooth.py | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index f57e736..1277939 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -6,6 +6,7 @@ from pathlib import Path | |||
| 6 | import itertools | 6 | import itertools |
| 7 | 7 | ||
| 8 | import torch | 8 | import torch |
| 9 | import torch.nn as nn | ||
| 9 | from torch.utils.data import DataLoader | 10 | from torch.utils.data import DataLoader |
| 10 | 11 | ||
| 11 | from accelerate import Accelerator | 12 | from accelerate import Accelerator |
| @@ -186,7 +187,18 @@ def dreambooth_strategy_callbacks( | |||
| 186 | ) | 187 | ) |
| 187 | 188 | ||
| 188 | 189 | ||
| 190 | def dreambooth_prepare( | ||
| 191 | accelerator: Accelerator, | ||
| 192 | text_encoder: CLIPTextModel, | ||
| 193 | unet: UNet2DConditionModel, | ||
| 194 | *args | ||
| 195 | ): | ||
| 196 | prep = [text_encoder, unet] + list(args) | ||
| 197 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep) | ||
| 198 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
| 199 | |||
| 200 | |||
| 189 | dreambooth_strategy = TrainingStrategy( | 201 | dreambooth_strategy = TrainingStrategy( |
| 190 | callbacks=dreambooth_strategy_callbacks, | 202 | callbacks=dreambooth_strategy_callbacks, |
| 191 | prepare_unet=True | 203 | prepare=dreambooth_prepare |
| 192 | ) | 204 | ) |
