summaryrefslogtreecommitdiffstats
path: root/training/strategy/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-20 14:26:17 +0100
committerVolpeon <git@volpeon.ink>2023-01-20 14:26:17 +0100
commit3575d041f1507811b577fd2c653171fb51c0a386 (patch)
tree702f9f1ae4eafc6f8ea06560c4de6bbe1c2acecb /training/strategy/dreambooth.py
parentMove Accelerator preparation into strategy (diff)
downloadtextual-inversion-diff-3575d041f1507811b577fd2c653171fb51c0a386.tar.gz
textual-inversion-diff-3575d041f1507811b577fd2c653171fb51c0a386.tar.bz2
textual-inversion-diff-3575d041f1507811b577fd2c653171fb51c0a386.zip
Restored LR finder
Diffstat (limited to 'training/strategy/dreambooth.py')
-rw-r--r--training/strategy/dreambooth.py4
1 files changed, 1 insertions, 3 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 1277939..e88bf90 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -193,9 +193,7 @@ def dreambooth_prepare(
193 unet: UNet2DConditionModel, 193 unet: UNet2DConditionModel,
194 *args 194 *args
195): 195):
196 prep = [text_encoder, unet] + list(args) 196 return accelerator.prepare(text_encoder, unet, *args)
197 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep)
198 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
199 197
200 198
201dreambooth_strategy = TrainingStrategy( 199dreambooth_strategy = TrainingStrategy(