diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-20 14:26:17 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-20 14:26:17 +0100 |
| commit | 3575d041f1507811b577fd2c653171fb51c0a386 (patch) | |
| tree | 702f9f1ae4eafc6f8ea06560c4de6bbe1c2acecb /training/strategy | |
| parent | Move Accelerator preparation into strategy (diff) | |
| download | textual-inversion-diff-3575d041f1507811b577fd2c653171fb51c0a386.tar.gz textual-inversion-diff-3575d041f1507811b577fd2c653171fb51c0a386.tar.bz2 textual-inversion-diff-3575d041f1507811b577fd2c653171fb51c0a386.zip | |
Restored LR finder
Diffstat (limited to 'training/strategy')
| -rw-r--r-- | training/strategy/dreambooth.py | 4 | ||||
| -rw-r--r-- | training/strategy/ti.py | 5 |
2 files changed, 3 insertions, 6 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 1277939..e88bf90 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -193,9 +193,7 @@ def dreambooth_prepare( | |||
| 193 | unet: UNet2DConditionModel, | 193 | unet: UNet2DConditionModel, |
| 194 | *args | 194 | *args |
| 195 | ): | 195 | ): |
| 196 | prep = [text_encoder, unet] + list(args) | 196 | return accelerator.prepare(text_encoder, unet, *args) |
| 197 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep) | ||
| 198 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
| 199 | 197 | ||
| 200 | 198 | ||
| 201 | dreambooth_strategy = TrainingStrategy( | 199 | dreambooth_strategy = TrainingStrategy( |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6a76f98..14bdafd 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -176,10 +176,9 @@ def textual_inversion_prepare( | |||
| 176 | elif accelerator.state.mixed_precision == "bf16": | 176 | elif accelerator.state.mixed_precision == "bf16": |
| 177 | weight_dtype = torch.bfloat16 | 177 | weight_dtype = torch.bfloat16 |
| 178 | 178 | ||
| 179 | prep = [text_encoder] + list(args) | 179 | prepped = accelerator.prepare(text_encoder, *args) |
| 180 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep) | ||
| 181 | unet.to(accelerator.device, dtype=weight_dtype) | 180 | unet.to(accelerator.device, dtype=weight_dtype) |
| 182 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 181 | return (prepped[0], unet) + prepped[1:] |
| 183 | 182 | ||
| 184 | 183 | ||
| 185 | textual_inversion_strategy = TrainingStrategy( | 184 | textual_inversion_strategy = TrainingStrategy( |
