summaryrefslogtreecommitdiffstats
path: root/training/strategy/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-17 07:20:45 +0100
committerVolpeon <git@volpeon.ink>2023-01-17 07:20:45 +0100
commit5821523a524190490a287c5e2aacb6e72cc3a4cf (patch)
treec0eac536c754f078683be6d59893ad23d70baf51 /training/strategy/dreambooth.py
parentTraining update (diff)
downloadtextual-inversion-diff-5821523a524190490a287c5e2aacb6e72cc3a4cf.tar.gz
textual-inversion-diff-5821523a524190490a287c5e2aacb6e72cc3a4cf.tar.bz2
textual-inversion-diff-5821523a524190490a287c5e2aacb6e72cc3a4cf.zip
Update
Diffstat (limited to 'training/strategy/dreambooth.py')
-rw-r--r--training/strategy/dreambooth.py10
1 files changed, 8 insertions, 2 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 93c81cb..bc26ee6 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -15,10 +15,10 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch
15from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 15from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
16from models.clip.tokenizer import MultiCLIPTokenizer 16from models.clip.tokenizer import MultiCLIPTokenizer
17from training.util import EMAModel 17from training.util import EMAModel
18from training.functional import TrainingCallbacks, save_samples 18from training.functional import TrainingStrategy, TrainingCallbacks, save_samples
19 19
20 20
21def dreambooth_strategy( 21def dreambooth_strategy_callbacks(
22 accelerator: Accelerator, 22 accelerator: Accelerator,
23 unet: UNet2DConditionModel, 23 unet: UNet2DConditionModel,
24 text_encoder: CLIPTextModel, 24 text_encoder: CLIPTextModel,
@@ -185,3 +185,9 @@ def dreambooth_strategy(
185 on_checkpoint=on_checkpoint, 185 on_checkpoint=on_checkpoint,
186 on_sample=on_sample, 186 on_sample=on_sample,
187 ) 187 )
188
189
190dreambooth_strategy = TrainingStrategy(
191 callbacks=dreambooth_strategy_callbacks,
192 prepare_unet=True
193)